diff --git a/CMakeLists.txt b/CMakeLists.txt index 477359afd..a413631e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,7 +68,7 @@ message(STATUS "CMAKE_BUILD_TYPE = ${CMAKE_BUILD_TYPE}") ## Set the standard required compile flags # Nov 18th --- removed -DHAVE_CONFIG_H set(REMOVE_WARNING_FLAGS "-Wno-unused-function;-Wno-unused-local-typedefs") -set(TGT_COMPILE_FLAGS "-ftree-vectorize;-funroll-loops;-fPIC;-fomit-frame-pointer;-O3;-DNDEBUG;-DSTX_NO_STD_STRING_VIEW") +set(TGT_COMPILE_FLAGS "-ftree-vectorize;-funroll-loops;-fPIC;-fomit-frame-pointer;-O3;-DNDEBUG;-DSTX_NO_STD_STRING_VIEW;-D__STDC_FORMAT_MACROS") set(TGT_WARN_FLAGS "-Wall;-Wno-unknown-pragmas;-Wno-reorder;-Wno-unused-variable;-Wreturn-type;-Werror=return-type;${REMOVE_WARNING_FLAGS}") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address") #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address") @@ -596,10 +596,15 @@ if (NOT CEREAL_FOUND) endif() ## Try and find TBB first -find_package(TBB 2018.0 COMPONENTS tbb tbbmalloc tbbmalloc_proxy) +find_package(TBB 2019.0 COMPONENTS tbb tbbmalloc tbbmalloc_proxy) +## NOTE: we actually require at least 2019 U4 or greater +## since we are using tbb::global_control. However, they +## seem not to have tagged minor version numbers in their +## source. Check before release if we can bump to the 2020 +## version (requires having tbb 2020 for OSX). if (${TBB_FOUND}) - if (${TBB_VERSION} VERSION_GREATER_EQUAL 2018.0) + if (${TBB_VERSION} VERSION_GREATER_EQUAL 2019.0) message("FOUND SUITABLE TBB VERSION : ${TBB_VERSION}") set(TBB_TARGET_EXISTED TRUE) else() @@ -627,7 +632,7 @@ endif() message("Build system will fetch and build Intel Threading Building Blocks") message("==================================================================") # These are useful for the custom install step we'll do later -set(TBB_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2019_U8) +set(TBB_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2020.2) set(TBB_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/install) if("${TBB_COMPILER}" STREQUAL "gcc") @@ -640,10 +645,10 @@ set(TBB_CXXFLAGS "${TBB_CXXFLAGS} ${CXXSTDFLAG}") externalproject_add(libtbb DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external - DOWNLOAD_COMMAND curl -k -L https://github.com/intel/tbb/archive/2019_U8.tar.gz -o tbb-2019_U8.tgz && - ${SHASUM} 6b540118cbc79f9cbc06a35033c18156c21b84ab7b6cf56d773b168ad2b68566 tbb-2019_U8.tgz && - tar -xzvf tbb-2019_U8.tgz - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2019_U8 + DOWNLOAD_COMMAND curl -k -L https://github.com/oneapi-src/oneTBB/archive/v2020.2.tar.gz -o tbb-2020_U2.tgz && + ${SHASUM} 4804320e1e6cbe3a5421997b52199e3c1a3829b2ecb6489641da4b8e32faf500 tbb-2020_U2.tgz && + tar -xzvf tbb-2020_U2.tgz + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2020.2 INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/install PATCH_COMMAND "${TBB_PATCH_STEP}" CONFIGURE_COMMAND "" @@ -652,9 +657,6 @@ externalproject_add(libtbb BUILD_IN_SOURCE 1 ) - - - set(RECONFIG_FLAGS ${RECONFIG_FLAGS} -DTBB_WILL_RECONFIGURE=FALSE -DTBB_RECONFIGURE=TRUE) externalproject_add_step(libtbb reconfigure COMMAND ${CMAKE_COMMAND} ${CMAKE_CURRENT_SOURCE_DIR} ${RECONFIG_FLAGS} @@ -721,14 +723,22 @@ message("TBB_LIBRARIES = ${TBB_LIBRARIES}") #message("TBB_LIBRARY_DIRS ${TBB_LIBRARY_DIRS}") #message("TBB_LIBRARIES ${TBB_LIBRARIES} ") -find_package(libgff) -if(NOT LIBGFF_FOUND) +find_package(libgff 2.0.0 +HINTS ${LIB_GFF_PATH} ${GFF_ROOT} +) +if(libgff_FOUND) + message(STATUS "libgff ver. ${LIB_GFF_VERSION} found.") + message(STATUS " include: ${LIB_GFF_INCLUDE_DIR}") + message(STATUS " lib : ${LIB_GFF_LIBRARY_DIR}") +endif() + +if(NOT libgff_FOUND) message("Build system will compile libgff") message("==================================================================") externalproject_add(libgff DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external - DOWNLOAD_COMMAND curl -k -L https://github.com/COMBINE-lab/libgff/archive/v1.2.tar.gz -o libgff.tgz && - ${SHASUM} bfabf143da828e8db251104341b934458c19d3e3c592d418d228de42becf98eb libgff.tgz && + DOWNLOAD_COMMAND curl -k -L https://github.com/COMBINE-lab/libgff/archive/v2.0.0.tar.gz -o libgff.tgz && + ${SHASUM} 7656b19459a7ca7d2fd0fcec4f2e0fd0deec1b4f39c703a114e8f4c22d82a99c libgff.tgz && tar -xzvf libgff.tgz ## #URL https://github.com/COMBINE-lab/libgff/archive/v1.1.tar.gz @@ -736,11 +746,11 @@ if(NOT LIBGFF_FOUND) #URL_HASH SHA1=37b3147d78391d1fabbe6a0df313fbf516abbc6f #TLS_VERIFY FALSE ## - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-1.2 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-2.0.0 #UPDATE_COMMAND sh -c "mkdir -p /build" INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/install - BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-1.2/build - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_CURRENT_SOURCE_DIR}/external/install + BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-2.0.0/build + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH= -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} ) externalproject_add_step(libgff makedir COMMAND mkdir -p /build @@ -748,6 +758,7 @@ if(NOT LIBGFF_FOUND) DEPENDEES download DEPENDERS configure) set(FETCHED_GFF TRUE) + set(LIB_GFF_PATH ${CMAKE_CURRENT_SOURCE_DIR}/external/install) endif() # Because of the way that Apple has changed SIP diff --git a/README.md b/README.md index 0f4326f05..bbebb8cf9 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,6 @@ Give salmon a try! You can find the latest binary releases [here](https://githu The current version number of the master branch of Salmon can be found [**here**](http://combine-lab.github.io/salmon/version_info/latest) -**NOTE**: Salmon works by (quasi)-mapping sequencing reads directly to the *transcriptome*. This means the Salmon index should be built on a set of target transcripts, **not** on the *genome* of the underlying organism. If indexing appears to be taking a very long time, or using a tremendous amount of memory (which it should not), please ensure that you are not attempting to build an index on the genome of your organism! - Documentation ============== diff --git a/current_version.txt b/current_version.txt index 8e8883241..f3da0dca9 100644 --- a/current_version.txt +++ b/current_version.txt @@ -1,3 +1,3 @@ VERSION_MAJOR 1 -VERSION_MINOR 2 -VERSION_PATCH 1 +VERSION_MINOR 3 +VERSION_PATCH 0 diff --git a/doc/source/conf.py b/doc/source/conf.py index 0ff2ead23..d3bbf85dd 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.2' +version = '1.3' # The full version, including alpha/beta/rc tags. -release = '1.2.1' +release = '1.3.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docker/Dockerfile b/docker/Dockerfile index 8888731f4..4021bd5a2 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,7 +6,7 @@ MAINTAINER salmon.maintainer@gmail.com ENV PACKAGES git gcc make g++ libboost-all-dev liblzma-dev libbz2-dev \ ca-certificates zlib1g-dev libcurl4-openssl-dev curl unzip autoconf apt-transport-https ca-certificates gnupg software-properties-common wget -ENV SALMON_VERSION 1.2.1 +ENV SALMON_VERSION 1.3.0 # salmon binary will be installed in /home/salmon/bin/salmon diff --git a/docker/build_test.sh b/docker/build_test.sh index 9a001d209..5bbe6dc76 100644 --- a/docker/build_test.sh +++ b/docker/build_test.sh @@ -1,3 +1,3 @@ #! /bin/bash -SALMON_VERSION=1.2.1 +SALMON_VERSION=1.3.0 docker build --no-cache -t combinelab/salmon:${SALMON_VERSION} -t combinelab/salmon:latest . diff --git a/include/AlevinOpts.hpp b/include/AlevinOpts.hpp index cf08e041f..10855b1bb 100644 --- a/include/AlevinOpts.hpp +++ b/include/AlevinOpts.hpp @@ -40,6 +40,9 @@ struct AlevinOpts { bool dumpBFH; // dump per cell level umi-graph bool dumpUmiGraph; + // dump per cell level de-duplicated + // equivalence class + bool dumpCellEq; //Stop progress sumps bool quiet; //flag for deduplication @@ -80,6 +83,8 @@ struct AlevinOpts { uint32_t maxNumBarcodes; // number of bootstraps to perform uint32_t numBootstraps; + // number of gibbs samples to perform + uint32_t numGibbsSamples; // force the number of cells uint32_t forceCells; // define a close upper bound on expected number of cells diff --git a/include/AlignmentLibrary.hpp b/include/AlignmentLibrary.hpp index 218e7d895..1cd89f913 100644 --- a/include/AlignmentLibrary.hpp +++ b/include/AlignmentLibrary.hpp @@ -27,9 +27,10 @@ extern "C" { #include "SalmonOpts.hpp" #include "SalmonUtils.hpp" #include "SimplePosBias.hpp" -#include "SpinLock.hpp" // RapMap's with try_lock +#include "SpinLock.hpp" // From pufferfish, with try_lock #include "Transcript.hpp" #include "concurrentqueue.h" +#include "parallel_hashmap/phmap.h" // Boost includes #include @@ -38,6 +39,7 @@ extern "C" { #include #include #include +#include template class NullFragmentFilter; @@ -109,8 +111,48 @@ template class AlignmentLibrary { // Figure out aligner information from the header if we can aligner_ = salmon::bam_utils::inferAlignerFromHeader(header); + // in this case check for decoys and make a list of their names + phmap::flat_hash_set decoys; + if (aligner_ == salmon::bam_utils::AlignerDetails::PUFFERFISH) { + // for each reference + for (decltype(header->nref) i = 0; i < header->nref; ++i) { + // for each tag + SAM_hdr_tag *tag; + for (tag = header->ref[i].tag; tag; tag = tag->next) { + // if this tag marks it as a decoy + if ((tag->len == 4) and (std::strncmp(tag->str, "DS:D", 4) == 0)) { + decoys.insert(header->ref[i].name); + break; + } // end if decoy tag + + } // end for each tag + } // end for each referecne + } + + if (!decoys.empty()) { + bq->forceEndParsing(); + bq.reset(); + salmonOpts.jointLog->error( + "Salmon is being run in alignment-mode with a SAM/BAM file that contains decoy\n" + "sequences (marked as such during salmon indexing). This SAM/BAM file had {}\n" + "such sequences tagged in the header. Since alignments to decoys are not\n" + "intended for decoy-level quantification, this functionality is not currently\n" + "supported. If you wish to run salmon with this SAM/BAM file, please \n" + "filter out / remove decoy transcripts (those tagged with `DS:D`) from the \n" + "header, and all SAM/BAM records that represent alignments to decoys \n" + "(those tagged with `XT:A:D`). If you believe you are receiving this message\n" + "in error, please report this issue on GitHub.", decoys.size()); + salmonOpts.jointLog->flush(); + std::stringstream ss; + ss << "\nCannot quantify from SAM/BAM file containing decoy transcripts or alignment records!\n"; + throw std::runtime_error(ss.str()); + } + // The transcript file existed, so load up the transcripts double alpha = 0.005; + // we know how many we will have, so reserve the space for + // them. + transcripts_.reserve(header->nref); for (decltype(header->nref) i = 0; i < header->nref; ++i) { transcripts_.emplace_back(i, header->ref[i].name, header->ref[i].len, alpha); diff --git a/include/AtomicMatrix.hpp b/include/AtomicMatrix.hpp index ed998ab00..0212da8de 100644 --- a/include/AtomicMatrix.hpp +++ b/include/AtomicMatrix.hpp @@ -1,42 +1,53 @@ #ifndef ATOMIC_MATRIX #define ATOMIC_MATRIX -#include "tbb/atomic.h" #include "tbb/concurrent_vector.h" #include "SalmonMath.hpp" +#include "SalmonUtils.hpp" #include template class AtomicMatrix { public: + AtomicMatrix() { + nRow_ = 0; + nCol_ = 0; + alpha_ = salmon::math::LOG_0; + logSpace_ = true; + } + AtomicMatrix(size_t nRow, size_t nCol, T alpha, bool logSpace = true) - : storage_(nRow * nCol, logSpace ? std::log(alpha) : alpha), - rowsums_(nRow, logSpace ? std::log(nCol * alpha) : nCol * alpha), - nRow_(nRow), nCol_(nCol), alpha_(alpha), logSpace_(logSpace) {} + : nRow_(nRow), nCol_(nCol), alpha_(alpha), logSpace_(logSpace) { + + decltype(storage_) storage_tmp(nRow * nCol); + std::swap(storage_, storage_tmp); + T e = logSpace ? std::log(alpha) : alpha; + std::fill(storage_.begin(), storage_.end(), e); + + decltype(rowsums_) rowsums_tmp(nRow); + std::swap(rowsums_, rowsums_tmp); + T ers = logSpace ? std::log(nCol * alpha) : nCol * alpha; + std::fill(rowsums_.begin(), rowsums_.end(), ers); + } + + AtomicMatrix& operator=(AtomicMatrix&& o) { + std::swap(storage_, o.storage_); + std::swap(rowsums_, o.rowsums_); + nRow_ = o.nRow_; + nCol_ = o.nCol_; + alpha_ = o.alpha_; + logSpace_ = o.logSpace_; + return *this; + } void incrementUnnormalized(size_t rowInd, size_t colInd, T amt) { using salmon::math::logAdd; size_t k = rowInd * nCol_ + colInd; if (logSpace_) { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = logAdd(oldVal, amt); - do { - oldVal = retVal; - newVal = logAdd(oldVal, amt); - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - + salmon::utils::incLoopLog(storage_[k], amt); } else { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = oldVal + amt; - do { - oldVal = retVal; - newVal = oldVal + amt; - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoop(storage_[k], amt); } } @@ -55,41 +66,11 @@ template class AtomicMatrix { using salmon::math::logAdd; size_t k = rowInd * nCol_ + colInd; if (logSpace_) { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = logAdd(oldVal, amt); - do { - oldVal = retVal; - newVal = logAdd(oldVal, amt); - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - oldVal = rowsums_[rowInd]; - retVal = oldVal; - newVal = logAdd(oldVal, amt); - do { - oldVal = retVal; - newVal = logAdd(oldVal, amt); - retVal = rowsums_[rowInd].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoopLog(storage_[k], amt); + salmon::utils::incLoopLog(rowsums_[rowInd], amt); } else { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = oldVal + amt; - do { - oldVal = retVal; - newVal = oldVal + amt; - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - oldVal = rowsums_[rowInd]; - retVal = oldVal; - newVal = oldVal + amt; - do { - oldVal = retVal; - newVal = oldVal + amt; - retVal = rowsums_[rowInd].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoop(storage_[k], amt); + salmon::utils::incLoop(rowsums_[rowInd], amt); } } @@ -106,8 +87,8 @@ template class AtomicMatrix { size_t nCol() const { return nCol_; } private: - std::vector> storage_; - std::vector> rowsums_; + std::vector> storage_; + std::vector> rowsums_; size_t nRow_, nCol_; T alpha_; bool logSpace_; diff --git a/include/BAMQueue.hpp b/include/BAMQueue.hpp index 9576b234e..d5df05c7d 100644 --- a/include/BAMQueue.hpp +++ b/include/BAMQueue.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include diff --git a/include/BAMQueue.tpp b/include/BAMQueue.tpp index 8efb4552d..76d8e5e72 100644 --- a/include/BAMQueue.tpp +++ b/include/BAMQueue.tpp @@ -116,7 +116,7 @@ void BAMQueue::reset() { template BAMQueue::~BAMQueue() { fmt::print(stderr, "\nFreeing memory used by read queue . . . "); - parsingThread_->join(); + if (parsingThread_) { parsingThread_->join(); } fmt::print(stderr, "\nJoined parsing thread . . . "); for (auto& file : files_) { diff --git a/include/CollapsedCellOptimizer.hpp b/include/CollapsedCellOptimizer.hpp index 4ae0c08a4..2aeaefd46 100644 --- a/include/CollapsedCellOptimizer.hpp +++ b/include/CollapsedCellOptimizer.hpp @@ -6,9 +6,6 @@ #include #include -#include "tbb/atomic.h" -#include "tbb/task_scheduler_init.h" - #include #include "ReadExperiment.hpp" @@ -45,7 +42,7 @@ struct CellState { class CollapsedCellOptimizer { public: - using VecType = std::vector>; + using VecType = std::vector>; using SerialVecType = std::vector; CollapsedCellOptimizer(); @@ -77,11 +74,13 @@ void optimizeCell(std::vector& trueBarcodes, std::vector& umiCount, std::vector& skippedCB, bool verbose, GZipWriter& gzw, bool noEM, bool useVBEM, - bool quiet, tbb::atomic& totalDedupCounts, - tbb::atomic& totalExpGeneCounts, double priorWeight, + bool quiet, std::atomic& totalDedupCounts, + std::atomic& totalExpGeneCounts, double priorWeight, spp::sparse_hash_map& txpToGeneMap, - uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, - bool naiveEqclass, bool dumpUmiGraph, bool useAllBootstraps, + uint32_t numGenes, uint32_t umiLength, + uint32_t numBootstraps, uint32_t numGibbsSamples, + bool naiveEqclass, bool dumpUmiGraph, + bool dumpCellEq, bool useAllBootstraps, bool initUniform, CFreqMapT& freqCounter, bool dumpArboFragCounts, spp::sparse_hash_set& mRnaGenes, spp::sparse_hash_set& rRnaGenes, diff --git a/include/CollapsedEMOptimizer.hpp b/include/CollapsedEMOptimizer.hpp index f2b406e8c..397ae4bc4 100644 --- a/include/CollapsedEMOptimizer.hpp +++ b/include/CollapsedEMOptimizer.hpp @@ -1,12 +1,10 @@ #ifndef COLLAPSED_EM_OPTIMIZER_HPP #define COLLAPSED_EM_OPTIMIZER_HPP +#include #include #include -#include "tbb/atomic.h" -#include "tbb/task_scheduler_init.h" - #include "ReadExperiment.hpp" #include "SalmonOpts.hpp" @@ -17,7 +15,7 @@ class BootstrapWriter; class CollapsedEMOptimizer { public: - using VecType = std::vector>; + using VecType = std::vector>; using SerialVecType = std::vector; CollapsedEMOptimizer(); diff --git a/include/CollapsedGibbsSampler.hpp b/include/CollapsedGibbsSampler.hpp index db8e8cba5..c5196dfde 100644 --- a/include/CollapsedGibbsSampler.hpp +++ b/include/CollapsedGibbsSampler.hpp @@ -3,10 +3,6 @@ #include #include - -#include "tbb/atomic.h" -#include "tbb/task_scheduler_init.h" - #include "SalmonOpts.hpp" #include "Eigen/Dense" diff --git a/include/DistributionUtils.hpp b/include/DistributionUtils.hpp index 99d266d3e..ee0a02e5f 100644 --- a/include/DistributionUtils.hpp +++ b/include/DistributionUtils.hpp @@ -12,13 +12,23 @@ class Transcript; namespace distribution_utils { enum class DistributionSpace : uint8_t { LOG = 0, LINEAR = 1 }; +class DistSummary { +public: +DistSummary(double mean_in, double sd_in, uint32_t support) : + mean(mean_in), sd(sd_in), samples(support,0) {} + +double mean; +double sd; +std::vector samples; +}; + /** * Draw samples from the provided fragment length distribution. * \param fld A pointer to the FragmentLengthDistribution from which * samples will be drawn. * \param numSamples The number of samples to draw. */ -std::vector samplesFromLogPMF(FragmentLengthDistribution* fld, +DistSummary samplesFromLogPMF(FragmentLengthDistribution* fld, int32_t numSamples); /** @@ -86,6 +96,58 @@ class LogCMFCache { std::vector cachedCMF_; }; +template +class VersionedValue { + public: + T val{T()}; + uint64_t gen{0}; +}; + +// A simple cache where values are retreived by their index +// and the cached items are versioned +template +class IndexedVersionedCache { + public: + + IndexedVersionedCache(size_t max_index) : + cache_(max_index+1), max_index_(max_index), current_gen_(0) {} + + inline void increment_generation() { ++current_gen_; } + + inline bool get_value(size_t index, T& v) { + size_t idx = (index > max_index_) ? max_index_ : index; + const VersionedValue& vv = cache_[idx]; + bool is_stale = vv.gen < current_gen_; + v = vv.val; + return is_stale; + } + + inline void update_value(size_t index, T v) { + size_t idx = (index > max_index_) ? max_index_ : index; + VersionedValue& vv = cache_[idx]; + vv.val = v; + vv.gen = current_gen_; + } + + template + inline T get_or_update(size_t index, F& gen_value) { + size_t idx = (index > max_index_) ? max_index_ : index; + VersionedValue& vv = cache_[idx]; + // if the current value is stale, compute a new one and cache it + if (vv.gen < current_gen_) { + vv.val = gen_value(idx); + vv.gen = current_gen_; + } + // return the (possibly newly) cached value + return vv.val; + } + + private: + std::vector> cache_; + size_t max_index_{0}; + uint64_t current_gen_{0}; +}; + } // namespace distribution_utils #endif // __DISTRIBUTION_UTILS__ diff --git a/include/EMUtils.hpp b/include/EMUtils.hpp index 3d2d543f4..0a4eada73 100644 --- a/include/EMUtils.hpp +++ b/include/EMUtils.hpp @@ -2,7 +2,7 @@ #define EM_UTILS_HPP #include -#include "tbb/atomic.h" +#include #include "Transcript.hpp" template diff --git a/include/FragmentLengthDistribution.hpp b/include/FragmentLengthDistribution.hpp index e9ec6ff53..84fbbd59e 100644 --- a/include/FragmentLengthDistribution.hpp +++ b/include/FragmentLengthDistribution.hpp @@ -10,7 +10,6 @@ #ifndef FRAGMENT_LENGTH_DISTRIBUTION #define FRAGMENT_LENGTH_DISTRIBUTION -#include "tbb/atomic.h" #include #include #include @@ -33,7 +32,7 @@ class FragmentLengthDistribution { /** * A private vector that stores the observed (logged) mass for each length. */ - std::vector> hist_; + std::vector> hist_; /** * A private vector that stores the observed (logged) mass for each length. @@ -47,12 +46,12 @@ class FragmentLengthDistribution { /** * A private double that stores the total observed (logged) mass. */ - tbb::atomic totMass_; + std::atomic totMass_; /** * A private double that stores the (logged) sum of the product of observed * lengths and masses for quick mean calculations. */ - tbb::atomic sum_; + std::atomic sum_; /** * A private int that stores the minimum observed length. */ diff --git a/include/FragmentStartPositionDistribution.hpp b/include/FragmentStartPositionDistribution.hpp index c400096f8..4d9dd6f1f 100644 --- a/include/FragmentStartPositionDistribution.hpp +++ b/include/FragmentStartPositionDistribution.hpp @@ -9,7 +9,7 @@ #ifndef FRAGMENT_START_POSITION_DISTRIBUTION #define FRAGMENT_START_POSITION_DISTRIBUTION -#include "tbb/atomic.h" +// #include "tbb/atomic.h" #include #include #include @@ -26,12 +26,12 @@ class FragmentStartPositionDistribution { /** * A private vector that stores the observed (logged) mass for each length. */ - std::vector> pmf_; - std::vector> cmf_; + std::vector> pmf_; + std::vector> cmf_; /** * A private double that stores the total observed (logged) mass. */ - tbb::atomic totMass_; + std::atomic totMass_; /** * A private double that stores the (logged) sum of the product of observed * lengths and masses for quick mean calculations. diff --git a/include/GZipWriter.hpp b/include/GZipWriter.hpp index 94629b71a..8ea847aca 100644 --- a/include/GZipWriter.hpp +++ b/include/GZipWriter.hpp @@ -92,6 +92,11 @@ class GZipWriter { bool writeCellEQVec(size_t barcode, const std::vector& offsets, const std::vector& counts, bool quiet = true); + bool writeDedupCellEQVec(size_t barcode, + const std::vector>& labels, + const std::vector& counts, + bool quiet = true); + bool writeUmiGraph(alevin::graph::Graph& g, std::string& trueBarcodeStr); bool setSamplingPath(const SalmonOpts& sopt); @@ -112,6 +117,7 @@ class GZipWriter { std::unique_ptr tierMatrixStream_{nullptr}; std::unique_ptr umiGraphStream_{nullptr}; std::unique_ptr cellEQStream_{nullptr}; + std::unique_ptr cellDedupEQStream_{nullptr}; std::unique_ptr arboMatrixStream_{nullptr}; std::unique_ptr bcNameStream_{nullptr}; std::unique_ptr bcFeaturesStream_{nullptr}; diff --git a/include/SBModel.hpp b/include/SBModel.hpp index 8ee10c1bf..772956dce 100644 --- a/include/SBModel.hpp +++ b/include/SBModel.hpp @@ -4,13 +4,14 @@ #include #include "UtilityFunctions.hpp" -#include "jellyfish/mer_dna.hpp" +//#include "jellyfish/mer_dna.hpp" //#include "rapmap/Kmer.hpp" +#include "pufferfish/Kmer.hpp" #include #include -using Mer = jellyfish::mer_dna_ns::mer_base_static; -//using Mer = combinelib::kmers::Kmer<32,2>; +//using Mer = jellyfish::mer_dna_ns::mer_base_static; +using SBMer = combinelib::kmers::Kmer<32,4>; class SBModel { public: @@ -31,13 +32,13 @@ class SBModel { } bool addSequence(const char* seqIn, bool revCmp, double weight = 1.0); - bool addSequence(const Mer& mer, double weight); + bool addSequence(const SBMer& sbmer, double weight); Eigen::MatrixXd& counts(); Eigen::MatrixXd& marginals(); double evaluateLog(const char* seqIn); - double evaluateLog(const Mer& mer); + double evaluateLog(const SBMer& sbmer); bool normalize(); @@ -88,10 +89,12 @@ class SBModel { Eigen::MatrixXd _probs; Eigen::MatrixXd _marginals; - Mer _mer; + //Mer _mer; + SBMer _sbmer; std::vector _order; std::vector _shifts; std::vector _widths; + constexpr static const double _prior_prob = 1e-10; }; #endif //__SB_MODEL_HPP__ diff --git a/include/SalmonConfig.hpp b/include/SalmonConfig.hpp index 27ea70b15..2ead95ae9 100644 --- a/include/SalmonConfig.hpp +++ b/include/SalmonConfig.hpp @@ -26,9 +26,9 @@ namespace salmon { constexpr char majorVersion[] = "1"; -constexpr char minorVersion[] = "2"; -constexpr char patchVersion[] = "1"; -constexpr char version[] = "1.2.1"; +constexpr char minorVersion[] = "3"; +constexpr char patchVersion[] = "0"; +constexpr char version[] = "1.3.0"; constexpr uint32_t indexVersion = 5; constexpr char requiredQuasiIndexVersion[] = "p7"; } // namespace salmon diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index dff3d66e5..1b4b079ca 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -24,13 +24,16 @@ namespace defaults { constexpr const bool disableSA{false}; constexpr const float consensusSlack{0.0}; constexpr const double minScoreFraction{0.65}; + constexpr const double pre_merge_chain_sub_thresh{0.75}; + constexpr const double post_merge_chain_sub_thresh{0.9}; + constexpr const double orphan_chain_sub_thresh{0.95}; constexpr const double scoreExp{1.0}; constexpr const int16_t matchScore{2}; constexpr const int16_t mismatchPenalty{-4}; - constexpr const int16_t gapOpenPenalty{4}; + constexpr const int16_t gapOpenPenalty{6}; constexpr const int16_t gapExtendPenalty{2}; constexpr const int32_t dpBandwidth{15}; - constexpr const uint32_t mismatchSeedSkip{5}; + constexpr const uint32_t mismatchSeedSkip{3}; constexpr const bool disableChainingHeuristic{false}; constexpr const bool eqClassMode{false}; constexpr const bool hardFilter{false}; @@ -143,11 +146,13 @@ namespace defaults { constexpr const bool dumpFeatures{false}; constexpr const bool dumpBFH{false}; constexpr const bool dumpUmiGraph{false}; + constexpr const bool dumpCellEq{false}; constexpr const bool dumpMtx{false}; constexpr const bool noEM{false}; constexpr const bool debug{true}; constexpr const uint32_t trimRight{0}; constexpr const uint32_t numBootstraps{0}; + constexpr const uint32_t numGibbsSamples{0}; constexpr const uint32_t lowRegionMinNumBarcodes{200}; constexpr const uint32_t maxNumBarcodes{100000}; constexpr const uint32_t expectCells{0}; diff --git a/include/SalmonMappingUtils.hpp b/include/SalmonMappingUtils.hpp index e1d9b3d66..8f4c8c261 100644 --- a/include/SalmonMappingUtils.hpp +++ b/include/SalmonMappingUtils.hpp @@ -56,23 +56,77 @@ #include "pufferfish/ksw2pp/KSW2Aligner.hpp" #include "pufferfish/metro/metrohash64.h" #include "pufferfish/SelectiveAlignmentUtils.hpp" +#include "pufferfish/chobo/small_vector.hpp" +#include "parallel_hashmap/phmap.h" namespace salmon { namespace mapping_utils { using MateStatus = pufferfish::util::MateStatus; + constexpr const int32_t invalid_score_ = std::numeric_limits::min(); + constexpr const int32_t invalid_index_ = std::numeric_limits::min(); + constexpr const size_t static_vec_size = 32; + + class MappingScoreInfo { + public: + MappingScoreInfo() + : bestScore(invalid_score_), secondBestScore(invalid_score_), + bestDecoyScore(invalid_score_), decoyThresh(1.0), collect_decoy_info_(false) {} + + MappingScoreInfo(double decoyThreshIn) : MappingScoreInfo() { + decoyThresh = decoyThreshIn; + } + + void collect_decoys(bool do_collect) { collect_decoy_info_ = do_collect; } + bool collect_decoys() const { return collect_decoy_info_; } + + // clear everything but the decoy threshold + void clear(size_t num_hits) { + bestScore = invalid_score_; + secondBestScore = invalid_score_; + bestDecoyScore = invalid_score_; + scores_.clear(); scores_.resize(num_hits, invalid_score_); + bestScorePerTranscript_.clear(); + perm_.clear(); + // NOTE: we do _not_ reset decoyThresh here + if (collect_decoy_info_) { + bool revert_to_static = best_decoy_hits.size() > static_vec_size; + best_decoy_hits.clear(); + if (revert_to_static) { best_decoy_hits.revert_to_static(); } + } + } + + bool haveOnlyDecoyMappings() const { + // if the best non-decoy mapping has score less than decoyThresh * + // bestDecoyScore and if the bestDecoyScore is a valid value, then we + // have no valid non-decoy mappings. + return (bestScore < static_cast(decoyThresh * bestDecoyScore)) and + (bestDecoyScore > std::numeric_limits::min()); + } + + inline bool update_decoy_mappings(int32_t hitScore, size_t idx, uint32_t tid) { + const bool better_score = hitScore > bestDecoyScore; + if (hitScore > bestDecoyScore) { + bestDecoyScore = hitScore; + if (collect_decoy_info_) { + best_decoy_hits.clear(); + best_decoy_hits.push_back(std::make_pair(static_cast(idx), static_cast(tid))); + } + } else if (collect_decoy_info_ and (hitScore == bestDecoyScore)){ + best_decoy_hits.push_back(std::make_pair(static_cast(idx), static_cast(tid))); + } + return better_score; + } - struct MappingScoreInfo { int32_t bestScore; int32_t secondBestScore; int32_t bestDecoyScore; double decoyThresh; - bool haveOnlyDecoyMappings() const { - // if the best non-decoy mapping has score less than decoyThresh * bestDecoyScore - // and if the bestDecoyScore is a valid value, then we have no valid non-decoy mappings. - return (bestScore < static_cast(decoyThresh * bestDecoyScore)) and - (bestDecoyScore > std::numeric_limits::min()); - } + chobo::small_vector> best_decoy_hits; + bool collect_decoy_info_; + std::vector scores_; + phmap::flat_hash_map> bestScorePerTranscript_; + std::vector> perm_; }; template @@ -83,7 +137,8 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem memCollector.setConsensusFraction(consensusFraction); memCollector.setHitFilterPolicy(salmonOpts.hitFilterPolicy); memCollector.setAltSkip(salmonOpts.mismatchSeedSkip); - + memCollector.setChainSubOptThresh(salmonOpts.pre_merge_chain_sub_thresh); + //Initialize ksw aligner ksw2pp::KSW2Config config; config.dropoff = -1; @@ -101,6 +156,9 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem aconf.refExtendLength = 20; aconf.fullAlignment = salmonOpts.fullLengthAlignment; + aconf.mismatchPenalty = salmonOpts.mismatchPenalty; + aconf.bestStrata = false; + aconf.decoyPresent = false; aconf.matchScore = salmonOpts.matchScore; aconf.gapExtendPenalty = salmonOpts.gapExtendPenalty; aconf.gapOpenPenalty = salmonOpts.gapOpenPenalty; @@ -129,6 +187,9 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem mpol.noDovetail = !salmonOpts.allowDovetail; aconf.noDovetail = mpol.noDovetail; + mpol.setPostMergeChainSubThresh(salmonOpts.post_merge_chain_sub_thresh); + mpol.setOrphanChainSubThresh(salmonOpts.orphan_chain_sub_thresh); + return true; } @@ -136,15 +197,8 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, size_t idx, const std::vector& transcripts, int32_t invalidScore, - salmon::mapping_utils::MappingScoreInfo& msi, - /* - int32_t& bestScore, - int32_t& secondBestScore, - int32_t& bestDecoyScore, - */ - std::vector& scores, - phmap::flat_hash_map>& bestScorePerTranscript, - std::vector>& perm) { + salmon::mapping_utils::MappingScoreInfo& msi) { + auto& scores = msi.scores_; scores[idx] = hitScore; auto& t = transcripts[tid]; bool isDecoy = t.isDecoy(); @@ -153,10 +207,13 @@ inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, siz //if (hitScore < decoyCutoff or (hitScore == invalidScore)) { } if (isDecoy) { - // if this is a decoy and its score is at least as good - // as the bestDecoyScore, then update that and go to the next mapping - msi.bestDecoyScore = std::max(hitScore, msi.bestDecoyScore); - return; + // NOTE: decide here if we need to process any of this if the + // current score is < the best (non-decoy) score. I think not. + + // if this is a decoy and its score is better than the best decoy score + bool did_update = msi.update_decoy_mappings(hitScore, idx, tid); + (void)did_update; + return; } else if (hitScore < decoyCutoff or (hitScore == invalidScore)) { // if the current score is to a valid target but doesn't // exceed the necessary decoy threshold, then skip it. @@ -164,6 +221,8 @@ inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, siz } // otherwise, we have a "high-scoring" hit to a non-decoy + auto& perm = msi.perm_; + auto& bestScorePerTranscript = msi.bestScorePerTranscript_; // removing duplicate hits from a read to the same transcript auto it = bestScorePerTranscript.find(tid); if (it == bestScorePerTranscript.end()) { @@ -187,15 +246,12 @@ inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, siz msi.secondBestScore = msi.bestScore; msi.bestScore = hitScore; } - //bestScore = (hitScore > bestScore) ? hitScore : bestScore; perm.push_back(std::make_pair(idx, tid)); } inline void filterAndCollectAlignments( std::vector& jointHits, - const std::vector& scores, - std::vector>& perm, uint32_t readLen, uint32_t mateLen, bool singleEnd, @@ -203,13 +259,7 @@ inline void filterAndCollectAlignments( bool hardFilter, double scoreExp, double minAlnProb, - // intentionally passing by value below --- come back - // and make sure it's necessary - salmon::mapping_utils::MappingScoreInfo msi, - /*int32_t bestScore, - int32_t secondBestScore, - int32_t bestDecoyScore, - */ + salmon::mapping_utils::MappingScoreInfo& msi, std::vector& jointAlignments) { auto invalidScore = std::numeric_limits::min(); @@ -218,6 +268,8 @@ inline void filterAndCollectAlignments( int32_t decoyThreshold = static_cast(msi.decoyThresh * msi.bestDecoyScore); //auto filterScore = (bestDecoyScore < secondBestScore) ? secondBestScore : bestDecoyScore; + auto& scores = msi.scores_; + auto& perm = msi.perm_; // throw away any pairs for which we should not produce valid alignments : // ====== @@ -300,6 +352,90 @@ inline void filterAndCollectAlignments( // done moving our alinged / score jointMEMs over to QuasiAlignment objects } + +inline void filterAndCollectAlignmentsDecoy( + std::vector& jointHits, + uint32_t readLen, + uint32_t mateLen, + bool singleEnd, + bool tryAlign, + bool hardFilter, + double scoreExp, + double minAlnProb, + salmon::mapping_utils::MappingScoreInfo& msi, + std::vector& jointAlignments) { +// NOTE: this function should only be called in the case that we have valid decoy mappings to report. +// Currently, this happens only when there are *no valid non-decoy* mappings. +// Further, this function will only add equally *best* decoy mappings to the output jointAlignments object +// regardless of the the status of hardFilter (i.e. no sub-optimal decoy mappings will be reported). +(void) hardFilter; +(void) minAlnProb; +(void) scoreExp; +double estAlnProb = 1.0; //std::exp(-scoreExp * 0.0); +for (auto& idxTxp : msi.best_decoy_hits) { + int32_t ctr = idxTxp.first; + int32_t tid = idxTxp.second; + auto& jointHit = jointHits[ctr]; + + if (singleEnd or jointHit.isOrphan()) { + readLen = jointHit.isLeftAvailable() ? readLen : mateLen; + jointAlignments.emplace_back( + tid, // reference id + jointHit.orphanClust()->getTrFirstHitPos(), // reference pos + jointHit.orphanClust()->isFw, // fwd direction + readLen, // read length + jointHit.orphanClust()->cigar, // cigar string + jointHit.fragmentLen, // fragment length + false); + auto& qaln = jointAlignments.back(); + // NOTE : score should not be filled in from a double + qaln.score = !tryAlign + ? static_cast(jointHit.orphanClust()->coverage) + : jointHit.alignmentScore; + qaln.estAlnProb(estAlnProb); + // NOTE : wth is numHits? + qaln.numHits = + static_cast(jointHits.size()); // orphanClust()->coverage; + qaln.mateStatus = jointHit.mateStatus; + if (singleEnd) { + qaln.mateLen = readLen; + qaln.mateCigar.clear(); + qaln.matePos = 0; + qaln.mateIsFwd = true; + qaln.mateScore = 0; + qaln.mateStatus = MateStatus::SINGLE_END; + } + } else { + jointAlignments.emplace_back( + tid, // reference id + jointHit.leftClust->getTrFirstHitPos(), // reference pos + jointHit.leftClust->isFw, // fwd direction + readLen, // read length + jointHit.leftClust->cigar, // cigar string + jointHit.fragmentLen, // fragment length + true); // properly paired + // Fill in the mate info + auto& qaln = jointAlignments.back(); + qaln.mateLen = mateLen; + qaln.mateCigar = jointHit.rightClust->cigar; + qaln.matePos = + static_cast(jointHit.rightClust->getTrFirstHitPos()); + qaln.mateIsFwd = jointHit.rightClust->isFw; + qaln.mateStatus = MateStatus::PAIRED_END_PAIRED; + // NOTE : wth is numHits? + qaln.numHits = static_cast(jointHits.size()); + // NOTE : score should not be filled in from a double + qaln.score = !tryAlign ? static_cast(jointHit.leftClust->coverage) + : jointHit.alignmentScore; + qaln.estAlnProb(estAlnProb); + qaln.mateScore = !tryAlign + ? static_cast(jointHit.rightClust->coverage) + : jointHit.mateAlignmentScore; + } +} // end for over best decoy hits +} + + } // namespace mapping_utils } // namespace salmon diff --git a/include/SalmonOpts.hpp b/include/SalmonOpts.hpp index 73121635e..8b2473dd7 100644 --- a/include/SalmonOpts.hpp +++ b/include/SalmonOpts.hpp @@ -260,6 +260,9 @@ struct SalmonOpts { bool validateMappings; bool disableSA{false}; // this cannot be done right now. float consensusSlack; + double pre_merge_chain_sub_thresh; + double post_merge_chain_sub_thresh; + double orphan_chain_sub_thresh; bool disableChainingHeuristic; bool disableAlignmentCache; double minScoreFraction; diff --git a/include/SalmonUtils.hpp b/include/SalmonUtils.hpp index 6ed0c54cd..04a452532 100644 --- a/include/SalmonUtils.hpp +++ b/include/SalmonUtils.hpp @@ -159,15 +159,12 @@ updateEffectiveLengths(SalmonOpts& sopt, ReadExpT& readExp, * val + inc (*in log-space*). Update occurs in a loop in case other * threads update in the meantime. */ -inline void incLoopLog(tbb::atomic& val, double inc) { +inline void incLoopLog(std::atomic& val, double inc) { double oldMass = val.load(); - double returnedMass = oldMass; - double newMass{salmon::math::LOG_0}; + double newMass; do { - oldMass = returnedMass; newMass = salmon::math::logAdd(oldMass, inc); - returnedMass = val.compare_and_swap(newMass, oldMass); - } while (returnedMass != oldMass); + } while (! val.compare_exchange_strong(oldMass, newMass)); } /* @@ -180,15 +177,13 @@ inline void incLoop(double& val, double inc) { val += inc; } * val + inc. Update occurs in a loop in case other * threads update in the meantime. */ -inline void incLoop(tbb::atomic& val, double inc) { + +inline void incLoop(std::atomic& val, double inc) { double oldMass = val.load(); - double returnedMass = oldMass; - double newMass{oldMass + inc}; + double newMass; do { - oldMass = returnedMass; newMass = oldMass + inc; - returnedMass = val.compare_and_swap(newMass, oldMass); - } while (returnedMass != oldMass); + } while (!val.compare_exchange_strong(oldMass, newMass)); } std::string getCurrentTimeAsString(); diff --git a/include/Sampler.hpp b/include/Sampler.hpp index 1bbf88f53..2518d7bc2 100644 --- a/include/Sampler.hpp +++ b/include/Sampler.hpp @@ -20,7 +20,6 @@ extern "C" { #include #include #include -#include #include #include #include diff --git a/include/SingleCellProtocols.hpp b/include/SingleCellProtocols.hpp index 53cba31ce..cb46ac37f 100644 --- a/include/SingleCellProtocols.hpp +++ b/include/SingleCellProtocols.hpp @@ -46,7 +46,7 @@ namespace alevin{ }; struct CITESeq : Rule{ - CITESeq(): Rule(16, 12, BarcodeEnd::FIVE, 4294967295){ + CITESeq(): Rule(16, 10, BarcodeEnd::FIVE, 4294967295){ featureLength = 15; featureStart = 10; } diff --git a/include/Transcript.hpp b/include/Transcript.hpp index 4edf6a79d..ac02aa5b9 100644 --- a/include/Transcript.hpp +++ b/include/Transcript.hpp @@ -7,7 +7,6 @@ #include "SalmonStringUtils.hpp" #include "SalmonUtils.hpp" #include "SequenceBiasModel.hpp" -#include "tbb/atomic.h" #include "stx/string_view.hpp" #include "IOUtils.hpp" #include @@ -480,6 +479,8 @@ class Transcript { } } + bool have_sequence() const { if (Sequence_) { return true; } else { return false; } } + const char* Sequence() const { return Sequence_.get(); } uint8_t* SAMSequence() const { return const_cast(SAMSequence_.data()); } @@ -675,10 +676,10 @@ class Transcript { std::atomic uniqueCount_; std::atomic totalCount_; double priorMass_; - tbb::atomic mass_; - tbb::atomic sharedCount_; - tbb::atomic cachedEffectiveLength_; - tbb::atomic avgMassBias_; + std::atomic mass_; + std::atomic sharedCount_; + std::atomic cachedEffectiveLength_; + std::atomic avgMassBias_; uint32_t lengthClassIndex_; double logPerBasePrior_; // In a paired-end protocol, a transcript has diff --git a/include/VersionChecker.hpp b/include/VersionChecker.hpp index 0ae49039f..d999472c0 100644 --- a/include/VersionChecker.hpp +++ b/include/VersionChecker.hpp @@ -11,41 +11,10 @@ #ifndef VERSION_CHECKER_HPP #define VERSION_CHECKER_HPP -#include -#include -#include #include -#include -#include #include #include -using boost::asio::ip::tcp; - -class VersionChecker { -public: - VersionChecker(boost::asio::io_service& io_service, const std::string& server, - const std::string& path); - std::string message(); - -private: - void cancel_upgrade_check(const boost::system::error_code& err); - void handle_resolve(const boost::system::error_code& err, - tcp::resolver::iterator endpoint_iterator); - void handle_connect(const boost::system::error_code& err); - void handle_write_request(const boost::system::error_code& err); - void handle_read_status_line(const boost::system::error_code& err); - void handle_read_headers(const boost::system::error_code& err); - void handle_read_content(const boost::system::error_code& err); - - tcp::resolver resolver_; - tcp::socket socket_; - boost::asio::streambuf request_; - boost::asio::streambuf response_; - boost::asio::deadline_timer deadline_; - std::stringstream messageStream_; -}; - std::string getVersionMessage(); #endif // VERSION_CHECKER_HPP \ No newline at end of file diff --git a/include/httplib.hpp b/include/httplib.hpp new file mode 100644 index 000000000..af4532c50 --- /dev/null +++ b/include/httplib.hpp @@ -0,0 +1,5124 @@ +// +// httplib.h +// +// Copyright (c) 2020 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(1u, std::thread::hardware_concurrency() - 1)) +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = int; +#endif + +#if _MSC_VER < 1900 +#define snprintf _snprintf_s +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER +#pragma comment(lib, "ws2_32.lib") +#endif + +#ifndef strcasecmp +#define strcasecmp _stricmp +#endif // strcasecmp + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#include +#include +#include +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include + +using socket_t = int; +#define INVALID_SOCKET (-1) +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include +#include +#include +#include + +#include +#include +#include + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include +inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { + return M_ASN1_STRING_data(asn1); +} +#endif +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +struct ci { + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } +}; + +} // namespace detail + +using Headers = std::multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { +public: + DataSink() = default; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function done; + std::function is_writable; +}; + +using ContentProvider = + std::function; + +using ContentReceiver = + std::function; + +using MultipartContentHeader = + std::function; + +class ContentReader { +public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader muitlpart_reader) + : reader_(reader), muitlpart_reader_(muitlpart_reader) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return muitlpart_reader_(header, receiver); + } + + bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + + Reader reader_; + MultipartReader muitlpart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; + + bool is_multipart_form_data() const; + + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; + + // private members... + size_t content_length; + ContentProvider content_provider; +}; + +struct Response { + std::string version; + int status = -1; + Headers headers; + std::string body; + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + void set_redirect(const char *url, int status = 302); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(std::string s, const char *content_type); + + void set_content_provider( + size_t length, + std::function + provider, + std::function resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function provider, + std::function resource_releaser = [] {}); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); + } + } + + // private members... + size_t content_length = 0; + ContentProvider content_provider; + std::function content_provider_resource_releaser; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + + template + ssize_t write_format(const char *fmt, const Args &... args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle(){}; +}; + +class ThreadPool : public TaskQueue { +public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +class Server { +public: + using Handler = std::function; + using HandlerWithContentReader = std::function; + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Delete(const char *pattern, HandlerWithContentReader handler); + Server &Options(const char *pattern, Handler handler); + + [[deprecated]] bool set_base_dir(const char *dir, + const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); + + void set_error_handler(Handler handler); + void set_logger(Logger logger); + + void set_expect_100_continue_handler(Expect100ContinueHandler handler); + + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + void set_payload_max_length(size_t length); + + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const char *host, int port, int socket_flags = 0); + + bool is_running() const; + void stop(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request); + + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t payload_max_length_; + +private: + using Handlers = std::vector>; + using HandlersForContentReader = + std::vector>; + + socket_t create_server_socket(const char *host, int port, + int socket_flags) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandlersForContentReader &handlers); + + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool last_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_; + std::atomic svr_sock_; + std::vector> base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + Expect100ContinueHandler expect_100_continue_handler_; +}; + +class Client { +public: + explicit Client(const std::string &host, int port = 80, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); + + virtual ~Client(); + + virtual bool is_valid() const; + + std::shared_ptr Get(const char *path); + + std::shared_ptr Get(const char *path, const Headers &headers); + + std::shared_ptr Get(const char *path, Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, + Progress progress); + + std::shared_ptr Get(const char *path, + ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + + std::shared_ptr + Get(const char *path, ContentReceiver content_receiver, Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr Head(const char *path); + + std::shared_ptr Head(const char *path, const Headers &headers); + + std::shared_ptr Post(const char *path); + + std::shared_ptr Post(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Post(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Post(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Post(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Post(const char *path, const Params ¶ms); + + std::shared_ptr Post(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr Post(const char *path, + const MultipartFormDataItems &items); + + std::shared_ptr Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + std::shared_ptr Put(const char *path); + + std::shared_ptr Put(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Put(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Put(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Put(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Put(const char *path, const Params ¶ms); + + std::shared_ptr Put(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr Patch(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Patch(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Patch(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Patch(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Delete(const char *path); + + std::shared_ptr Delete(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Delete(const char *path, const Headers &headers); + + std::shared_ptr Delete(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Options(const char *path); + + std::shared_ptr Options(const char *path, const Headers &headers); + + bool send(const Request &req, Response &res); + + bool send(const std::vector &requests, + std::vector &responses); + + void stop(); + + void set_timeout_sec(time_t timeout_sec); + + void set_read_timeout(time_t sec, time_t usec); + + void set_keep_alive_max_count(size_t count); + + void set_basic_auth(const char *username, const char *password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const char *username, const char *password); +#endif + + void set_follow_location(bool on); + + void set_compress(bool on); + + void set_interface(const char *intf); + + void set_proxy(const char *host, int port); + + void set_proxy_basic_auth(const char *username, const char *password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const char *username, const char *password); +#endif + + void set_logger(Logger logger); + +protected: + bool process_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); + + std::atomic sock_; + + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t timeout_sec_ = 300; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + + std::string basic_auth_username_; + std::string basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool follow_location_ = false; + + bool compress_ = false; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + + Logger logger_; + + void copy_settings(const Client &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + timeout_sec_ = rhs.timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + keep_alive_max_count_ = rhs.keep_alive_max_count_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + follow_location_ = rhs.follow_location_; + compress_ = rhs.compress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif + logger_ = rhs.logger_; + } + +private: + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + bool write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); + bool handle_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool connect(socket_t sock, Response &res, bool &error); +#endif + + std::shared_ptr send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type); + + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback); + + virtual bool is_ssl() const; +}; + +inline void Get(std::vector &requests, const char *path, + const Headers &headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); +} + +inline void Get(std::vector &requests, const char *path) { + Get(requests, path, Headers()); +} + +inline void Post(std::vector &requests, const char *path, + const Headers &headers, const std::string &body, + const char *content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + requests.emplace_back(std::move(req)); +} + +inline void Post(std::vector &requests, const char *path, + const std::string &body, const char *content_type) { + Post(requests, path, Headers(), body, content_type); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + ~SSLServer() override; + + bool is_valid() const override; + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient : public Client { +public: + explicit SSLClient(const std::string &host, int port = 443, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); + + SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_path(const char *ca_ceert_file_path, + const char *ca_cert_dir_path = nullptr); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + + void enable_server_certificate_verification(bool enabled); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) override; + bool is_ssl() const override; + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::vector host_components_; + + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + X509_STORE *ca_cert_store_ = nullptr; + bool server_certificate_verification_ = false; + long verify_result_ = 0; +}; +#endif + +// ---------------------------------------------------------------------------- + +/* + * Implementation + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + int val = 0; + int valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_file(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +} + +inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], size); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +template void split(const char *b, const char *e, char d, Fn fn) { + int i = 0; + int beg = 0; + + while (e ? (b + i != e) : (b[i] != '\0')) { + if (b[i] == d) { + fn(&b[beg], &b[i]); + beg = i + 1; + } + i++; + } + + if (i) { fn(&b[beg], &b[i]); } +} + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } + } + + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } + } + + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } + + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + + if (byte == '\n') { break; } + } + + return true; + } + +private: + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +inline int select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return HANDLE_EINTR(poll, &pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); +#endif +} + +inline int select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return HANDLE_EINTR(poll, &pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); +#endif +} + +inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + int poll_res = HANDLE_EINTR(poll, &pfd_read, 1, timeout); + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + return res >= 0 && !error; + } + return false; +#else + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + if (select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 && + (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; +#endif +} + +class SocketStream : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream : public Stream { +public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; +#endif + +class BufferStream : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +template +inline bool process_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + assert(keep_alive_max_count > 0); + + auto ret = false; + + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { // keep_alive_max_count is 0 or 1 + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(strm, true, dummy_connection_close); + } + + return ret; +} + +template +inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + auto ret = process_socket(is_client_request, sock, keep_alive_max_count, + read_timeout_sec, read_timeout_usec, callback); + close_socket(sock); + return ret; +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +template +socket_t create_socket(const char *host, int port, Fn fn, + int socket_flags = 0) { +#ifdef _WIN32 +#define SO_SYNCHRONOUS_NONALERT 0x20 +#define SO_OPENTYPE 0x7008 + + int opt = SO_SYNCHRONOUS_NONALERT; + setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char *)&opt, + sizeof(opt)); +#endif + + // Get address info + struct addrinfo hints; + struct addrinfo *result = nullptr; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; + + auto service = std::to_string(port); + + if (getaddrinfo(host, service.c_str(), &hints, &result)) { + if (result != nullptr) { freeaddrinfo(result); } + return INVALID_SOCKET; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { continue; } + +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } +#endif + + // Make 'reuse address' option available + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); +#endif + + // bind or connect + if (fn(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const char *host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host, "0", &hints, &result)) { return false; } + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; +} + +inline std::string if2ip(const std::string &ifn) { +#ifndef _WIN32 + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); +#endif + return std::string(); +} + +inline socket_t create_client_socket(const char *host, int port, + time_t timeout_sec, + const std::string &intf) { + return create_socket( + host, port, [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { + auto ip = if2ip(intf); + if (ip.empty()) { ip = intf; } + if (!bind_ip_address(sock, ip.c_str())) { return false; } + } + + set_nonblocking(sock, true); + + auto ret = + ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, 0)) { + close_socket(sock); + return false; + } + } + + set_nonblocking(sock, false); + return true; + }); +} + +inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, + int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } + + std::array ipstr{}; + if (!getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + ip = ipstr.data(); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { + get_remote_ip_and_port(addr, addr_len, ip, port); + } +} + +inline const char * +find_content_type(const std::string &path, + const std::map &user_data) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second.c_str(); } + + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "wasm") { + return "application/wasm"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; +} + +inline const char *status_message(int status) { + switch (status) { + case 100: return "Continue"; + case 101: return "Switching Protocol"; + case 102: return "Processing"; + case 103: return "Early Hints"; + case 200: return "OK"; + case 201: return "Created"; + case 202: return "Accepted"; + case 203: return "Non-Authoritative Information"; + case 204: return "No Content"; + case 205: return "Reset Content"; + case 206: return "Partial Content"; + case 207: return "Multi-Status"; + case 208: return "Already Reported"; + case 226: return "IM Used"; + case 300: return "Multiple Choice"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 303: return "See Other"; + case 304: return "Not Modified"; + case 305: return "Use Proxy"; + case 306: return "unused"; + case 307: return "Temporary Redirect"; + case 308: return "Permanent Redirect"; + case 400: return "Bad Request"; + case 401: return "Unauthorized"; + case 402: return "Payment Required"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 405: return "Method Not Allowed"; + case 406: return "Not Acceptable"; + case 407: return "Proxy Authentication Required"; + case 408: return "Request Timeout"; + case 409: return "Conflict"; + case 410: return "Gone"; + case 411: return "Length Required"; + case 412: return "Precondition Failed"; + case 413: return "Payload Too Large"; + case 414: return "URI Too Long"; + case 415: return "Unsupported Media Type"; + case 416: return "Range Not Satisfiable"; + case 417: return "Expectation Failed"; + case 418: return "I'm a teapot"; + case 421: return "Misdirected Request"; + case 422: return "Unprocessable Entity"; + case 423: return "Locked"; + case 424: return "Failed Dependency"; + case 425: return "Too Early"; + case 426: return "Upgrade Required"; + case 428: return "Precondition Required"; + case 429: return "Too Many Requests"; + case 431: return "Request Header Fields Too Large"; + case 451: return "Unavailable For Legal Reasons"; + case 501: return "Not Implemented"; + case 502: return "Bad Gateway"; + case 503: return "Service Unavailable"; + case 504: return "Gateway Timeout"; + case 505: return "HTTP Version Not Supported"; + case 506: return "Variant Also Negotiates"; + case 507: return "Insufficient Storage"; + case 508: return "Loop Detected"; + case 510: return "Not Extended"; + case 511: return "Network Authentication Required"; + + default: + case 500: return "Internal Server Error"; + } +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline bool can_compress(const std::string &content_type) { + return !content_type.find("text/") || content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; +} + +inline bool compress(std::string &content) { + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY); + if (ret != Z_OK) { return false; } + + strm.avail_in = static_cast(content.size()); + strm.next_in = + const_cast(reinterpret_cast(content.data())); + + std::string compressed; + + std::array buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + ret = deflate(&strm, Z_FINISH); + assert(ret != Z_STREAM_ERROR); + compressed.append(buff.data(), buff.size() - strm.avail_out); + } while (strm.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm.avail_in == 0); + + content.swap(compressed); + + deflateEnd(&strm); + return true; +} + +class decompressor { +public: + decompressor() { + std::memset(&strm, 0, sizeof(strm)); + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm, 32 + 15) == Z_OK; + } + + ~decompressor() { inflateEnd(&strm); } + + bool is_valid() const { return is_valid_; } + + template + bool decompress(const char *data, size_t data_length, T callback) { + int ret = Z_OK; + + strm.avail_in = static_cast(data_length); + strm.next_in = const_cast(reinterpret_cast(data)); + + std::array buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm); return false; + } + + if (!callback(buff.data(), buff.size() - strm.avail_out)) { + return false; + } + } while (strm.avail_out == 0); + + return ret == Z_OK || ret == Z_STREAM_END; + } + +private: + bool is_valid_; + z_stream strm; +}; +#endif + +inline bool has_header(const Headers &headers, const char *key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, const char *key, + size_t id = 0, const char *def = nullptr) { + auto it = headers.find(key); + std::advance(it, static_cast(id)); + if (it != headers.end()) { return it->second.c_str(); } + return def; +} + +inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, + uint64_t def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { + continue; // Skip invalid line. + } + + // Skip trailing spaces and tabs. + auto end = line_reader.ptr() + line_reader.size() - 2; + while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) { + end--; + } + + // Horizontal tab and ' ' are considered whitespace and are ignored when on + // the left or right side of the header value: + // - https://stackoverflow.com/questions/50179659/ + // - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html + static const std::regex re(R"(([^:]+):[\t ]*(.+))"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), end, m, re)) { + auto key = std::string(m[1]); + auto val = std::string(m[2]); + headers.emplace(key, val); + } + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n))) { return false; } + + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, static_cast(n))) { return false; } + } + + return true; +} + +inline bool read_content_chunked(Stream &strm, ContentReceiver out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { return false; } + if (chunk_len == ULONG_MAX) { return false; } + + if (chunk_len == 0) { break; } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n")) { break; } + + if (!line_reader.getline()) { return false; } + } + + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiver receiver) { + + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor decompressor; + + std::string content_encoding = x.get_header_value("Content-Encoding"); + if (content_encoding.find("gzip") != std::string::npos || + content_encoding.find("deflate") != std::string::npos) { + if (!decompressor.is_valid()) { + status = 500; + return false; + } + + out = [&](const char *buf, size_t n) { + return decompressor.decompress( + buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + } +#else + if (x.get_header_value("Content-Encoding") == "gzip") { + status = 415; + return false; + } +#endif + + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } + + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + + return ret; +} + +template +inline ssize_t write_headers(Stream &strm, const T &info, + const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : info.headers) { + if (x.first == "EXCEPTION_WHAT") { continue; } + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline ssize_t write_content(Stream &strm, ContentProvider content_provider, + size_t offset, size_t length) { + size_t begin_offset = offset; + size_t end_offset = offset + length; + while (offset < end_offset) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }; + data_sink.done = [&](void) { written_length = -1; }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, end_offset - offset, data_sink); + if (written_length < 0) { return written_length; } + } + return static_cast(offset - begin_offset); +} + +template +inline ssize_t write_content_chunked(Stream &strm, + ContentProvider content_provider, + T is_shutting_down) { + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while (data_available && !is_shutting_down()) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + data_available = l > 0; + offset += l; + + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }; + data_sink.done = [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, 0, data_sink); + + if (written_length < 0) { return written_length; } + total_written_length += written_length; + } + return total_written_length; +} + +template +inline bool redirect(T &cli, const Request &req, Response &res, + const std::string &path) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; +} + +inline std::string encode_url(const std::string &s) { + std::string result; + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return query; +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + params.emplace(decode_url(key, true), decode_url(val, true)); + }); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { return false; } + + boundary = content_type.substr(pos + 9); + return true; +} + +inline bool parse_range_header(const std::string &s, Ranges &ranges) { + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; +} + +class MultipartFormDataParser { +public: + MultipartFormDataParser() = default; + + void set_boundary(std::string boundary) { boundary_ = std::move(boundary); } + + bool is_valid() const { return is_valid_; } + + template + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + static const std::regex re_content_type(R"(^Content-Type:\s*(.*?)\s*$)", + std::regex_constants::icase); + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + static const std::string dash_ = "--"; + static const std::string crlf_ = "\r\n"; + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + auto pos = buf_.find(pattern); + if (pos != 0) { + is_done_ = true; + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + is_done_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + auto header = buf_.substr(0, pos); + { + std::smatch m; + if (std::regex_match(header, m, re_content_type)) { + file_.content_type = m[1]; + } else if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { pos = buf_.size(); } + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { return true; } + if (buf_.find(crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + if (buf_.find(pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + is_done_ = true; + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + size_t is_valid_ = false; + size_t is_done_ = false; + size_t off_ = 0; + MultipartFormData file_; +}; + +inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; +} + +inline std::string make_multipart_data_boundary() { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; +} + +inline std::pair +get_range_offset_and_length(const Request &req, size_t content_length, + size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + auto slen = static_cast(content_length); + + if (r.first == -1) { + r.first = slen - r.second; + r.second = slen - 1; + } + + if (r.second == -1) { r.second = slen - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline std::string make_content_range_header_field(size_t offset, size_t length, + size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + SToken stoken, CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { return false; } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; +} + +inline std::string make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; +} + +inline size_t +get_multipart_ranges_data_length(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider, offset, length) >= 0; + }); +} + +inline std::pair +get_range_offset_and_length(const Request &req, const Response &res, + size_t index) { + auto r = req.ranges[index]; + + if (r.second == -1) { + r.second = static_cast(res.content_length) - 1; + } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI" || req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +inline bool has_crlf(const char *s) { + auto p = s; + while (*p) { + if (*p == '\r' || *p == '\n') { return true; } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +template +inline std::string message_digest(const std::string &s, Init init, + Update update, Final final, + size_t digest_length) { + using namespace std; + + std::vector md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); + + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); +} +#endif + +template +inline ssize_t handle_EINTR(T fn) { + ssize_t res = false; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + continue; + } + break; + } + return res; +} + +#define HANDLE_EINTR(method, ...) (handle_EINTR([&]() { return method(__VA_ARGS__); })) + +#ifdef _WIN32 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } + + ~WSInit() { WSACleanup(); } +}; + +static WSInit wsinit_; +#endif + +} // namespace detail + +// Header utilities +inline std::pair make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", field); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + using namespace std; + + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } + + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + string response; + { + auto H = algo == "SHA-256" + ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + + auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") + + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + + "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc + + "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\""; + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const httplib::Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 +inline std::string random_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[static_cast(rand()) % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +// Request implementation +inline bool Request::has_header(const char *key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Request::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const char *key, const char *val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Request::set_header(const char *key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const char *key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const char *key, size_t id) const { + auto it = params.find(key); + std::advance(it, static_cast(id)); + if (it != params.end()) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const char *key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); +} + +inline bool Request::has_file(const char *key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const char *key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); +} + +// Response implementation +inline bool Response::has_header(const char *key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const char *key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Response::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const char *key, const char *val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_header(const char *key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const char *url, int status) { + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= status && status < 400) { + this->status = status; + } else { + this->status = 302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, + const char *content_type) { + body.assign(s, n); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(std::string s, const char *content_type) { + body = std::move(s); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t in_length, + std::function provider, + std::function resource_releaser) { + assert(in_length > 0); + content_length = in_length; + content_provider = [provider](size_t offset, size_t length, DataSink &sink) { + provider(offset, length, sink); + }; + content_provider_resource_releaser = resource_releaser; +} + +inline void Response::set_chunked_content_provider( + std::function provider, + std::function resource_releaser) { + content_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink &sink) { + provider(offset, sink); + }; + content_provider_resource_releaser = resource_releaser; +} + +// Rstream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +template +inline ssize_t Stream::write_format(const char *fmt, const Args &... args) { + std::array buf; + +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto sn = _snprintf_s(buf, bufsiz, buf.size() - 1, fmt, args...); +#else + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); +#endif + if (sn <= 0) { return sn; } + + auto n = static_cast(sn); + + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); + + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); +#if defined(_MSC_VER) && _MSC_VER < 1900 + n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, + args...)); +#else + n = static_cast( + snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); +#endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); + } +} + +namespace detail { + +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SocketStream::~SocketStream() {} + +inline bool SocketStream::is_readable() const { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SocketStream::is_writable() const { + return select_write(sock_, 0, 0) > 0; +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { + if (!is_readable()) { return -1; } + +#ifdef _WIN32 + if (size > static_cast(std::numeric_limits::max())) { + return -1; + } + return recv(sock_, ptr, static_cast(size), 0); +#else + return HANDLE_EINTR(recv, sock_, ptr, size, 0); +#endif +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!is_writable()) { return -1; } + +#ifdef _WIN32 + if (size > static_cast(std::numeric_limits::max())) { + return -1; + } + return send(sock_, ptr, static_cast(size), 0); +#else + return send(sock_, ptr, size, 0); +#endif +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +} // namespace detail + +// HTTP server implementation +inline Server::Server() + : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), + read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), + payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), + svr_sock_(INVALID_SOCKET) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif + new_task_queue = [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }; +} + +inline Server::~Server() {} + +inline Server &Server::Get(const char *pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Post(const char *pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Post(const char *pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Put(const char *pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Put(const char *pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Patch(const char *pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Patch(const char *pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Delete(const char *pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Delete(const char *pattern, + HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Options(const char *pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline bool Server::set_base_dir(const char *dir, const char *mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const char *mount_point, const char *dir) { + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const char *mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline void Server::set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime) { + file_extension_and_mimetype_map_[ext] = mime; +} + +inline void Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); +} + +inline void Server::set_error_handler(Handler handler) { + error_handler_ = std::move(handler); +} + +inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } + +inline void +Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); +} + +inline void Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; +} + +inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; +} +inline int Server::bind_to_any_port(const char *host, int socket_flags) { + return bind_internal(host, 0, socket_flags); +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const char *host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline bool Server::parse_request_line(const char *s, Request &req) { + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3], false); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { detail::parse_query_text(m[4], req.params); } + + return true; + } + + return false; +} + +inline bool Server::write_response(Stream &strm, bool last_connection, + const Request &req, Response &res) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_) { error_handler_(req, res); } + + detail::BufferStream bstrm; + + // Response line + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } + + // Headers + if (last_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } + + if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + res.set_header("Connection", "Keep-Alive"); + } + + if (!res.has_header("Content-Type") && + (!res.body.empty() || res.content_length > 0)) { + res.set_header("Content-Type", "text/plain"); + } + + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } + + std::string content_type; + std::string boundary; + + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); + + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + if (res.body.empty()) { + if (res.content_length > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + } else { + res.set_header("Content-Length", "0"); + } + } + } else { + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); + } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + const auto &encodings = req.get_header_value("Accept-Encoding"); + if (encodings.find("gzip") != std::string::npos && + detail::can_compress(res.get_header_value("Content-Type"))) { + if (detail::compress(res.body)) { + res.set_header("Content-Encoding", "gzip"); + } + } +#endif + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } + + if (!detail::write_headers(bstrm, res, Headers())) { return false; } + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { return false; } + } else if (res.content_provider) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + if (res.content_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_length) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } + } + } else { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + if (detail::write_content_chunked(strm, res.content_provider, + is_shutting_down) < 0) { + return false; + } + } + return true; +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, receiver, multipart_header, + multipart_receiver); +} + +inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver) { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n) { + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } else { + out = receiver; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), out)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(Request &req, Response &res, + bool head) { + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.find(mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { res.set_header("Content-Type", type); } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; + } + } + } + } + return false; +} + +inline socket_t Server::create_server_socket(const char *host, int port, + int socket_flags) const { + return detail::create_socket( + host, port, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }, + socket_flags); +} + +inline int Server::bind_internal(const char *host, int port, int socket_flags) { + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; + + { + std::unique_ptr task_queue(new_task_queue()); + + for (;;) { + if (svr_sock_ == INVALID_SOCKET) { + // The server socket was closed by 'stop' method. + break; + } + + auto val = detail::select_read(svr_sock_, 0, 100000); + + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } + + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + +#if __cplusplus > 201703L + task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); +#else + task_queue->enqueue([=]() { process_and_close_socket(sock); }); +#endif + } + + task_queue->shutdown(); + } + + is_running_ = false; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, receiver, + nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, reader, delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + Handlers &handlers) { + + try { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + } catch (const std::exception &ex) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", ex.what()); + } catch (...) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + return false; +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + HandlersForContentReader &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, last_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_close = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_close = true; + } + + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + auto status = 100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case 100: + case 417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, + detail::status_message(status)); + break; + default: return write_response(strm, last_connection, req, res); + } + } + + // Rounting + if (routing(req, res, strm)) { + if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } + } else { + if (res.status == -1) { res.status = 404; } + } + + return write_response(strm, last_connection, req, res); +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + [this](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, last_connection, connection_close, + nullptr); + }); +} + +// HTTP client implementation +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : sock_(INVALID_SOCKET), host_(host), port_(port), + host_and_port_(host_ + ":" + std::to_string(port_)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline Client::~Client() { stop(); } + +inline bool Client::is_valid() const { return true; } + +inline socket_t Client::create_client_socket() const { + if (!proxy_host_.empty()) { + return detail::create_client_socket(proxy_host_.c_str(), proxy_port_, + timeout_sec_, interface_); + } + return detail::create_client_socket(host_.c_str(), port_, timeout_sec_, + interface_); +} + +inline bool Client::read_response_line(Stream &strm, Response &res) { + std::array buf; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), m, re)) { + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + } + + return true; +} + +inline bool Client::send(const Request &req, Response &res) { + sock_ = create_client_socket(); + if (sock_ == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + bool error; + if (!connect(sock_, res, error)) { return error; } + } +#endif + + return process_and_close_socket( + sock_, 1, + [&](Stream &strm, bool last_connection, bool &connection_close) { + return handle_request(strm, req, res, last_connection, + connection_close); + }); +} + +inline bool Client::send(const std::vector &requests, + std::vector &responses) { + size_t i = 0; + while (i < requests.size()) { + sock_ = create_client_socket(); + if (sock_ == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + Response res; + bool error; + if (!connect(sock_, res, error)) { return false; } + } +#endif + + if (!process_and_close_socket(sock_, requests.size() - i, + [&](Stream &strm, bool last_connection, + bool &connection_close) -> bool { + auto &req = requests[i++]; + auto res = Response(); + auto ret = handle_request(strm, req, res, + last_connection, + connection_close); + if (ret) { + responses.emplace_back(std::move(res)); + } + return ret; + })) { + return false; + } + } + + return true; +} + +inline bool Client::handle_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + if (req.path.empty()) { return false; } + + bool ret; + + if (!is_ssl() && !proxy_host_.empty()) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, last_connection, connection_close); + } else { + ret = process_request(strm, req, res, last_connection, connection_close); + } + + if (!ret) { return false; } + + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (res.status == 401 || res.status == 407) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + auto key = is_proxy ? "Proxy-Authorization" : "WWW-Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(make_digest_authentication_header( + req, auth, 1, random_string(10), username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool Client::connect(socket_t sock, Response &res, bool &error) { + error = true; + Response res2; + + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false, connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, + bool &connection_close) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(make_digest_authentication_header( + req3, auth, 1, random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false, + connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + } + } else { + res = res2; + return false; + } + } + + return true; +} +#endif + +inline bool Client::redirect(const Request &req, Response &res) { + if (req.redirect_count == 0) { return false; } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + const static std::regex re( + R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto port_str = m[3].str(); + auto next_path = m[4].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); +#else + return false; +#endif + } else { + Client cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); + } + } +} + +inline bool Client::write_request(Stream &strm, const Request &req, + bool last_connection) { + detail::BufferStream bstrm; + + // Request line + const auto &path = detail::encode_url(req.path); + + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } + + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.5"); + } + + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } + + if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + + detail::write_headers(bstrm, req, headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + auto written_length = strm.write(d, l); + offset += static_cast(written_length); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + while (offset < end_offset) { + req.content_provider(offset, end_offset - offset, data_sink); + } + } + } else { + strm.write(req.body); + } + + return true; +} + +inline std::shared_ptr Client::send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + if (content_provider) { + size_t offset = 0; + + DataSink data_sink; + data_sink.write = [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + offset += data_len; + }; + data_sink.is_writable = [&](void) { return true; }; + + while (offset < content_length) { + content_provider(offset, content_length - offset, data_sink); + } + } else { + req.body = body; + } + + if (!detail::compress(req.body)) { return nullptr; } + req.headers.emplace("Content-Encoding", "gzip"); + } else +#endif + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } + } + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline bool Client::process_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + // Send request + if (!write_request(strm, req, last_connection)) { return false; } + + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + return false; + } + + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + connection_close = true; + } + + if (req.response_handler) { + if (!req.response_handler(res)) { return false; } + } + + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + auto out = + req.content_receiver + ? static_cast([&](const char *buf, size_t n) { + return req.content_receiver(buf, n); + }) + : static_cast([&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { return false; } + res.body.append(buf, n); + return true; + }); + + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, req.progress, out)) { + return false; + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + request_count = (std::min)(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, + read_timeout_sec_, read_timeout_usec_, + callback); +} + +inline bool Client::is_ssl() const { return false; } + +inline std::shared_ptr Client::Get(const char *path) { + return Get(path, Headers(), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline std::shared_ptr +Client::Get(const char *path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), content_receiver, + Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); + + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Head(const char *path) { + return Head(path, Headers()); +} + +inline std::shared_ptr Client::Head(const char *path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Post(const char *path) { + return Post(path, std::string(), nullptr); +} + +inline std::shared_ptr Client::Post(const char *path, + const std::string &body, + const char *content_type) { + return Post(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Post(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Post(const char *path, + const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline std::shared_ptr Client::Post(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Post(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr +Client::Post(const char *path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items) { + auto boundary = detail::make_multipart_data_boundary(); + + std::string body; + + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; + } + + body += "--" + boundary + "--\r\n"; + + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); +} + +inline std::shared_ptr Client::Put(const char *path) { + return Put(path, std::string(), nullptr); +} + +inline std::shared_ptr Client::Put(const char *path, + const std::string &body, + const char *content_type) { + return Put(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Put(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Put(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline std::shared_ptr +Client::Put(const char *path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr Client::Patch(const char *path, + const std::string &body, + const char *content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Patch(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PATCH", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Patch(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Patch(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr Client::Delete(const char *path) { + return Delete(path, Headers(), std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, + const std::string &body, + const char *content_type) { + return Delete(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers) { + return Delete(path, headers, std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Options(const char *path) { + return Options(path, Headers()); +} + +inline std::shared_ptr Client::Options(const char *path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline void Client::stop() { + if (sock_ != INVALID_SOCKET) { + std::atomic sock(sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline void Client::set_timeout_sec(time_t timeout_sec) { + timeout_sec_ = timeout_sec; +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Client::set_basic_auth(const char *username, const char *password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const char *username, + const char *password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void Client::set_follow_location(bool on) { follow_location_ = on; } + +inline void Client::set_compress(bool on) { compress_ = on; } + +inline void Client::set_interface(const char *intf) { interface_ = intf; } + +inline void Client::set_proxy(const char *host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void Client::set_proxy_basic_auth(const char *username, + const char *password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const char *username, + const char *password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} +#endif + +inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); } + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline bool process_and_close_socket_ssl( + bool is_client_request, socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX *ctx, + std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) { + assert(keep_alive_max_count > 0); + + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (!ssl) { + close_socket(sock); + return false; + } + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl)) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + return false; + } + + auto ret = false; + + if (SSL_connect_or_accept(ssl) == 1) { + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(ssl, strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); + } + } + + if (ret) { + SSL_shutdown(ssl); // shutdown only if not already closed by remote + } + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + + return ret; +} + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static std::shared_ptr> openSSL_locks_; + +class SSLThreadLocks { +public: + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } + + ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + +private: + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &lk = (*openSSL_locks_)[static_cast(type)]; + if (mode & CRYPTO_LOCK) { + lk.lock(); + } else { + lk.unlock(); + } + } +}; + +#endif + +class SSLInit { +public: + SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + SSL_load_error_strings(); + SSL_library_init(); +#else + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); +#endif + } + + ~SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + ERR_free_strings(); +#endif + } + +private: +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSLThreadLocks thread_init_; +#endif +}; + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SSLSocketStream::~SSLSocketStream() {} + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0 || + select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +static SSLInit sslinit_; + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, + [this](SSL *ssl, Stream &strm, bool last_connection, + bool &connection_close) { + return process_request(strm, last_connection, connection_close, + [&](Request &req) { req.ssl = ssl; }); + }); +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : Client(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key) + : Client(host, port) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (client_cert != nullptr && client_key != nullptr) { + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path) { + if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } + if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } +} + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { ca_cert_store_ = ca_cert_store; } +} + +inline void SSLClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + + request_count = std::min(request_count, keep_alive_max_count_); + + return is_valid() && + detail::process_and_close_socket_ssl( + true, sock, request_count, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); + } else if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations( + ctx_, ca_cert_file_path_.c_str(), nullptr)) { + return false; + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } else if (ca_cert_store_ != nullptr) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { + SSL_CTX_set_cert_store(ctx_, ca_cert_store_); + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } + + if (SSL_connect(ssl) != 1) { return false; } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); + + if (verify_result_ != X509_V_OK) { return false; } + + auto server_cert = SSL_get_peer_certificate(ssl); + + if (server_cert == nullptr) { return false; } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }, + [&](SSL * /*ssl*/, Stream &strm, bool last_connection, + bool &connection_close) { + return callback(strm, last_connection, connection_close); + }); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (auto i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } + } + } + + if (dsn_matched || ip_mached) { ret = true; } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} +#endif + +namespace url { + +struct Options { + // TODO: support more options... + bool follow_location = false; + std::string client_cert_path; + std::string client_key_path; + + std::string ca_cert_file_path; + std::string ca_cert_dir_path; + bool server_certificate_verification = false; +}; + +inline std::shared_ptr Get(const char *url, Options &options) { + const static std::regex re( + R"(^(https?)://([^:/?#]+)(?::(\d+))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + std::cmatch m; + if (!std::regex_match(url, m, re)) { return nullptr; } + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto port_str = m[3].str(); + auto next_path = m[4].str(); + + auto next_port = !port_str.empty() ? std::stoi(port_str) + : (next_scheme == "https" ? 443 : 80); + + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str(), next_port, options.client_cert_path, + options.client_key_path); + cli.set_follow_location(options.follow_location); + cli.set_ca_cert_path(options.ca_cert_file_path.c_str(), + options.ca_cert_dir_path.c_str()); + cli.enable_server_certificate_verification( + options.server_certificate_verification); + return cli.Get(next_path.c_str()); +#else + return nullptr; +#endif + } else { + Client cli(next_host.c_str(), next_port, options.client_cert_path, + options.client_key_path); + cli.set_follow_location(options.follow_location); + return cli.Get(next_path.c_str()); + } +} + +inline std::shared_ptr Get(const char *url) { + Options options; + return Get(url, options); +} + +} // namespace url + +namespace detail { + +#undef HANDLE_EINTR + +} // namespace detail + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#endif // CPPHTTPLIB_HTTPLIB_H + diff --git a/include/parallel_hashmap/phmap.h b/include/parallel_hashmap/phmap.h index ffe68b22c..7c1cd95a0 100644 --- a/include/parallel_hashmap/phmap.h +++ b/include/parallel_hashmap/phmap.h @@ -34,6 +34,23 @@ // limitations under the License. // --------------------------------------------------------------------------- +#ifdef _MSC_VER + #pragma warning(push) + + #pragma warning(disable : 4127) // conditional expression is constant + #pragma warning(disable : 4324) // structure was padded due to alignment specifier + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4623) // default constructor was implicitly defined as deleted + #pragma warning(disable : 4625) // copy constructor was implicitly defined as deleted + #pragma warning(disable : 4626) // assignment operator was implicitly defined as deleted + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion + #pragma warning(disable : 4820) // '6' bytes padding added after data member + #pragma warning(disable : 4868) // compiler may not enforce left-to-right evaluation order in braced initializer list + #pragma warning(disable : 5027) // move assignment operator was implicitly defined as deleted + #pragma warning(disable : 5045) // Compiler will insert Spectre mitigation for memory load if /Qspectre switch specified +#endif + #include #include #include @@ -45,10 +62,11 @@ #include #include #include +#include +#include "phmap_fwd_decl.h" #include "phmap_utils.h" #include "phmap_base.h" -#include "phmap_fwd_decl.h" #if PHMAP_HAVE_STD_STRING_VIEW #include @@ -77,7 +95,7 @@ class probe_seq offset_ &= mask_; } // 0-based probe index. The i-th probe in the probe sequence. - size_t index() const { return index_; } + size_t getindex() const { return index_; } private: size_t mask_; @@ -267,7 +285,7 @@ inline size_t H1(size_t hash, const ctrl_t* ) { #endif -inline ctrl_t H2(size_t hash) { return hash & 0x7F; } +inline ctrl_t H2(size_t hash) { return (ctrl_t)(hash & 0x7F); } inline bool IsEmpty(ctrl_t c) { return c == kEmpty; } inline bool IsFull(ctrl_t c) { return c >= 0; } @@ -276,6 +294,11 @@ inline bool IsEmptyOrDeleted(ctrl_t c) { return c < kSentinel; } #if PHMAP_HAVE_SSE2 +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4365) // conversion from 'int' to 'T', signed/unsigned mismatch +#endif + // -------------------------------------------------------------------------- // https://github.com/abseil/abseil-cpp/issues/209 // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87853 @@ -289,7 +312,7 @@ inline __m128i _mm_cmpgt_epi8_fixed(__m128i a, __m128i b) { #pragma GCC diagnostic ignored "-Woverflow" if (std::is_unsigned::value) { - const __m128i mask = _mm_set1_epi8(0x80); + const __m128i mask = _mm_set1_epi8(static_cast(0x80)); const __m128i diff = _mm_subs_epi8(b, a); return _mm_cmpeq_epi8(_mm_and_si128(diff, mask), mask); } @@ -312,7 +335,7 @@ struct GroupSse2Impl // Returns a bitmask representing the positions of slots that match hash. // ---------------------------------------------------------------------- BitMask Match(h2_t hash) const { - auto match = _mm_set1_epi8(hash); + auto match = _mm_set1_epi8((char)hash); return BitMask( _mm_movemask_epi8(_mm_cmpeq_epi8(match, ctrl))); } @@ -361,6 +384,11 @@ struct GroupSse2Impl __m128i ctrl; }; + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + #endif // PHMAP_HAVE_SSE2 // -------------------------------------------------------------------------- @@ -404,7 +432,7 @@ struct GroupPortableImpl uint32_t CountLeadingEmptyOrDeleted() const { constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL; - return (TrailingZeros(((~ctrl & (ctrl >> 7)) | gaps) + 1) + 7) >> 3; + return (uint32_t)((TrailingZeros(((~ctrl & (ctrl >> 7)) | gaps) + 1) + 7) >> 3); } void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const { @@ -468,9 +496,12 @@ inline size_t CapacityToGrowth(size_t capacity) { assert(IsValidCapacity(capacity)); // `capacity*7/8` - if (Group::kWidth == 8 && capacity == 7) { - // x-x/8 does not work when x==7. - return 6; + PHMAP_IF_CONSTEXPR (Group::kWidth == 8) { + if (capacity == 7) + { + // x-x/8 does not work when x==7. + return 6; + } } return capacity - capacity / 8; } @@ -482,9 +513,12 @@ inline size_t CapacityToGrowth(size_t capacity) inline size_t GrowthToLowerboundCapacity(size_t growth) { // `growth*8/7` - if (Group::kWidth == 8 && growth == 7) { - // x+(x-1)/7 does not work when x==7. - return 8; + PHMAP_IF_CONSTEXPR (Group::kWidth == 8) { + if (growth == 7) + { + // x+(x-1)/7 does not work when x==7. + return 8; + } } return growth + static_cast((static_cast(growth) - 1) / 7); } @@ -642,78 +676,6 @@ DecomposePairImpl(F&& f, std::pair, V> p) { } // namespace memory_internal -// Helper functions for asan and msan. -// ---------------------------------------------------------------------------- -inline void SanitizerPoisonMemoryRegion(const void* m, size_t s) { -#ifdef ADDRESS_SANITIZER - ASAN_POISON_MEMORY_REGION(m, s); -#endif -#ifdef MEMORY_SANITIZER - __msan_poison(m, s); -#endif - (void)m; - (void)s; -} - -inline void SanitizerUnpoisonMemoryRegion(const void* m, size_t s) { -#ifdef ADDRESS_SANITIZER - ASAN_UNPOISON_MEMORY_REGION(m, s); -#endif -#ifdef MEMORY_SANITIZER - __msan_unpoison(m, s); -#endif - (void)m; - (void)s; -} - -template -inline void SanitizerPoisonObject(const T* object) { - SanitizerPoisonMemoryRegion(object, sizeof(T)); -} - -template -inline void SanitizerUnpoisonObject(const T* object) { - SanitizerUnpoisonMemoryRegion(object, sizeof(T)); -} - -// ---------------------------------------------------------------------------- -// Allocates at least n bytes aligned to the specified alignment. -// Alignment must be a power of 2. It must be positive. -// -// Note that many allocators don't honor alignment requirements above certain -// threshold (usually either alignof(std::max_align_t) or alignof(void*)). -// Allocate() doesn't apply alignment corrections. If the underlying allocator -// returns insufficiently alignment pointer, that's what you are going to get. -// ---------------------------------------------------------------------------- -template -void* Allocate(Alloc* alloc, size_t n) { - static_assert(Alignment > 0, ""); - assert(n && "n must be positive"); - struct alignas(Alignment) M {}; - using A = typename phmap::allocator_traits::template rebind_alloc; - using AT = typename phmap::allocator_traits::template rebind_traits; - A mem_alloc(*alloc); - void* p = AT::allocate(mem_alloc, (n + sizeof(M) - 1) / sizeof(M)); - assert(reinterpret_cast(p) % Alignment == 0 && - "allocator does not respect alignment"); - return p; -} - -// ---------------------------------------------------------------------------- -// The pointer must have been previously obtained by calling -// Allocate(alloc, n). -// ---------------------------------------------------------------------------- -template -void Deallocate(Alloc* alloc, void* p, size_t n) { - static_assert(Alignment > 0, ""); - assert(n && "n must be positive"); - struct alignas(Alignment) M {}; - using A = typename phmap::allocator_traits::template rebind_alloc; - using AT = typename phmap::allocator_traits::template rebind_traits; - A mem_alloc(*alloc); - AT::deallocate(mem_alloc, static_cast(p), - (n + sizeof(M) - 1) / sizeof(M)); -} // ---------------------------------------------------------------------------- // R A W _ H A S H _ S E T @@ -1202,6 +1164,8 @@ class raw_hash_set // compared to destruction of the elements of the container. So we pick the // largest bucket_count() threshold for which iteration is still fast and // past that we simply deallocate the array. + if (empty()) + return; if (capacity_ > 127) { destroy_slots(); } else if (capacity_) { @@ -1532,6 +1496,14 @@ class raw_hash_set } } +#ifndef PHMAP_NON_DETERMINISTIC + template + bool dump(OutputArchive&) const; + + template + bool load(InputArchive&); +#endif + void rehash(size_t n) { if (n == 0 && capacity_ == 0) return; if (n == 0 && size_ == 0) { @@ -1541,7 +1513,7 @@ class raw_hash_set } // bitor is a faster way of doing `max` here. We will round up to the next // power-of-2-minus-1, so bitor is good enough. - auto m = NormalizeCapacity(n | GrowthToLowerboundCapacity(size())); + auto m = NormalizeCapacity((std::max)(n, size())); // n == 0 unconditionally rehashes as per the standard. if (n == 0 || m > capacity_) { resize(m); @@ -1561,7 +1533,7 @@ class raw_hash_set // s.count("abc"); template size_t count(const key_arg& key) const { - return find(key) == end() ? 0 : 1; + return find(key) == end() ? size_t(0) : size_t(1); } // Issues CPU prefetch instructions for the memory needed to find or insert @@ -1595,11 +1567,11 @@ class raw_hash_set auto seq = probe(hash); while (true) { Group g{ctrl_ + seq.offset()}; - for (int i : g.Match(H2(hash))) { + for (int i : g.Match((h2_t)H2(hash))) { if (PHMAP_PREDICT_TRUE(PolicyTraits::apply( EqualElement{key, eq_ref()}, - PolicyTraits::element(slots_ + seq.offset(i))))) - return iterator_at(seq.offset(i)); + PolicyTraits::element(slots_ + seq.offset((size_t)i))))) + return iterator_at(seq.offset((size_t)i)); } if (PHMAP_PREDICT_TRUE(g.MatchEmpty())) return end(); @@ -1770,7 +1742,7 @@ class raw_hash_set void erase_meta_only(const_iterator it) { assert(IsFull(*it.inner_.ctrl_) && "erasing a dangling iterator"); --size_; - const size_t index = it.inner_.ctrl_ - ctrl_; + const size_t index = (size_t)(it.inner_.ctrl_ - ctrl_); const size_t index_before = (index - Group::kWidth) & capacity_; const auto empty_after = Group(it.inner_.ctrl_).MatchEmpty(); const auto empty_before = Group(ctrl_ + index_before).MatchEmpty(); @@ -1929,14 +1901,14 @@ class raw_hash_set auto seq = probe(hash); while (true) { Group g{ctrl_ + seq.offset()}; - for (int i : g.Match(H2(hash))) { - if (PHMAP_PREDICT_TRUE(PolicyTraits::element(slots_ + seq.offset(i)) == + for (int i : g.Match((h2_t)H2(hash))) { + if (PHMAP_PREDICT_TRUE(PolicyTraits::element(slots_ + seq.offset((size_t)i)) == elem)) return true; } if (PHMAP_PREDICT_TRUE(g.MatchEmpty())) return false; seq.next(); - assert(seq.index() < capacity_ && "full table!"); + assert(seq.getindex() < capacity_ && "full table!"); } return false; } @@ -1966,9 +1938,9 @@ class raw_hash_set Group g{ctrl_ + seq.offset()}; auto mask = g.MatchEmptyOrDeleted(); if (mask) { - return {seq.offset(mask.LowestBitSet()), seq.index()}; + return {seq.offset((size_t)mask.LowestBitSet()), seq.getindex()}; } - assert(seq.index() < capacity_ && "full table!"); + assert(seq.getindex() < capacity_ && "full table!"); seq.next(); } } @@ -1991,11 +1963,11 @@ class raw_hash_set auto seq = probe(hash); while (true) { Group g{ctrl_ + seq.offset()}; - for (int i : g.Match(H2(hash))) { + for (int i : g.Match((h2_t)H2(hash))) { if (PHMAP_PREDICT_TRUE(PolicyTraits::apply( EqualElement{key, eq_ref()}, - PolicyTraits::element(slots_ + seq.offset(i))))) - return {seq.offset(i), false}; + PolicyTraits::element(slots_ + seq.offset((size_t)i))))) + return {seq.offset((size_t)i), false}; } if (PHMAP_PREDICT_TRUE(g.MatchEmpty())) break; seq.next(); @@ -2244,14 +2216,16 @@ class raw_hash_map : public raw_hash_set template MappedReference

at(const key_arg& key) { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } template MappedConstReference

at(const key_arg& key) const { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } @@ -2351,7 +2325,7 @@ class parallel_hash_set using Lockable = phmap::LockableImpl; // -------------------------------------------------------------------- - struct alignas(64) Inner : public Lockable + struct Inner : public Lockable { bool operator==(const Inner& o) const { @@ -2892,7 +2866,7 @@ class parallel_hash_set Inner& inner = sets_[subidx(hashval)]; auto& set = inner.set_; typename Lockable::UniqueLock m(inner); - return make_iterator(&inner, set.lazy_emplace(key, hashval, std::forward(f))); + return make_iterator(&inner, set.lazy_emplace_with_hash(key, hashval, std::forward(f))); } // Extension API: support for heterogeneous keys. @@ -3008,7 +2982,12 @@ class parallel_hash_set } } - void reserve(size_t n) { rehash(GrowthToLowerboundCapacity(n)); } + void reserve(size_t n) + { + size_t target = GrowthToLowerboundCapacity(n); + size_t normalized = 16 * NormalizeCapacity(n / num_tables); + rehash(normalized > target ? normalized : target); + } // Extension API: support for heterogeneous keys. // @@ -3053,11 +3032,8 @@ class parallel_hash_set // -------------------------------------------------------------------- template iterator find(const key_arg& key, size_t hashval) { - Inner& inner = sets_[subidx(hashval)]; - auto& set = inner.set_; - typename Lockable::SharedLock m(inner); - auto it = set.find(key, hashval); - return make_iterator(&inner, it); + typename Lockable::SharedLock m; + return find(key, hashval, m); } template @@ -3132,6 +3108,14 @@ class parallel_hash_set a.swap(b); } +#ifndef PHMAP_NON_DETERMINISTIC + template + bool dump(OutputArchive& ar) const; + + template + bool load(InputArchive& ar); +#endif + private: template friend struct phmap::container_internal::hashtable_debug_internal::HashtableDebugAccess; @@ -3207,6 +3191,15 @@ class parallel_hash_set } protected: + template + iterator find(const key_arg& key, size_t hashval, typename Lockable::SharedLock &mutexlock) { + Inner& inner = sets_[subidx(hashval)]; + auto& set = inner.set_; + mutexlock = std::move(typename Lockable::SharedLock(inner)); + auto it = set.find(key, hashval); + return make_iterator(&inner, it); + } + template std::tuple find_or_prepare_insert(const K& key, typename Lockable::UniqueLock &mutexlock) { @@ -3232,7 +3225,7 @@ class parallel_hash_set } template - size_t hash(const K& key) { + size_t hash(const K& key) const { return HashElement{hash_ref()}(key); } @@ -3387,17 +3380,32 @@ class parallel_hash_map : public parallel_hash_set MappedReference

at(const key_arg& key) { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } template MappedConstReference

at(const key_arg& key) const { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } + template + bool if_contains(const key_arg& key, F&& f) const { +#if __cplusplus >= 201703L + static_assert(std::is_invocable::value); +#endif + typename Lockable::SharedLock m; + auto it = const_cast(this)->find(key, this->hash(key), m); + if (it == this->end()) + return false; + std::forward(f)(Policy::value(&*it)); + return true; + } + template MappedReference

operator[](key_arg&& key) { return Policy::value(&*try_emplace(std::forward(key)).first); @@ -3521,214 +3529,6 @@ DecomposeValue(F&& f, Arg&& arg) { } -namespace memory_internal { - -// ---------------------------------------------------------------------------- -// If Pair is a standard-layout type, OffsetOf::kFirst and -// OffsetOf::kSecond are equivalent to offsetof(Pair, first) and -// offsetof(Pair, second) respectively. Otherwise they are -1. -// -// The purpose of OffsetOf is to avoid calling offsetof() on non-standard-layout -// type, which is non-portable. -// ---------------------------------------------------------------------------- -template -struct OffsetOf { - static constexpr size_t kFirst = -1; - static constexpr size_t kSecond = -1; -}; - -template -struct OffsetOf::type> -{ - static constexpr size_t kFirst = offsetof(Pair, first); - static constexpr size_t kSecond = offsetof(Pair, second); -}; - -// ---------------------------------------------------------------------------- -template -struct IsLayoutCompatible -{ -private: - struct Pair { - K first; - V second; - }; - - // Is P layout-compatible with Pair? - template - static constexpr bool LayoutCompatible() { - return std::is_standard_layout

() && sizeof(P) == sizeof(Pair) && - alignof(P) == alignof(Pair) && - memory_internal::OffsetOf

::kFirst == - memory_internal::OffsetOf::kFirst && - memory_internal::OffsetOf

::kSecond == - memory_internal::OffsetOf::kSecond; - } - -public: - // Whether pair and pair are layout-compatible. If they are, - // then it is safe to store them in a union and read from either. - static constexpr bool value = std::is_standard_layout() && - std::is_standard_layout() && - memory_internal::OffsetOf::kFirst == 0 && - LayoutCompatible>() && - LayoutCompatible>(); -}; - -} // namespace memory_internal - -// ---------------------------------------------------------------------------- -// The internal storage type for key-value containers like flat_hash_map. -// -// It is convenient for the value_type of a flat_hash_map to be -// pair; the "const K" prevents accidental modification of the key -// when dealing with the reference returned from find() and similar methods. -// However, this creates other problems; we want to be able to emplace(K, V) -// efficiently with move operations, and similarly be able to move a -// pair in insert(). -// -// The solution is this union, which aliases the const and non-const versions -// of the pair. This also allows flat_hash_map to work, even though -// that has the same efficiency issues with move in emplace() and insert() - -// but people do it anyway. -// -// If kMutableKeys is false, only the value member can be accessed. -// -// If kMutableKeys is true, key can be accessed through all slots while value -// and mutable_value must be accessed only via INITIALIZED slots. Slots are -// created and destroyed via mutable_value so that the key can be moved later. -// -// Accessing one of the union fields while the other is active is safe as -// long as they are layout-compatible, which is guaranteed by the definition of -// kMutableKeys. For C++11, the relevant section of the standard is -// https://timsong-cpp.github.io/cppwp/n3337/class.mem#19 (9.2.19) -// ---------------------------------------------------------------------------- -template -union map_slot_type -{ - map_slot_type() {} - ~map_slot_type() = delete; - using value_type = std::pair; - using mutable_value_type = std::pair; - - value_type value; - mutable_value_type mutable_value; - K key; -}; - -// ---------------------------------------------------------------------------- -// ---------------------------------------------------------------------------- -template -struct map_slot_policy -{ - using slot_type = map_slot_type; - using value_type = std::pair; - using mutable_value_type = std::pair; - -private: - static void emplace(slot_type* slot) { - // The construction of union doesn't do anything at runtime but it allows us - // to access its members without violating aliasing rules. - new (slot) slot_type; - } - // If pair and pair are layout-compatible, we can accept one - // or the other via slot_type. We are also free to access the key via - // slot_type::key in this case. - using kMutableKeys = memory_internal::IsLayoutCompatible; - -public: - static value_type& element(slot_type* slot) { return slot->value; } - static const value_type& element(const slot_type* slot) { - return slot->value; - } - - static const K& key(const slot_type* slot) { - return kMutableKeys::value ? slot->key : slot->value.first; - } - - template - static void construct(Allocator* alloc, slot_type* slot, Args&&... args) { - emplace(slot); - if (kMutableKeys::value) { - phmap::allocator_traits::construct(*alloc, &slot->mutable_value, - std::forward(args)...); - } else { - phmap::allocator_traits::construct(*alloc, &slot->value, - std::forward(args)...); - } - } - - // Construct this slot by moving from another slot. - template - static void construct(Allocator* alloc, slot_type* slot, slot_type* other) { - emplace(slot); - if (kMutableKeys::value) { - phmap::allocator_traits::construct( - *alloc, &slot->mutable_value, std::move(other->mutable_value)); - } else { - phmap::allocator_traits::construct(*alloc, &slot->value, - std::move(other->value)); - } - } - - template - static void destroy(Allocator* alloc, slot_type* slot) { - if (kMutableKeys::value) { - phmap::allocator_traits::destroy(*alloc, &slot->mutable_value); - } else { - phmap::allocator_traits::destroy(*alloc, &slot->value); - } - } - - template - static void transfer(Allocator* alloc, slot_type* new_slot, - slot_type* old_slot) { - emplace(new_slot); - if (kMutableKeys::value) { - phmap::allocator_traits::construct( - *alloc, &new_slot->mutable_value, std::move(old_slot->mutable_value)); - } else { - phmap::allocator_traits::construct(*alloc, &new_slot->value, - std::move(old_slot->value)); - } - destroy(alloc, old_slot); - } - - template - static void swap(Allocator* alloc, slot_type* a, slot_type* b) { - if (kMutableKeys::value) { - using std::swap; - swap(a->mutable_value, b->mutable_value); - } else { - value_type tmp = std::move(a->value); - phmap::allocator_traits::destroy(*alloc, &a->value); - phmap::allocator_traits::construct(*alloc, &a->value, - std::move(b->value)); - phmap::allocator_traits::destroy(*alloc, &b->value); - phmap::allocator_traits::construct(*alloc, &b->value, - std::move(tmp)); - } - } - - template - static void move(Allocator* alloc, slot_type* src, slot_type* dest) { - if (kMutableKeys::value) { - dest->mutable_value = std::move(src->mutable_value); - } else { - phmap::allocator_traits::destroy(*alloc, &dest->value); - phmap::allocator_traits::construct(*alloc, &dest->value, - std::move(src->value)); - } - } - - template - static void move(Allocator* alloc, slot_type* first, slot_type* last, - slot_type* result) { - for (slot_type *src = first, *dest = result; src != last; ++src, ++dest) - move(alloc, src, dest); - } -}; - // -------------------------------------------------------------------------- // Policy: a policy defines how to perform different operations on // the slots of the hashtable (see hash_policy_traits.h for the full interface @@ -3977,7 +3777,7 @@ struct StringHashT size_t operator()(std::basic_string_view v) const { std::string_view bv{reinterpret_cast(v.data()), v.size() * sizeof(CharT)}; - return phmap::Hash{}(bv); + return std::hash()(bv); } }; @@ -4019,6 +3819,7 @@ struct HashEq : StringHashEqT {}; #endif // Supports heterogeneous lookup for pointers and smart pointers. +// ------------------------------------------------------------- template struct HashEq { @@ -4040,6 +3841,7 @@ struct HashEq private: static const T* ToPtr(const T* ptr) { return ptr; } + template static const T* ToPtr(const std::unique_ptr& ptr) { return ptr.get(); @@ -4078,7 +3880,7 @@ struct HashtableDebugAccess> if (Traits::apply( typename Set::template EqualElement{ key, set.eq_ref()}, - Traits::element(set.slots_ + seq.offset(i)))) + Traits::element(set.slots_ + seq.offset((size_t)i)))) return num_probes; ++num_probes; } @@ -4131,7 +3933,6 @@ struct HashtableDebugAccess> // Its interface is similar to that of `std::unordered_set` with the // following notable differences: // -// * Requires keys that are CopyConstructible // * Supports heterogeneous lookup, through `find()`, `operator[]()` and // `insert()`, provided that the set is provided a compatible heterogeneous // hashing function and equality operator. @@ -4139,7 +3940,7 @@ struct HashtableDebugAccess> // `rehash()`. // * Contains a `capacity()` member function indicating the number of element // slots (open, deleted, and empty) within the hash set. -// * Returns `void` from the `erase(iterator)` overload. +// * Returns `void` from the `_erase(iterator)` overload. // ----------------------------------------------------------------------------- template // default values in phmap_fwd_decl.h class flat_hash_set @@ -4194,8 +3995,6 @@ class flat_hash_set // cases. Its interface is similar to that of `std::unordered_map` with // the following notable differences: // -// * Requires keys that are CopyConstructible -// * Requires values that are MoveConstructible // * Supports heterogeneous lookup, through `find()`, `operator[]()` and // `insert()`, provided that the map is provided a compatible heterogeneous // hashing function and equality operator. @@ -4203,7 +4002,7 @@ class flat_hash_set // `rehash()`. // * Contains a `capacity()` member function indicating the number of element // slots (open, deleted, and empty) within the hash map. -// * Returns `void` from the `erase(iterator)` overload. +// * Returns `void` from the `_erase(iterator)` overload. // ----------------------------------------------------------------------------- template // default values in phmap_fwd_decl.h class flat_hash_map : public phmap::container_internal::raw_hash_map< @@ -4600,4 +4399,9 @@ class parallel_node_hash_map } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + + #endif // phmap_h_guard_ diff --git a/include/parallel_hashmap/phmap_base.h b/include/parallel_hashmap/phmap_base.h index bbc6712d9..27976826c 100644 --- a/include/parallel_hashmap/phmap_base.h +++ b/include/parallel_hashmap/phmap_base.h @@ -54,6 +54,17 @@ #include // after "phmap_config.h" #endif +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4582) // constructor is not implicitly called + #pragma warning(disable : 4625) // copy constructor was implicitly defined as deleted + #pragma warning(disable : 4626) // assignment operator was implicitly defined as deleted + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion + #pragma warning(disable : 4820) // '6' bytes padding added after data member +#endif // _MSC_VER + namespace phmap { template using Allocator = typename std::allocator; @@ -69,6 +80,15 @@ struct EqualTo } }; +template +struct Less +{ + inline bool operator()(const T& a, const T& b) const + { + return std::less()(a, b); + } +}; + namespace type_traits_internal { template @@ -213,177 +233,31 @@ struct disjunction : T {}; template <> struct disjunction<> : std::false_type {}; -// --------------------------------------------------------------------------- -// negation -// -// Performs a compile-time logical NOT operation on the passed type (which -// must have `::value` members convertible to `bool`. -// -// This metafunction is designed to be a drop-in replacement for the C++17 -// `std::negation` metafunction. -// --------------------------------------------------------------------------- template struct negation : std::integral_constant {}; -// --------------------------------------------------------------------------- -// is_trivially_destructible() -// -// Determines whether the passed type `T` is trivially destructable. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_destructible()` metafunction for platforms that have -// incomplete C++11 support (such as libstdc++ 4.x). On any platforms that do -// fully support C++11, we check whether this yields the same result as the std -// implementation. -// -// NOTE: the extensions (__has_trivial_xxx) are implemented in gcc (version >= -// 4.3) and clang. Since we are supporting libstdc++ > 4.7, they should always -// be present. These extensions are documented at -// https://gcc.gnu.org/onlinedocs/gcc/Type-Traits.html#Type-Traits. -// --------------------------------------------------------------------------- template struct is_trivially_destructible : std::integral_constant::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_DESTRUCTIBLE -private: - static constexpr bool compliant = std::is_trivially_destructible::value == - is_trivially_destructible::value; - static_assert(compliant || std::is_trivially_destructible::value, - "Not compliant with std::is_trivially_destructible; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_destructible::value, - "Not compliant with std::is_trivially_destructible; " - "Standard: true, Implementation: false"); -#endif -}; + std::is_destructible::value> {}; -// --------------------------------------------------------------------------- -// is_trivially_default_constructible() -// -// Determines whether the passed type `T` is trivially default constructible. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_default_constructible()` metafunction for platforms that -// have incomplete C++11 support (such as libstdc++ 4.x). On any platforms that -// do fully support C++11, we check whether this yields the same result as the -// std implementation. -// -// NOTE: according to the C++ standard, Section: 20.15.4.3 [meta.unary.prop] -// "The predicate condition for a template specialization is_constructible shall be satisfied if and only if the following variable -// definition would be well-formed for some invented variable t: -// -// T t(declval()...); -// -// is_trivially_constructible additionally requires that the -// variable definition does not call any operation that is not trivial. -// For the purposes of this check, the call to std::declval is considered -// trivial." -// -// Notes from https://en.cppreference.com/w/cpp/types/is_constructible: -// In many implementations, is_nothrow_constructible also checks if the -// destructor throws because it is effectively noexcept(T(arg)). Same -// applies to is_trivially_constructible, which, in these implementations, also -// requires that the destructor is trivial. -// GCC bug 51452: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51452 -// LWG issue 2116: http://cplusplus.github.io/LWG/lwg-active.html#2116. -// -// "T obj();" need to be well-formed and not call any nontrivial operation. -// Nontrivially destructible types will cause the expression to be nontrivial. -// --------------------------------------------------------------------------- template struct is_trivially_default_constructible : std::integral_constant::value && - is_trivially_destructible::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_CONSTRUCTIBLE -private: - static constexpr bool compliant = - std::is_trivially_default_constructible::value == - is_trivially_default_constructible::value; - static_assert(compliant || std::is_trivially_default_constructible::value, - "Not compliant with std::is_trivially_default_constructible; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_default_constructible::value, - "Not compliant with std::is_trivially_default_constructible; " - "Standard: true, Implementation: false"); -#endif -}; + is_trivially_destructible::value> {}; -// --------------------------------------------------------------------------- -// is_trivially_copy_constructible() -// -// Determines whether the passed type `T` is trivially copy constructible. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_copy_constructible()` metafunction for platforms that have -// incomplete C++11 support (such as libstdc++ 4.x). On any platforms that do -// fully support C++11, we check whether this yields the same result as the std -// implementation. -// -// NOTE: `T obj(declval());` needs to be well-formed and not call any -// nontrivial operation. Nontrivially destructible types will cause the -// expression to be nontrivial. -// --------------------------------------------------------------------------- template struct is_trivially_copy_constructible : std::integral_constant::value && - is_trivially_destructible::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_CONSTRUCTIBLE -private: - static constexpr bool compliant = - std::is_trivially_copy_constructible::value == - is_trivially_copy_constructible::value; - static_assert(compliant || std::is_trivially_copy_constructible::value, - "Not compliant with std::is_trivially_copy_constructible; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_copy_constructible::value, - "Not compliant with std::is_trivially_copy_constructible; " - "Standard: true, Implementation: false"); -#endif -}; + is_trivially_destructible::value> {}; -// --------------------------------------------------------------------------- -// is_trivially_copy_assignable() -// -// Determines whether the passed type `T` is trivially copy assignable. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_copy_assignable()` metafunction for platforms that have -// incomplete C++11 support (such as libstdc++ 4.x). On any platforms that do -// fully support C++11, we check whether this yields the same result as the std -// implementation. -// -// NOTE: `is_assignable::value` is `true` if the expression -// `declval() = declval()` is well-formed when treated as an unevaluated -// operand. `is_trivially_assignable` requires the assignment to call no -// operation that is not trivial. `is_trivially_copy_assignable` is simply -// `is_trivially_assignable`. -// --------------------------------------------------------------------------- template struct is_trivially_copy_assignable : std::integral_constant< bool, __has_trivial_assign(typename std::remove_reference::type) && - phmap::is_copy_assignable::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_ASSIGNABLE -private: - static constexpr bool compliant = - std::is_trivially_copy_assignable::value == - is_trivially_copy_assignable::value; - static_assert(compliant || std::is_trivially_copy_assignable::value, - "Not compliant with std::is_trivially_copy_assignable; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_copy_assignable::value, - "Not compliant with std::is_trivially_copy_assignable; " - "Standard: true, Implementation: false"); -#endif -}; + phmap::is_copy_assignable::value> {}; // ----------------------------------------------------------------------------- // C++14 "_t" trait aliases @@ -447,14 +321,19 @@ using enable_if_t = typename std::enable_if::type; template using conditional_t = typename std::conditional::type; + template using common_type_t = typename std::common_type::type; template using underlying_type_t = typename std::underlying_type::type; -template -using result_of_t = typename std::result_of::type; +template< class F, class... ArgTypes> +#if __cplusplus >= 201703L +using invoke_result_t = typename std::invoke_result_t; +#else +using invoke_result_t = typename std::result_of::type; +#endif namespace type_traits_internal { @@ -1226,6 +1105,11 @@ auto apply(Functor&& functor, Tuple&& t) typename std::remove_reference::type>::value>{}); } +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4365) // '=': conversion from 'T' to 'T', signed/unsigned mismatch +#endif // _MSC_VER + // exchange // // Replaces the value of `obj` with `new_value` and returns the old value of @@ -1247,6 +1131,11 @@ T exchange(T& obj, U&& new_value) return old_value; } +#ifdef _MSC_VER + #pragma warning(pop) +#endif // _MSC_VER + + } // namespace phmap // ----------------------------------------------------------------------------- @@ -2845,6 +2734,12 @@ struct KeyArg using type = key_type; }; +#ifdef _MSC_VER + #pragma warning(push) + // warning C4820: '6' bytes padding added after data member + #pragma warning(disable : 4820) +#endif + // The node_handle concept from C++17. // We specialize node_handle for sets and maps. node_handle_base holds the // common API of both. @@ -2883,10 +2778,25 @@ class node_handle_base protected: friend struct CommonAccess; + struct transfer_tag_t {}; + node_handle_base(transfer_tag_t, const allocator_type& a, slot_type* s) + : alloc_(a) { + PolicyTraits::transfer(alloc(), slot(), s); + } + + struct move_tag_t {}; + node_handle_base(move_tag_t, const allocator_type& a, slot_type* s) + : alloc_(a) { + PolicyTraits::construct(alloc(), slot(), s); + } + node_handle_base(const allocator_type& a, slot_type* s) : alloc_(a) { PolicyTraits::transfer(alloc(), slot(), s); } + //node_handle_base(const node_handle_base&) = delete; + //node_handle_base& operator=(const node_handle_base&) = delete; + void destroy() { if (!empty()) { PolicyTraits::destroy(alloc(), slot()); @@ -2908,10 +2818,13 @@ class node_handle_base private: phmap::optional alloc_; - mutable phmap::aligned_storage_t - slot_space_; + mutable phmap::aligned_storage_t slot_space_; }; +#ifdef _MSC_VER + #pragma warning(pop) +#endif + // For sets. // --------- template private: friend struct CommonAccess; - node_handle(const Alloc& a, typename Base::slot_type* s) : Base(a, s) {} + using Base::Base; }; // For maps. @@ -2961,7 +2874,7 @@ class node_handle + static void Destroy(Node* node) { + node->destroy(); + } + template static void Reset(Node* node) { node->reset(); @@ -2981,6 +2899,16 @@ struct CommonAccess static T Make(Args&&... args) { return T(std::forward(args)...); } + + template + static T Transfer(Args&&... args) { + return T(typename T::transfer_tag_t{}, std::forward(args)...); + } + + template + static T Move(Args&&... args) { + return T(typename T::move_tag_t{}, std::forward(args)...); + } }; // Implement the insert_return_type<> concept of C++17. @@ -4404,6 +4332,98 @@ class PHMAP_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTuple<> {}; } // namespace container_internal } // namespace phmap + +namespace phmap { +namespace container_internal { + +#ifdef _MSC_VER + #pragma warning(push) + // warning warning C4324: structure was padded due to alignment specifier + #pragma warning(disable : 4324) +#endif + + +// ---------------------------------------------------------------------------- +// Allocates at least n bytes aligned to the specified alignment. +// Alignment must be a power of 2. It must be positive. +// +// Note that many allocators don't honor alignment requirements above certain +// threshold (usually either alignof(std::max_align_t) or alignof(void*)). +// Allocate() doesn't apply alignment corrections. If the underlying allocator +// returns insufficiently alignment pointer, that's what you are going to get. +// ---------------------------------------------------------------------------- +template +void* Allocate(Alloc* alloc, size_t n) { + static_assert(Alignment > 0, ""); + assert(n && "n must be positive"); + struct alignas(Alignment) M {}; + using A = typename phmap::allocator_traits::template rebind_alloc; + using AT = typename phmap::allocator_traits::template rebind_traits; + A mem_alloc(*alloc); + void* p = AT::allocate(mem_alloc, (n + sizeof(M) - 1) / sizeof(M)); + assert(reinterpret_cast(p) % Alignment == 0 && + "allocator does not respect alignment"); + return p; +} + +// ---------------------------------------------------------------------------- +// The pointer must have been previously obtained by calling +// Allocate(alloc, n). +// ---------------------------------------------------------------------------- +template +void Deallocate(Alloc* alloc, void* p, size_t n) { + static_assert(Alignment > 0, ""); + assert(n && "n must be positive"); + struct alignas(Alignment) M {}; + using A = typename phmap::allocator_traits::template rebind_alloc; + using AT = typename phmap::allocator_traits::template rebind_traits; + A mem_alloc(*alloc); + AT::deallocate(mem_alloc, static_cast(p), + (n + sizeof(M) - 1) / sizeof(M)); +} + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +// Helper functions for asan and msan. +// ---------------------------------------------------------------------------- +inline void SanitizerPoisonMemoryRegion(const void* m, size_t s) { +#ifdef ADDRESS_SANITIZER + ASAN_POISON_MEMORY_REGION(m, s); +#endif +#ifdef MEMORY_SANITIZER + __msan_poison(m, s); +#endif + (void)m; + (void)s; +} + +inline void SanitizerUnpoisonMemoryRegion(const void* m, size_t s) { +#ifdef ADDRESS_SANITIZER + ASAN_UNPOISON_MEMORY_REGION(m, s); +#endif +#ifdef MEMORY_SANITIZER + __msan_unpoison(m, s); +#endif + (void)m; + (void)s; +} + +template +inline void SanitizerPoisonObject(const T* object) { + SanitizerPoisonMemoryRegion(object, sizeof(T)); +} + +template +inline void SanitizerUnpoisonObject(const T* object) { + SanitizerUnpoisonMemoryRegion(object, sizeof(T)); +} + +} // namespace container_internal +} // namespace phmap + + // --------------------------------------------------------------------------- // thread_annotations.h // --------------------------------------------------------------------------- @@ -4513,6 +4533,221 @@ inline T& ts_unchecked_read(T& v) PHMAP_NO_THREAD_SAFETY_ANALYSIS { } } // namespace thread_safety_analysis + +namespace container_internal { + +namespace memory_internal { + +// ---------------------------------------------------------------------------- +// If Pair is a standard-layout type, OffsetOf::kFirst and +// OffsetOf::kSecond are equivalent to offsetof(Pair, first) and +// offsetof(Pair, second) respectively. Otherwise they are -1. +// +// The purpose of OffsetOf is to avoid calling offsetof() on non-standard-layout +// type, which is non-portable. +// ---------------------------------------------------------------------------- +template +struct OffsetOf { + static constexpr size_t kFirst = (size_t)-1; + static constexpr size_t kSecond = (size_t)-1; +}; + +template +struct OffsetOf::type> +{ + static constexpr size_t kFirst = offsetof(Pair, first); + static constexpr size_t kSecond = offsetof(Pair, second); +}; + +// ---------------------------------------------------------------------------- +template +struct IsLayoutCompatible +{ +private: + struct Pair { + K first; + V second; + }; + + // Is P layout-compatible with Pair? + template + static constexpr bool LayoutCompatible() { + return std::is_standard_layout

() && sizeof(P) == sizeof(Pair) && + alignof(P) == alignof(Pair) && + memory_internal::OffsetOf

::kFirst == + memory_internal::OffsetOf::kFirst && + memory_internal::OffsetOf

::kSecond == + memory_internal::OffsetOf::kSecond; + } + +public: + // Whether pair and pair are layout-compatible. If they are, + // then it is safe to store them in a union and read from either. + static constexpr bool value = std::is_standard_layout() && + std::is_standard_layout() && + memory_internal::OffsetOf::kFirst == 0 && + LayoutCompatible>() && + LayoutCompatible>(); +}; + +} // namespace memory_internal + +// ---------------------------------------------------------------------------- +// The internal storage type for key-value containers like flat_hash_map. +// +// It is convenient for the value_type of a flat_hash_map to be +// pair; the "const K" prevents accidental modification of the key +// when dealing with the reference returned from find() and similar methods. +// However, this creates other problems; we want to be able to emplace(K, V) +// efficiently with move operations, and similarly be able to move a +// pair in insert(). +// +// The solution is this union, which aliases the const and non-const versions +// of the pair. This also allows flat_hash_map to work, even though +// that has the same efficiency issues with move in emplace() and insert() - +// but people do it anyway. +// +// If kMutableKeys is false, only the value member can be accessed. +// +// If kMutableKeys is true, key can be accessed through all slots while value +// and mutable_value must be accessed only via INITIALIZED slots. Slots are +// created and destroyed via mutable_value so that the key can be moved later. +// +// Accessing one of the union fields while the other is active is safe as +// long as they are layout-compatible, which is guaranteed by the definition of +// kMutableKeys. For C++11, the relevant section of the standard is +// https://timsong-cpp.github.io/cppwp/n3337/class.mem#19 (9.2.19) +// ---------------------------------------------------------------------------- +template +union map_slot_type +{ + map_slot_type() {} + ~map_slot_type() = delete; + map_slot_type(const map_slot_type&) = delete; + map_slot_type& operator=(const map_slot_type&) = delete; + + using value_type = std::pair; + using mutable_value_type = std::pair; + + value_type value; + mutable_value_type mutable_value; + K key; +}; + +// ---------------------------------------------------------------------------- +// ---------------------------------------------------------------------------- +template +struct map_slot_policy +{ + using slot_type = map_slot_type; + using value_type = std::pair; + using mutable_value_type = std::pair; + +private: + static void emplace(slot_type* slot) { + // The construction of union doesn't do anything at runtime but it allows us + // to access its members without violating aliasing rules. + new (slot) slot_type; + } + // If pair and pair are layout-compatible, we can accept one + // or the other via slot_type. We are also free to access the key via + // slot_type::key in this case. + using kMutableKeys = memory_internal::IsLayoutCompatible; + +public: + static value_type& element(slot_type* slot) { return slot->value; } + static const value_type& element(const slot_type* slot) { + return slot->value; + } + + static const K& key(const slot_type* slot) { + return kMutableKeys::value ? slot->key : slot->value.first; + } + + template + static void construct(Allocator* alloc, slot_type* slot, Args&&... args) { + emplace(slot); + if (kMutableKeys::value) { + phmap::allocator_traits::construct(*alloc, &slot->mutable_value, + std::forward(args)...); + } else { + phmap::allocator_traits::construct(*alloc, &slot->value, + std::forward(args)...); + } + } + + // Construct this slot by moving from another slot. + template + static void construct(Allocator* alloc, slot_type* slot, slot_type* other) { + emplace(slot); + if (kMutableKeys::value) { + phmap::allocator_traits::construct( + *alloc, &slot->mutable_value, std::move(other->mutable_value)); + } else { + phmap::allocator_traits::construct(*alloc, &slot->value, + std::move(other->value)); + } + } + + template + static void destroy(Allocator* alloc, slot_type* slot) { + if (kMutableKeys::value) { + phmap::allocator_traits::destroy(*alloc, &slot->mutable_value); + } else { + phmap::allocator_traits::destroy(*alloc, &slot->value); + } + } + + template + static void transfer(Allocator* alloc, slot_type* new_slot, + slot_type* old_slot) { + emplace(new_slot); + if (kMutableKeys::value) { + phmap::allocator_traits::construct( + *alloc, &new_slot->mutable_value, std::move(old_slot->mutable_value)); + } else { + phmap::allocator_traits::construct(*alloc, &new_slot->value, + std::move(old_slot->value)); + } + destroy(alloc, old_slot); + } + + template + static void swap(Allocator* alloc, slot_type* a, slot_type* b) { + if (kMutableKeys::value) { + using std::swap; + swap(a->mutable_value, b->mutable_value); + } else { + value_type tmp = std::move(a->value); + phmap::allocator_traits::destroy(*alloc, &a->value); + phmap::allocator_traits::construct(*alloc, &a->value, + std::move(b->value)); + phmap::allocator_traits::destroy(*alloc, &b->value); + phmap::allocator_traits::construct(*alloc, &b->value, + std::move(tmp)); + } + } + + template + static void move(Allocator* alloc, slot_type* src, slot_type* dest) { + if (kMutableKeys::value) { + dest->mutable_value = std::move(src->mutable_value); + } else { + phmap::allocator_traits::destroy(*alloc, &dest->value); + phmap::allocator_traits::construct(*alloc, &dest->value, + std::move(src->value)); + } + } + + template + static void move(Allocator* alloc, slot_type* first, slot_type* last, + slot_type* result) { + for (slot_type *src = first, *dest = result; src != last; ++src, ++dest) + move(alloc, src, dest); + } +}; + +} // namespace container_internal } // phmap @@ -4928,4 +5163,9 @@ class LockableImpl: public phmap::NullMutex } // phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + + #endif // phmap_base_h_guard_ diff --git a/include/parallel_hashmap/phmap_bits.h b/include/parallel_hashmap/phmap_bits.h index c48e8eac0..7933d8cb5 100644 --- a/include/parallel_hashmap/phmap_bits.h +++ b/include/parallel_hashmap/phmap_bits.h @@ -50,6 +50,11 @@ #include #include "phmap_config.h" +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed +#endif + // ----------------------------------------------------------------------------- // unaligned APIs // ----------------------------------------------------------------------------- @@ -151,6 +156,102 @@ inline void UnalignedStore64(void *p, uint64_t v) { memcpy(p, &v, sizeof v); } #endif +// ----------------------------------------------------------------------------- +// File: optimization.h +// ----------------------------------------------------------------------------- + +#if defined(__pnacl__) + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } +#elif defined(__clang__) + // Clang will not tail call given inline volatile assembly. + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") +#elif defined(__GNUC__) + // GCC will not tail call given inline volatile assembly. + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") +#elif defined(_MSC_VER) + #include + // The __nop() intrinsic blocks the optimisation. + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __nop() +#else + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } +#endif + +#if defined(__GNUC__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" +#endif + +#ifdef PHMAP_HAVE_INTRINSIC_INT128 + __extension__ typedef unsigned __int128 phmap_uint128; + inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) + { + auto result = static_cast(a) * static_cast(b); + *high = static_cast(result >> 64); + return static_cast(result); + } + #define PHMAP_HAS_UMUL128 1 +#elif (defined(_MSC_VER)) + #if defined(_M_X64) + #pragma intrinsic(_umul128) + inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) + { + return _umul128(a, b, high); + } + #define PHMAP_HAS_UMUL128 1 + #endif +#endif + +#if defined(__GNUC__) + #pragma GCC diagnostic pop +#endif + +#if defined(__GNUC__) + // Cache line alignment + #if defined(__i386__) || defined(__x86_64__) + #define PHMAP_CACHELINE_SIZE 64 + #elif defined(__powerpc64__) + #define PHMAP_CACHELINE_SIZE 128 + #elif defined(__aarch64__) + // We would need to read special register ctr_el0 to find out L1 dcache size. + // This value is a good estimate based on a real aarch64 machine. + #define PHMAP_CACHELINE_SIZE 64 + #elif defined(__arm__) + // Cache line sizes for ARM: These values are not strictly correct since + // cache line sizes depend on implementations, not architectures. There + // are even implementations with cache line sizes configurable at boot + // time. + #if defined(__ARM_ARCH_5T__) + #define PHMAP_CACHELINE_SIZE 32 + #elif defined(__ARM_ARCH_7A__) + #define PHMAP_CACHELINE_SIZE 64 + #endif + #endif + + #ifndef PHMAP_CACHELINE_SIZE + // A reasonable default guess. Note that overestimates tend to waste more + // space, while underestimates tend to waste more time. + #define PHMAP_CACHELINE_SIZE 64 + #endif + + #define PHMAP_CACHELINE_ALIGNED __attribute__((aligned(PHMAP_CACHELINE_SIZE))) +#elif defined(_MSC_VER) + #define PHMAP_CACHELINE_SIZE 64 + #define PHMAP_CACHELINE_ALIGNED __declspec(align(PHMAP_CACHELINE_SIZE)) +#else + #define PHMAP_CACHELINE_SIZE 64 + #define PHMAP_CACHELINE_ALIGNED +#endif + + +#if PHMAP_HAVE_BUILTIN(__builtin_expect) || \ + (defined(__GNUC__) && !defined(__clang__)) + #define PHMAP_PREDICT_FALSE(x) (__builtin_expect(x, 0)) + #define PHMAP_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else + #define PHMAP_PREDICT_FALSE(x) (x) + #define PHMAP_PREDICT_TRUE(x) (x) +#endif + // ----------------------------------------------------------------------------- // File: bits.h // ----------------------------------------------------------------------------- @@ -183,7 +284,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros64(uint64_t n) { // MSVC does not have __buitin_clzll. Use _BitScanReverse64. unsigned long result = 0; // NOLINT(runtime/int) if (_BitScanReverse64(&result, n)) { - return 63 - result; + return (int)(63 - result); } return 64; #elif defined(_MSC_VER) @@ -226,7 +327,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros32(uint32_t n) { #if defined(_MSC_VER) unsigned long result = 0; // NOLINT(runtime/int) if (_BitScanReverse(&result, n)) { - return 31 - result; + return (int)(31 - result); } return 32; #elif defined(__GNUC__) @@ -263,7 +364,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero64(uint64_t n) { #if defined(_MSC_VER) && defined(_M_X64) unsigned long result = 0; // NOLINT(runtime/int) _BitScanForward64(&result, n); - return result; + return (int)result; #elif defined(_MSC_VER) unsigned long result = 0; // NOLINT(runtime/int) if (static_cast(n) == 0) { @@ -296,7 +397,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero32(uint32_t n) { #if defined(_MSC_VER) unsigned long result = 0; // NOLINT(runtime/int) _BitScanForward(&result, n); - return result; + return (int)result; #elif defined(__GNUC__) static_assert(sizeof(int) == sizeof(n), "__builtin_ctz does not take 32-bit arg"); @@ -311,101 +412,6 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero32(uint32_t n) { } // namespace base_internal } // namespace phmap -// ----------------------------------------------------------------------------- -// File: optimization.h -// ----------------------------------------------------------------------------- - -#if defined(__pnacl__) - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } -#elif defined(__clang__) - // Clang will not tail call given inline volatile assembly. - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") -#elif defined(__GNUC__) - // GCC will not tail call given inline volatile assembly. - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") -#elif defined(_MSC_VER) - #include - // The __nop() intrinsic blocks the optimisation. - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __nop() -#else - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } -#endif - -#if defined(__GNUC__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wpedantic" -#endif - -#ifdef PHMAP_HAVE_INTRINSIC_INT128 - inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) - { - auto result = static_cast(a) * static_cast(b); - *high = static_cast(result >> 64); - return static_cast(result); - } - #define PHMAP_HAS_UMUL128 1 -#elif (defined(_MSC_VER)) - #if defined(_M_X64) - #pragma intrinsic(_umul128) - inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) - { - return _umul128(a, b, high); - } - #define PHMAP_HAS_UMUL128 1 - #endif -#endif - -#if defined(__GNUC__) - #pragma GCC diagnostic pop -#endif - -#if defined(__GNUC__) - // Cache line alignment - #if defined(__i386__) || defined(__x86_64__) - #define PHMAP_CACHELINE_SIZE 64 - #elif defined(__powerpc64__) - #define PHMAP_CACHELINE_SIZE 128 - #elif defined(__aarch64__) - // We would need to read special register ctr_el0 to find out L1 dcache size. - // This value is a good estimate based on a real aarch64 machine. - #define PHMAP_CACHELINE_SIZE 64 - #elif defined(__arm__) - // Cache line sizes for ARM: These values are not strictly correct since - // cache line sizes depend on implementations, not architectures. There - // are even implementations with cache line sizes configurable at boot - // time. - #if defined(__ARM_ARCH_5T__) - #define PHMAP_CACHELINE_SIZE 32 - #elif defined(__ARM_ARCH_7A__) - #define PHMAP_CACHELINE_SIZE 64 - #endif - #endif - - #ifndef PHMAP_CACHELINE_SIZE - // A reasonable default guess. Note that overestimates tend to waste more - // space, while underestimates tend to waste more time. - #define PHMAP_CACHELINE_SIZE 64 - #endif - - #define PHMAP_CACHELINE_ALIGNED __attribute__((aligned(PHMAP_CACHELINE_SIZE))) -#elif defined(_MSC_VER) - #define PHMAP_CACHELINE_SIZE 64 - #define PHMAP_CACHELINE_ALIGNED __declspec(align(PHMAP_CACHELINE_SIZE)) -#else - #define PHMAP_CACHELINE_SIZE 64 - #define PHMAP_CACHELINE_ALIGNED -#endif - - -#if PHMAP_HAVE_BUILTIN(__builtin_expect) || \ - (defined(__GNUC__) && !defined(__clang__)) - #define PHMAP_PREDICT_FALSE(x) (__builtin_expect(x, 0)) - #define PHMAP_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) -#else - #define PHMAP_PREDICT_FALSE(x) (x) - #define PHMAP_PREDICT_TRUE(x) (x) -#endif - // ----------------------------------------------------------------------------- // File: endian.h // ----------------------------------------------------------------------------- @@ -650,4 +656,8 @@ inline void Store64(void *p, uint64_t v) { } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + #endif // phmap_bits_h_guard_ diff --git a/include/parallel_hashmap/phmap_config.h b/include/parallel_hashmap/phmap_config.h index ff10b4393..dd5a0bf54 100644 --- a/include/parallel_hashmap/phmap_config.h +++ b/include/parallel_hashmap/phmap_config.h @@ -296,7 +296,8 @@ #endif #ifdef __has_include - #if __has_include() && __cplusplus >= 201703L + #if __has_include() && __cplusplus >= 201703L && \ + (!defined(_MSC_VER) || _MSC_VER >= 1920) // vs2019 #define PHMAP_HAVE_STD_STRING_VIEW 1 #endif #endif @@ -304,14 +305,16 @@ // #pragma message(PHMAP_VAR_NAME_VALUE(_MSVC_LANG)) #if defined(_MSC_VER) && _MSC_VER >= 1910 && \ - ((defined(_MSVC_LANG) && _MSVC_LANG > 201402) || __cplusplus > 201402) + ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703) || __cplusplus >= 201703) // #define PHMAP_HAVE_STD_ANY 1 #define PHMAP_HAVE_STD_OPTIONAL 1 #define PHMAP_HAVE_STD_VARIANT 1 - #define PHMAP_HAVE_STD_STRING_VIEW 1 + #if !defined(PHMAP_HAVE_STD_STRING_VIEW) && _MSC_VER >= 1920 + #define PHMAP_HAVE_STD_STRING_VIEW 1 + #endif #endif -#if (defined(_MSVC_LANG) && _MSVC_LANG >= 201402) || __cplusplus >= 201703 +#if (defined(_MSVC_LANG) && _MSVC_LANG >= 201703) || __cplusplus >= 201703 #define PHMAP_HAVE_SHARED_MUTEX 1 #endif @@ -391,7 +394,7 @@ #define PHMAP_ATTRIBUTE_ALWAYS_INLINE #endif -#if PHMAP_HAVE_ATTRIBUTE(noinline) || (defined(__GNUC__) && !defined(__clang__)) +#if !defined(__INTEL_COMPILER) && (PHMAP_HAVE_ATTRIBUTE(noinline) || (defined(__GNUC__) && !defined(__clang__))) #define PHMAP_ATTRIBUTE_NOINLINE __attribute__((noinline)) #define PHMAP_HAVE_ATTRIBUTE_NOINLINE 1 #else @@ -627,6 +630,15 @@ #endif +// ---------------------------------------------------------------------- +// constexpr if +// ---------------------------------------------------------------------- +#if __cplusplus >= 201703 || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703) + #define PHMAP_IF_CONSTEXPR(expr) if constexpr ((expr)) +#else + #define PHMAP_IF_CONSTEXPR(expr) if ((expr)) +#endif + // ---------------------------------------------------------------------- // base/macros.h // ---------------------------------------------------------------------- diff --git a/include/parallel_hashmap/phmap_fwd_decl.h b/include/parallel_hashmap/phmap_fwd_decl.h index a9fb51418..f90417d53 100644 --- a/include/parallel_hashmap/phmap_fwd_decl.h +++ b/include/parallel_hashmap/phmap_fwd_decl.h @@ -11,13 +11,30 @@ // https://www.apache.org/licenses/LICENSE-2.0 // --------------------------------------------------------------------------- +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion +#endif + #include #include +#if defined(PHMAP_USE_ABSL_HASH) && !defined(ABSL_HASH_HASH_H_) + namespace absl { template struct Hash; }; +#endif + namespace phmap { +#if defined(PHMAP_USE_ABSL_HASH) + template using Hash = ::absl::Hash; +#else template struct Hash; +#endif + template struct EqualTo; + template struct Less; template using Allocator = typename std::allocator; template using Pair = typename std::pair; @@ -29,13 +46,8 @@ namespace phmap { template struct HashEq { -#if defined(PHMAP_USE_ABSL_HASHEQ) - using Hash = absl::Hash; - using Eq = phmap::EqualTo; -#else using Hash = phmap::Hash; using Eq = phmap::EqualTo; -#endif }; template @@ -54,6 +66,7 @@ namespace phmap { } // namespace container_internal + // ------------- forward declarations for hash containers ---------------------------------- template , class Eq = phmap::container_internal::hash_default_eq, @@ -114,9 +127,28 @@ namespace phmap { class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks class parallel_node_hash_map; + // ------------- forward declarations for btree containers ---------------------------------- + template , + typename Alloc = phmap::Allocator> + class btree_set; + template , + typename Alloc = phmap::Allocator> + class btree_multiset; + + template , + typename Alloc = phmap::Allocator>> + class btree_map; + + template , + typename Alloc = phmap::Allocator>> + class btree_multimap; } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + #endif // phmap_fwd_decl_h_guard_ diff --git a/include/parallel_hashmap/phmap_utils.h b/include/parallel_hashmap/phmap_utils.h index 96fb6ff39..72d4e716f 100644 --- a/include/parallel_hashmap/phmap_utils.h +++ b/include/parallel_hashmap/phmap_utils.h @@ -4,7 +4,9 @@ // --------------------------------------------------------------------------- // Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com // -// minimal header providing phmap::hash_combine +// minimal header providing phmap::HashState +// +// use as: phmap::HashState().combine(0, _first_name, _last_name, _age); // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,10 +21,25 @@ // limitations under the License. // --------------------------------------------------------------------------- +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion +#endif + #include #include +#include #include "phmap_bits.h" +// --------------------------------------------------------------- +// Absl forward declaration requires global scope. +// --------------------------------------------------------------- +#if defined(PHMAP_USE_ABSL_HASH) && !defined(phmap_fwd_decl_h_guard_) && !defined(ABSL_HASH_HASH_H_) + namespace absl { template struct Hash; }; +#endif + namespace phmap { @@ -113,14 +130,17 @@ struct has_hash_value typedef std::true_type yes; typedef std::false_type no; - template static auto test(int) -> decltype(U::hash_value(std::declval()) == 1, yes()); + template static auto test(int) -> decltype(hash_value(std::declval()) == 1, yes()); template static no test(...); public: static constexpr bool value = std::is_same(0)), yes>::value; }; - + +#if defined(PHMAP_USE_ABSL_HASH) && !defined(phmap_fwd_decl_h_guard_) + template using Hash = ::absl::Hash; +#elif !defined(PHMAP_USE_ABSL_HASH) // --------------------------------------------------------------- // phmap::Hash // --------------------------------------------------------------- @@ -130,7 +150,7 @@ struct Hash template ::value, int>::type = 0> size_t _hash(const T& val) const { - return U::hash_value(val); + return hash_value(val); } template ::value, int>::type = 0> @@ -262,6 +282,8 @@ struct Hash : public phmap_unary_function } }; +#endif + template struct Combiner { H operator()(H seed, size_t value); @@ -283,7 +305,7 @@ template struct Combiner } }; - +// define HashState to combine member hashes... see example below // ----------------------------------------------------------------------------- template class HashStateBase { @@ -299,12 +321,56 @@ template H HashStateBase::combine(H seed, const T& v, const Ts&... vs) { return HashStateBase::combine(Combiner()( - seed, phmap::Hash()(v)), vs...); + seed, phmap::Hash()(v)), + vs...); } using HashState = HashStateBase; +// ----------------------------------------------------------------------------- + +#if !defined(PHMAP_USE_ABSL_HASH) + +// define Hash for std::pair +// ------------------------- +template +struct Hash> { + size_t operator()(std::pair const& p) const noexcept { + return phmap::HashState().combine(phmap::Hash()(p.first), p.second); + } +}; + +// define Hash for std::tuple +// -------------------------- +template +struct Hash> { + size_t operator()(std::tuple const& t) const noexcept { + return _hash_helper(t); + } + +private: + template + typename std::enable_if::type + _hash_helper(const std::tuple &) const noexcept { return 0; } + + template + typename std::enable_if::type + _hash_helper(const std::tuple &t) const noexcept { + const auto &el = std::get(t); + using el_type = typename std::remove_cv::type>::type; + return Combiner()( + phmap::Hash()(el), _hash_helper(t)); + } +}; + + +#endif + + } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif #endif // phmap_utils_h_guard_ diff --git a/scripts/fetchPufferfish.sh b/scripts/fetchPufferfish.sh index d49d77190..ca299e178 100755 --- a/scripts/fetchPufferfish.sh +++ b/scripts/fetchPufferfish.sh @@ -22,10 +22,10 @@ if [ -d ${INSTALL_DIR}/src/pufferfish ] ; then rm -fr ${INSTALL_DIR}/src/pufferfish fi -SVER=salmon-v1.2.1 +SVER=salmon-v1.3.0 #SVER=develop -EXPECTED_SHA256=da51713e54cf426524a2a1da7de2273cea7bf1f4089abbce22fcaa8f59e493cc +EXPECTED_SHA256=0176b2ec5fc45bbf68c60b5845fead28e63db72f91ff93499d67e7a571167fdf mkdir -p ${EXTERNAL_DIR} curl -k -L https://github.com/COMBINE-lab/pufferfish/archive/${SVER}.zip -o ${EXTERNAL_DIR}/pufferfish.zip diff --git a/src/Alevin.cpp b/src/Alevin.cpp index 7cdd19900..2fb8d4235 100644 --- a/src/Alevin.cpp +++ b/src/Alevin.cpp @@ -823,6 +823,7 @@ void initiatePipeline(AlevinOpts& aopt, std::vector readFiles){ bool isOptionsOk = aut::processAlevinOpts(aopt, sopt, noTgMap, vm); if (!isOptionsOk){ + aopt.jointLog->flush(); exit(1); } @@ -1003,7 +1004,7 @@ salmon-based processing of single-cell RNA-seq data. if (celseq) validate_num_protocols += 1; if (celseq2) validate_num_protocols += 1; if (quartzseq2) validate_num_protocols += 1; - if (custom) validate_num_protocols += 1; + if (custom and !noTgMap) validate_num_protocols += 1; if ( validate_num_protocols != 1 ) { fmt::print(stderr, "ERROR: Please specify one and only one scRNA protocol;"); diff --git a/src/AlevinUtils.cpp b/src/AlevinUtils.cpp index 3dda707b2..5130a5f75 100644 --- a/src/AlevinUtils.cpp +++ b/src/AlevinUtils.cpp @@ -563,6 +563,7 @@ namespace alevin { auto logPath = aopt.outputDirectory / "alevin.log"; auto fileSink = std::make_shared(logPath.string(), true); auto consoleSink = std::make_shared(); + consoleSink->set_color(spdlog::level::warn, consoleSink->magenta); std::vector sinks{consoleSink, fileSink}; aopt.jointLog = spdlog::create("alevinLog", std::begin(sinks), std::end(sinks)); @@ -578,8 +579,10 @@ namespace alevin { aopt.dumpBarcodeEq = vm["dumpBarcodeEq"].as(); aopt.dumpBFH = vm["dumpBfh"].as(); aopt.dumpUmiGraph = vm["dumpUmiGraph"].as(); + aopt.dumpCellEq = vm["dumpCellEq"].as(); aopt.trimRight = vm["trimRight"].as(); aopt.numBootstraps = vm["numCellBootstraps"].as(); + aopt.numGibbsSamples = vm["numCellGibbsSamples"].as(); aopt.lowRegionMinNumBarcodes = vm["lowRegionMinNumBarcodes"].as(); aopt.maxNumBarcodes = vm["maxNumBarcodes"].as(); aopt.freqThreshold = vm["freqThreshold"].as(); @@ -602,6 +605,12 @@ namespace alevin { return false; } + if (aopt.numGibbsSamples > 0 and aopt.numBootstraps > 0) { + aopt.jointLog->error("Either of --numCellGibbsSamples or --numCellBootstraps " + "can be used"); + return false; + } + if ( aopt.numBootstraps > 0 and aopt.noEM ) { aopt.jointLog->error("cannot perform bootstrapping with noEM option."); return false; diff --git a/src/AlignmentModel.cpp b/src/AlignmentModel.cpp index 23dab30e3..5377e423f 100644 --- a/src/AlignmentModel.cpp +++ b/src/AlignmentModel.cpp @@ -13,14 +13,15 @@ #include "UnpairedRead.hpp" AlignmentModel::AlignmentModel(double alpha, uint32_t readBins) - : transitionProbsLeft_(readBins, - AtomicMatrix(numAlignmentStates(), - numAlignmentStates(), alpha)), - transitionProbsRight_(readBins, - AtomicMatrix(numAlignmentStates(), - numAlignmentStates(), alpha)), + : transitionProbsLeft_(readBins), transitionProbsRight_(readBins), isEnabled_(true), readBins_(readBins), burnedIn_(false) { - + + for (size_t i = 0; i < readBins; ++i) { + transitionProbsLeft_[i] = std::move(AtomicMatrix( + numAlignmentStates(), numAlignmentStates(), alpha)); + transitionProbsRight_[i] = std::move(AtomicMatrix( + numAlignmentStates(), numAlignmentStates(), alpha)); + } } bool AlignmentModel::burnedIn() { return burnedIn_; } diff --git a/src/BuildSalmonIndex.cpp b/src/BuildSalmonIndex.cpp index 8e97e69ad..d3e525adc 100644 --- a/src/BuildSalmonIndex.cpp +++ b/src/BuildSalmonIndex.cpp @@ -29,7 +29,6 @@ #include "tbb/parallel_for.h" #include "tbb/parallel_for_each.h" #include "tbb/parallel_sort.h" -#include "tbb/task_scheduler_init.h" #include "GenomicFeature.hpp" #include "SalmonIndex.hpp" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5ea2f4146..e0309ec76 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,6 +10,7 @@ ${Boost_INCLUDE_DIRS} ${GAT_SOURCE_DIR}/external/install/include ${GAT_SOURCE_DIR}/external/install/include/pufferfish ${GAT_SOURCE_DIR}/external/install/include/pufferfish/digestpp +${LIB_GFF_INCLUDE_DIR} #${GAT_SOURCE_DIR}/external/install/include/rapmap #${GAT_SOURCE_DIR}/external/install/include/rapmap/digestpp ${ICU_INC_DIRS} @@ -102,6 +103,7 @@ ${Boost_LIBRARY_DIRS} ${TBB_LIBRARY_DIRS} ${LAPACK_LIBRARY_DIR} ${BLAS_LIBRARY_DIR} +${LIB_GFF_LIBRARY_DIR} ) message("TBB_LIBRARIES = ${TBB_LIBRARIES}") @@ -206,16 +208,16 @@ target_link_libraries(salmon m ${LIBLZMA_LIBRARIES} ${BZIP2_LIBRARIES} - ${TBB_LIBRARIES} ${LIBSALMON_LINKER_FLAGS} ${NON_APPLECLANG_LIBS} - ${FAST_MALLOC_LIB} ${LIBRT} ksw2pp ## PUFF_INTEGRATION alevin_core - ${CMAKE_DL_LIBS} ${ASAN_LIB} + ${FAST_MALLOC_LIB} + ${TBB_LIBRARIES} + ${CMAKE_DL_LIBS} #ubsan ) @@ -238,8 +240,8 @@ target_link_libraries(unitTests ${LIBSALMON_LINKER_FLAGS} ${NON_APPLECLANG_LIBS} ${LIBRT} - ${CMAKE_DL_LIBS} ${ASAN_LIB} + ${CMAKE_DL_LIBS} #ubsan ) diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index bbdefc781..a7844c1da 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -1,6 +1,9 @@ -#include "CollapsedCellOptimizer.hpp" #include #include +#include + +#include "CollapsedCellOptimizer.hpp" +#include "ezETAProgressBar.hpp" CollapsedCellOptimizer::CollapsedCellOptimizer() {} @@ -216,6 +219,216 @@ bool runPerCellEM(double& totalNumFrags, size_t numGenes, return true; } +bool runGibbsSamples(size_t numGenes, + CollapsedCellOptimizer::SerialVecType& geneAlphas, + std::vector& sampleMean, + std::vector& sampleVariance, + std::vector& salmonEqclasses, + std::shared_ptr& jointlog, + uint32_t numSamples, + std::vector>& sampleEstimates, + bool quiet = true){ + + + constexpr double minEQClassWeight = std::numeric_limits::denorm_min(); + constexpr double minWeight = std::numeric_limits::denorm_min(); + + uint32_t numInternalRounds{16}; + size_t numClasses = salmonEqclasses.size(); + + // gibbs related + CollapsedCellOptimizer::SerialVecType alphas(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType mean(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType squareMean(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType alphasIn(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType alphasInit(numGenes, 0.0); + + //extracting weight of eqclasses for making discrete distribution + uint32_t totalNumFrags = 0; + std::vector eqCounts; + + std::vector active(numGenes, false); + for (auto& eqclass: salmonEqclasses) { + for (size_t j = 0; j < eqclass.labels.size(); ++j) { + auto geneIdx = eqclass.labels[j]; + active[geneIdx] = true; + } + totalNumFrags += eqclass.count; + eqCounts.emplace_back(eqclass.count); + } + + // make active list (genes that are present in equivalence classes) + std::vector activeList; + activeList.reserve(numGenes); + for (size_t i = 0; i < numGenes; ++i) { + if (active[i]) { + activeList.push_back(i); + } + alphasIn[i] = geneAlphas[i]; + alphasInit[i] = geneAlphas[i]; + } + + double prior = 1e-3; + std::vector priorAlphas(numGenes, prior); + + std::vector offsetMap(numClasses, 0); + size_t countMapSize{0}; + for (size_t i = 0; i < numClasses ; ++i){ + countMapSize += salmonEqclasses[i].labels.size(); + if(i < numClasses - 1) { + offsetMap[i + 1] = countMapSize; + } + } + + // hold the estimated counts, active list + std::vector mu(numGenes, 0.0); + std::vector countMap(countMapSize, 0); + std::vector probMap(countMapSize, 0.0); + + uint32_t nchains{1}; + if (numSamples >= 50) { + nchains = 2; + } + if (numSamples >= 100) { + nchains = 4; + } + if (numSamples >= 200) { + nchains = 8; + } + + std::random_device rd; + std::mt19937 gen(rd()); + + std::vector newChainIter{0}; + if (nchains > 1) { + auto step = numSamples / nchains; + for (size_t i = 1; i < nchains; ++i) { + newChainIter.push_back(i * step); + } + } + auto nextChainStart = newChainIter.begin(); + + // For each sample this thread should generate + std::unique_ptr pbar{nullptr}; + if (!quiet) { + pbar.reset(new ez::ezETAProgressBar(numSamples)); + pbar->start(); + } + + for (size_t sampleID = 0; sampleID < numSamples; ++sampleID) { + + if (pbar) { + ++(*pbar); + } + // If we should start a new chain here, then do it! + if (nextChainStart < newChainIter.end() and sampleID == *nextChainStart) { + alphasIn = alphasInit; + ++nextChainStart; + } + + // Since for single cell data we don't estimate the exact fragment length + // essentially it will be treated as a single end read, there for from the + // definition of effective length, + // l_e = l_i - l_f + 1 + // we assume l_e = 1 + // The rest of the gibbs sample would pretty much follow from + // the principle of bulk RNA-seq + + // The mean transcript fraction are sampled from + // ~ Gam( prior[i] + geneAlphas[i], \Beta + 1 ) + // Given these transcript fractions, the reads are + // re-assigned within each equivalence class by sampling from + // a multinomial distribution according to these means + for (size_t roundIdx = 0; roundIdx < numInternalRounds; ++roundIdx) { + double beta = 0.1; + double norm = 0.0; + + // first phase: Calculate mean transcript fraction from Gamma + for(size_t activeIdx = 0; activeIdx < activeList.size(); ++activeIdx) { + auto i = activeList[activeIdx]; + double ci = static_cast(alphas[i] + priorAlphas[i]); + std::gamma_distribution d(ci, 1.0 / (beta + 1.0)); + mu[i] = d(gen); + alphas[i] = 0.0 ; + } + + // second phase: sample from the trandcript fractions + // re-assign them back to equivalence classes + for (size_t eqId = 0; eqId < numClasses; ++eqId) { + size_t offset = offsetMap[eqId]; + size_t classCount = salmonEqclasses[eqId].count ; + const std::vector& geneLabels = salmonEqclasses[eqId].labels; + size_t groupSize = geneLabels.size(); + + double muSum{0.0}; + double denom{0.0}; + if (groupSize > 1) { + double uniformWeight = 1.0 / static_cast(groupSize); + for (size_t i = 0; i < groupSize; ++i) { + auto gid = geneLabels[i]; + size_t globalIndex = offset + i; + probMap[globalIndex] = (1000.0 * mu[gid]) * uniformWeight ; + muSum += probMap[globalIndex]; + denom += probMap[globalIndex]; + } + // we might be working with tiny values and + // the denominator can become very small + if (denom <= minEQClassWeight) { + denom = 0.0; + muSum = 0.0; + for (size_t i = 0; i < groupSize; ++i) { + auto gid = geneLabels[i]; + size_t globalIndex = offset + i; + probMap[globalIndex] = 1.0 ; + muSum += probMap[globalIndex]; + denom += probMap[globalIndex]; + } + } + // Assuming previous step worked + // re-sample from a multinomial + // fill in only subpart of alpha + if (denom > minEQClassWeight) { + std::discrete_distribution dist(probMap.begin() + offset, + probMap.begin() + offset + groupSize + ); + for (size_t s = 0; s < classCount ; ++s){ + auto ind = dist(gen); + ++alphas[geneLabels[ind]]; + } + }else{ + jointlog->warn("the probabilities are too small " + "Make sure you ran salmon correclty."); + jointlog->flush(); + } + }else{ + auto gid = geneLabels[0]; + alphas[gid] += static_cast(classCount); + } + } + } + + // internal rounds are done + // TODO: can extrapolate the counts + // but as effective length is not at play, but imo it + // would hardly affect anything + + // updated values are in alpha let's put them in the estimation vector + // Calculate mean and square mean for later + for (size_t i=0; i& salmonEqclasses, @@ -329,6 +542,7 @@ bool runBootstraps(size_t numGenes, double meanAlpha = mean[i] / numBootstraps; geneAlphas[i] = meanAlpha; variance[i] = (squareMean[i]/numBootstraps) - (meanAlpha*meanAlpha); + variance[i] *= (numBootstraps / static_cast(numBootstraps - 1)); } return true; @@ -343,11 +557,13 @@ void optimizeCell(std::vector& trueBarcodes, std::vector& umiCount, std::vector& skippedCB, bool verbose, GZipWriter& gzw, bool noEM, bool useVBEM, - bool quiet, tbb::atomic& totalDedupCounts, - tbb::atomic& totalExpGeneCounts, double priorWeight, + bool quiet, std::atomic& totalDedupCounts, + std::atomic& totalExpGeneCounts, double priorWeight, spp::sparse_hash_map& txpToGeneMap, - uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, - bool naiveEqclass, bool dumpUmiGraph, bool useAllBootstraps, + uint32_t numGenes, uint32_t umiLength, + uint32_t numBootstraps, uint32_t numGibbsSamples, + bool naiveEqclass, bool dumpUmiGraph, + bool dumpCellEq, bool useAllBootstraps, bool initUniform, CFreqMapT& freqCounter, bool dumpArborescences, spp::sparse_hash_set& mRnaGenes, spp::sparse_hash_set& rRnaGenes, @@ -449,8 +665,20 @@ void optimizeCell(std::vector& trueBarcodes, std::exit(74); } - if ( numBootstraps and noEM ) { - jointlog->error("Cannot perform bootstrapping with noEM"); + if ( dumpCellEq ){ + std::vector> labelsEq ; + std::vector countsEq ; + labelsEq.resize(salmonEqclasses.size()); + countsEq.resize(salmonEqclasses.size()); + for(size_t salEqId = 0 ; salEqId < salmonEqclasses.size(); ++salEqId){ + labelsEq[salEqId] = salmonEqclasses[salEqId].labels; + countsEq[salEqId] = salmonEqclasses[salEqId].count; + } + gzw.writeDedupCellEQVec(trueBarcodeIdx, labelsEq, countsEq, true); + } + + if ( (numBootstraps and noEM) or (numGibbsSamples and noEM) ) { + jointlog->error("Cannot perform bootstrapping/gibbs with noEM"); jointlog->flush(); exit(1); } @@ -665,6 +893,35 @@ void optimizeCell(std::vector& trueBarcodes, // maintaining count for total number of predicted UMI salmon::utils::incLoop(totalDedupCounts, totalCount); totalExpGeneCounts += totalExpGenes; + + if ( numGibbsSamples > 0 ) { + std::vector> sampleEstimates; + std::vector sampleVariance(numGenes, 0.0); + std::vector sampleMean(numGenes, 0.0); + + bool isGibbsOk = runGibbsSamples( + numGenes, + geneAlphas, + sampleMean, + sampleVariance, + salmonEqclasses, + jointlog, + numGibbsSamples, + sampleEstimates + ); + + if( not isGibbsOk or (sampleEstimates.size()!=numGibbsSamples)){ + jointlog->error("Gibbs failed failed \n" + "Please Report this on github."); + jointlog->flush(); + std::exit(74); + } + + // write the abundance for the cell + gzw.writeSparseBootstraps( trueBarcodeStr, + sampleMean, sampleVariance, + true, sampleEstimates); + }//end-gibbs-if if ( numBootstraps > 0 ){ std::vector> sampleEstimates; @@ -696,12 +953,16 @@ void optimizeCell(std::vector& trueBarcodes, } else { // doing per eqclass level naive deduplication for (size_t eqId=0; eqId umis; + size_t numFeats = txpGroups[eqId].size(); + if (numFeats > 1) { continue; }; + spp::sparse_hash_set umis; for(auto& it: umiGroups[eqId]) { umis.insert( it.first ); } + totalCount += umis.size(); + geneAlphas[ txpGroups[eqId][0] ] += umis.size(); // filling in the eqclass level deduplicated counts if (verbose) { @@ -709,8 +970,28 @@ void optimizeCell(std::vector& trueBarcodes, } } + std::string emptyString = ""; + // write the abundance for the cell + bool isWriteOk = gzw.writeSparseAbundances( trueBarcodeStr, + emptyString, + emptyString, + 0, + geneAlphas, + tiers, + false, + false ); + + + if( not isWriteOk ){ + jointlog->error("Gzip Writer failed \n" + "Please Report this on github."); + jointlog->flush(); + std::exit(74); + } + // maintaining count for total number of predicted UMI salmon::utils::incLoop(totalDedupCounts, totalCount); + totalExpGeneCounts += totalExpGenes; } if (verbose) { @@ -839,8 +1120,8 @@ bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap, double priorWeight {1.0}; std::atomic bcount{0}; - tbb::atomic totalDedupCounts{0.0}; - tbb::atomic totalExpGeneCounts{0}; + std::atomic totalDedupCounts{0.0}; + std::atomic totalExpGeneCounts{0}; std::atomic totalBiEdgesCounts{0}; std::atomic totalUniEdgesCounts{0}; @@ -1006,8 +1287,10 @@ bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap, numGenes, aopt.protocol.umiLength, aopt.numBootstraps, + aopt.numGibbsSamples, aopt.naiveEqclass, aopt.dumpUmiGraph, + aopt.dumpCellEq, aopt.dumpfeatures, aopt.initUniform, std::ref(freqCounter), @@ -1072,7 +1355,7 @@ bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap, std::copy(geneNames.begin(), geneNames.end(), giterator); gFile.close(); - if( not hasWhitelist and not usingHashMode){ + if( not aopt.naiveEqclass and not hasWhitelist and not usingHashMode){ aopt.jointLog->info("Clearing EqMap; Might take some time."); fullEqMap.clear(); diff --git a/src/CollapsedEMOptimizer.cpp b/src/CollapsedEMOptimizer.cpp index f3a142c73..46fc121f1 100644 --- a/src/CollapsedEMOptimizer.cpp +++ b/src/CollapsedEMOptimizer.cpp @@ -8,7 +8,8 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" +// <-- deprecated in TBB --> #include "tbb/task_scheduler_init.h" +#include "tbb/global_control.h" //#include "fastapprox.h" #include @@ -41,7 +42,7 @@ constexpr double minWeight = std::numeric_limits::min(); // A bit more conservative of a minimum as an argument to the digamma function. constexpr double digammaMin = 1e-10; -double normalize(std::vector>& vec) { +double normalize(std::vector>& vec) { double sum{0.0}; for (auto& v : vec) { sum += v; @@ -721,7 +722,8 @@ template bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, double relDiffTolerance, uint32_t maxIter) { - tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + // <-- deprecated in TBB --> tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + tbb::global_control c(tbb::global_control::max_allowed_parallelism, sopt.numThreads); std::vector& transcripts = readExp.transcripts(); std::vector available(transcripts.size(), false); @@ -736,8 +738,8 @@ bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, using VecT = CollapsedEMOptimizer::VecType; // With atomics - VecType alphas(transcripts.size(), 0.0); - VecType alphasPrime(transcripts.size(), 0.0); + VecType alphas(transcripts.size()); + VecType alphasPrime(transcripts.size()); VecType expTheta(transcripts.size()); Eigen::VectorXd effLens(transcripts.size()); @@ -798,7 +800,7 @@ bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, // over to the alphas if (sopt.initUniform) { for (size_t i = 0; i < alphas.size(); ++i) { - alphas[i] = alphasPrime[i]; + alphas[i].store(alphasPrime[i].load()); alphasPrime[i] = 1.0; } } else { // otherwise, initalize with a linear combination of the true and @@ -940,8 +942,8 @@ bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, converged = false; } } - alphas[i] = alphasPrime[i]; - alphasPrime[i] = 0.0; + alphas[i].store(alphasPrime[i].load()); + alphasPrime[i].store(0.0); } /* -- v0.8.x diff --git a/src/CollapsedGibbsSampler.cpp b/src/CollapsedGibbsSampler.cpp index 60f30f6ad..8ca0b66d0 100644 --- a/src/CollapsedGibbsSampler.cpp +++ b/src/CollapsedGibbsSampler.cpp @@ -12,7 +12,8 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" +// <-- deprecated in TBB --> #include "tbb/task_scheduler_init.h" +#include "tbb/global_control.h" //#include "fastapprox.h" #include @@ -308,7 +309,10 @@ bool CollapsedGibbsSampler::sample( namespace bfs = boost::filesystem; auto& jointLog = sopt.jointLog; - tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + + // <-- deprecated in TBB --> tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + tbb::global_control c(tbb::global_control::max_allowed_parallelism, sopt.numThreads); + std::vector& transcripts = readExp.transcripts(); // Fill in the effective length vector diff --git a/src/DistributionUtils.cpp b/src/DistributionUtils.cpp index 538c4a0ff..0924d97e3 100644 --- a/src/DistributionUtils.cpp +++ b/src/DistributionUtils.cpp @@ -41,7 +41,7 @@ void computeSmoothedEffectiveLengths(size_t maxLength, ? correctionFactors[maxLen - 1] : correctionFactors[origLen]; - double effLen = origLen - correctionFactor + 1.0; + double effLen = origLen - correctionFactor; if (effLen < 1.0) { effLen = origLen; } @@ -54,7 +54,7 @@ void computeSmoothedEffectiveLengths(size_t maxLength, } } -std::vector samplesFromLogPMF(FragmentLengthDistribution* fld, +DistSummary samplesFromLogPMF(FragmentLengthDistribution* fld, int32_t numSamples) { std::vector logPMF; size_t minVal; @@ -70,21 +70,29 @@ std::vector samplesFromLogPMF(FragmentLengthDistribution* fld, } // Create the non-logged pmf + double mean = salmon::math::LOG_0; + double var = 0.0; std::vector pmf(maxVal + 1, 0.0); for (size_t i = minVal; i < maxVal; ++i) { + mean = (i > 0) ? (salmon::math::logAdd(mean, std::log(i)+logPMF[i-minVal])) : mean; pmf[i] = std::exp(logPMF[i - minVal]); + var += pmf[i] * (i*i); } + mean = std::exp(mean); + var -= mean*mean; + double sd = std::sqrt(var); // generate samples std::random_device rd; std::mt19937 gen(rd()); std::discrete_distribution dist(pmf.begin(), pmf.end()); + DistSummary ds(mean, sd, pmf.size()); std::vector samples(pmf.size()); for (int32_t i = 0; i < numSamples; ++i) { - ++samples[dist(gen)]; + ++(ds.samples[dist(gen)]); } - return samples; + return ds; } diff --git a/src/EMUtils.cpp b/src/EMUtils.cpp index 09e8c608a..30b42e824 100644 --- a/src/EMUtils.cpp +++ b/src/EMUtils.cpp @@ -74,14 +74,14 @@ void EMUpdate_>(std::vector>& txpGroup std::vector& alphaOut); template -void EMUpdate_>>(std::vector>& txpGroupLabels, +void EMUpdate_>>(std::vector>& txpGroupLabels, std::vector>& txpGroupCombinedWeights, const std::vector& txpGroupCounts, - const std::vector>& alphaIn, - std::vector>& alphaOut); + const std::vector>& alphaIn, + std::vector>& alphaOut); template double truncateCountVector>(std::vector& alphas, double cutoff); template -double truncateCountVector>>(std::vector>& alphas, double cutoff); +double truncateCountVector>>(std::vector>& alphas, double cutoff); diff --git a/src/FASTAParser.cpp b/src/FASTAParser.cpp index 338198ab5..926e66efb 100644 --- a/src/FASTAParser.cpp +++ b/src/FASTAParser.cpp @@ -42,8 +42,8 @@ void FASTAParser::populateTargets(std::vector& refs, constexpr char bases[] = {'A', 'C', 'G', 'T'}; // Create a random uniform distribution - std::random_device rd; - std::default_random_engine eng(rd()); + constexpr const uint64_t randseed{271828}; + std::default_random_engine eng(randseed); std::uniform_int_distribution<> dis(0, 3); uint64_t numNucleotidesReplaced{0}; diff --git a/src/FragmentLengthDistribution.cpp b/src/FragmentLengthDistribution.cpp index 216e442a3..ff4c3dd09 100644 --- a/src/FragmentLengthDistribution.cpp +++ b/src/FragmentLengthDistribution.cpp @@ -9,6 +9,7 @@ #include "FragmentLengthDistribution.hpp" #include "SalmonMath.hpp" +#include "SalmonUtils.hpp" #include #include #include @@ -22,7 +23,7 @@ using namespace std; FragmentLengthDistribution::FragmentLengthDistribution( double alpha, size_t max_val, double prior_mu, double prior_sigma, size_t kernel_n, double kernel_p, size_t bin_size) - : hist_(max_val / bin_size + 1), cachedCMF_(hist_.size()), + : /*hist_(max_val / bin_size + 1),*/ cachedCMF_(hist_.size()), haveCachedCMF_(false), totMass_(salmon::math::LOG_0), sum_(salmon::math::LOG_0), min_(max_val / bin_size), binSize_(bin_size) { @@ -38,6 +39,9 @@ FragmentLengthDistribution::FragmentLengthDistribution( boost::math::normal norm(prior_mu / bin_size, prior_sigma / (bin_size * bin_size)); + std::vector> hist_tmp(max_val / bin_size + 1); + std::swap(hist_, hist_tmp); + for (size_t i = 0; i <= max_val; ++i) { double norm_mass = boost::math::cdf(norm, i + 0.5) - boost::math::cdf(norm, i - 0.5); @@ -45,16 +49,17 @@ FragmentLengthDistribution::FragmentLengthDistribution( if (norm_mass != 0) { mass = tot + log(norm_mass); } - hist_[i].compare_and_swap(mass, hist_[i]); - sum_.compare_and_swap(logAdd(sum_, log((double)i) + mass), sum_); - totMass_.compare_and_swap(logAdd(totMass_, mass), totMass_); + hist_[i].store(mass); + sum_.store(logAdd(sum_, log((double)i) + mass)); + totMass_.store(logAdd(totMass_, mass)); } } else { - hist_ = - vector>(max_val + 1, tot - log((double)max_val)); - hist_[0].compare_and_swap(salmon::math::LOG_0, hist_[0]); - sum_.compare_and_swap( - hist_[1] + log((double)(max_val * (max_val + 1))) - log(2.), sum_); + std::vector> hist_tmp(max_val + 1); + std::swap(hist_, hist_tmp); + std::fill(hist_.begin(), hist_.end(), tot - log((double)max_val)); + + hist_[0].store(salmon::math::LOG_0); + sum_.store(hist_[1] + log((double)(max_val * (max_val + 1))) - log(2.)); totMass_ = tot; } @@ -96,28 +101,9 @@ void FragmentLengthDistribution::addVal(size_t len, double mass) { for (size_t i = 0; i < kernel_.size(); i++) { if (offset > 0 && offset < hist_.size()) { double kMass = mass + kernel_[i]; - double oldVal = hist_[offset]; - double retVal = oldVal; - double newVal = 0.0; - do { - oldVal = retVal; - newVal = logAdd(oldVal, kMass); - retVal = hist_[offset].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - retVal = sum_; - do { - oldVal = retVal; - newVal = logAdd(oldVal, log(static_cast(offset)) + kMass); - retVal = sum_.compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - retVal = totMass_; - do { - oldVal = retVal; - newVal = logAdd(oldVal, kMass); - retVal = totMass_.compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoopLog(hist_[offset], kMass); + salmon::utils::incLoopLog(sum_, std::log(static_cast(offset)) + kMass); + salmon::utils::incLoopLog(totMass_, kMass); } offset++; } diff --git a/src/FragmentStartPositionDistribution.cpp b/src/FragmentStartPositionDistribution.cpp index 8467c659b..f89e667e7 100644 --- a/src/FragmentStartPositionDistribution.cpp +++ b/src/FragmentStartPositionDistribution.cpp @@ -34,15 +34,12 @@ FragmentStartPositionDistribution::FragmentStartPositionDistribution( } } -inline void logAddMass(tbb::atomic& bin, double newMass) { - double oldVal = bin; - double retVal = oldVal; - double newVal = 0.0; - do { - oldVal = retVal; +inline void logAddMass(std::atomic& bin, double newMass) { + double oldVal = bin.load(); + double newVal; + do{ newVal = salmon::math::logAdd(oldVal, newMass); - retVal = bin.compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + } while (!bin.compare_exchange_strong(oldVal, newVal)); } void FragmentStartPositionDistribution::addVal(int32_t hitPos, uint32_t txpLen, diff --git a/src/GZipWriter.cpp b/src/GZipWriter.cpp index 1f696c497..461e3365e 100644 --- a/src/GZipWriter.cpp +++ b/src/GZipWriter.cpp @@ -30,6 +30,9 @@ GZipWriter::~GZipWriter() { if (cellEQStream_){ cellEQStream_->reset(); } + if (cellDedupEQStream_){ + cellDedupEQStream_->reset(); + } if (umiGraphStream_){ umiGraphStream_->reset(); } @@ -69,6 +72,9 @@ void GZipWriter::close_all_streams(){ if (cellEQStream_){ cellEQStream_->reset(); } + if (cellDedupEQStream_){ + cellDedupEQStream_->reset(); + } if (umiGraphStream_){ umiGraphStream_->reset(); } @@ -401,6 +407,8 @@ bool GZipWriter::writeEmptyMeta(const SalmonOpts& opts, const ExpT& experiment, oa(cereal::make_nvp("library_types", libStrings)); oa(cereal::make_nvp("frag_dist_length", 0)); + oa(cereal::make_nvp("frag_length_mean", 0.0)); + oa(cereal::make_nvp("frag_length_sd", 0.0)); oa(cereal::make_nvp("seq_bias_correct", false)); oa(cereal::make_nvp("gc_bias_correct", false)); oa(cereal::make_nvp("num_bias_bins", 0)); @@ -565,9 +573,9 @@ bool GZipWriter::writeMeta(const SalmonOpts& opts, const ExpT& experiment, const bfs::path fldPath = auxDir / "fld.gz"; int32_t numFLDSamples{10000}; - auto fldSamples = distribution_utils::samplesFromLogPMF( + auto fldSummary = distribution_utils::samplesFromLogPMF( experiment.fragmentLengthDistribution(), numFLDSamples); - writeVectorToFile(fldPath, fldSamples); + writeVectorToFile(fldPath, fldSummary.samples); bfs::path normBiasPath = auxDir / "expected_bias.gz"; writeVectorToFile(normBiasPath, experiment.expectedSeqBias()); @@ -768,7 +776,9 @@ bool GZipWriter::writeMeta(const SalmonOpts& opts, const ExpT& experiment, const oa(cereal::make_nvp("num_libraries", libStrings.size())); oa(cereal::make_nvp("library_types", libStrings)); - oa(cereal::make_nvp("frag_dist_length", fldSamples.size())); + oa(cereal::make_nvp("frag_dist_length", fldSummary.samples.size())); + oa(cereal::make_nvp("frag_length_mean", fldSummary.mean)); + oa(cereal::make_nvp("frag_length_sd", fldSummary.sd)); oa(cereal::make_nvp("seq_bias_correct", opts.biasCorrect)); oa(cereal::make_nvp("gc_bias_correct", opts.gcBiasCorrect)); oa(cereal::make_nvp("num_bias_bins", bcounts.size())); @@ -1481,6 +1491,49 @@ bool GZipWriter::writeCellEQVec(size_t barcode, const std::vector& off return true; } +// TODO: complete this method +bool GZipWriter::writeDedupCellEQVec( + size_t barcode, + const std::vector>& labels, + const std::vector& counts, + bool quiet + ) { +#if defined __APPLE__ + spin_lock::scoped_lock sl(writeMutex_); +#else + std::lock_guard lock(writeMutex_); +#endif + if (!cellDedupEQStream_) { + cellDedupEQStream_.reset(new boost::iostreams::filtering_ostream); + cellDedupEQStream_->push(boost::iostreams::gzip_compressor(6)); + auto ceqFilename = path_ / "alevin" / "cell_dedup_eq_mat.gz"; + cellDedupEQStream_->push(boost::iostreams::file_sink(ceqFilename.string(), + std::ios_base::out | std::ios_base::binary)); + } + + boost::iostreams::filtering_ostream& ofile = *cellDedupEQStream_; + size_t num = labels.size(); + size_t elSize = sizeof(typename std::vector::value_type); + // write the barcode + ofile.write(reinterpret_cast(&barcode), sizeof(barcode)); + // write the number of elements in the list + ofile.write(reinterpret_cast(&num), sizeof(barcode)); + // write the individual labels + for (size_t i = 0 ; i < num ; ++i) { + size_t labelLength = labels[i].size(); + // write the length of the equivalence class vector + ofile.write(reinterpret_cast(&labelLength), sizeof(barcode)); + // write the label vector + ofile.write(reinterpret_cast(const_cast(labels[i].data())), elSize * labelLength); + } + ofile.write(reinterpret_cast(const_cast(counts.data())), elSize * num); + // write the offsets and counts + if (!quiet) { + logger_->info("wrote EQ vector for barcode ID {}", barcode); + } + return true; +} + bool GZipWriter::writeUmiGraph(alevin::graph::Graph& g, std::string& cbString) { #if defined __APPLE__ diff --git a/src/ProgramOptionsGenerator.cpp b/src/ProgramOptionsGenerator.cpp index 169febc75..0d89c9e88 100644 --- a/src/ProgramOptionsGenerator.cpp +++ b/src/ProgramOptionsGenerator.cpp @@ -84,7 +84,7 @@ namespace salmon { mapspec.add_options() ("discardOrphansQuasi", po::bool_switch(&(sopt.discardOrphansQuasi))->default_value(salmon::defaults::discardOrphansQuasi), - "[selective-alignment mode only] : Discard orphan mappings in quasi-mapping " + "[selective-alignment mode only] : Discard orphan mappings in selective-alignment " "mode. If this flag is passed " "then only paired mappings will be considered toward quantification " "estimates. The default behavior is " @@ -102,12 +102,39 @@ namespace salmon { "[*deprecated* (no effect; selective-alignment is the default)]") ("consensusSlack", po::value(&(sopt.consensusSlack))->default_value(salmon::defaults::consensusSlack), - "[selective-alignment mode only] : The amount of slack allowed in the quasi-mapping consensus " - "mechanism. Normally, a transcript must cover all hits to be considered for mapping. " - "If this is set to a fraction, X, greater than 0 (and in [0,1)), then a transcript can fail to cover up to " - "(100 * X)% of the hits before it is discounted as a mapping candidate. The default value of this option " - "is 0.2 if --validateMappings is given and 0 otherwise." + "[selective-alignment mode only] : The amount of slack allowed in the selective-alignment " + "filtering mechanism. If this is set to a fraction, X, greater than 0 (and in [0,1)), then " + "uniMEM chains with scores below (100 * X)% of the best chain score for a read, and read pairs " + "with a sum of chain scores below (100 * X)% of the best chain score for a read pair " + "will be discounted as a mapping candidates. The default value of this option is 0.35." ) + ("preMergeChainSubThresh", + po::value(&(sopt.pre_merge_chain_sub_thresh))->default_value(salmon::defaults::pre_merge_chain_sub_thresh), + "[selective-alignment mode only] : The threshold of sub-optimal chains, compared to the best chain on a given " + "target, that will be retained and passed to the next phase of mapping. Specifically, if the best chain " + "for a read (or read-end in paired-end mode) to target t has score X_t, then all chains for this read with " + "score >= X_t * preMergeChainSubThresh will be retained and passed to subsequent mapping phases. This value " + "must be in the range [0, 1]." + ) + ("postMergeChainSubThresh", + po::value(&(sopt.post_merge_chain_sub_thresh))->default_value(salmon::defaults::post_merge_chain_sub_thresh), + "[selective-alignment mode only] : The threshold of sub-optimal chain pairs, compared to the best chain pair " + "on a given target, that will be retained and passed to the next phase of mapping. This is different than " + "preMergeChainSubThresh, because this is applied to pairs of chains (from the ends of paired-end reads) after " + "merging (i.e. after checking concordancy constraints etc.). Specifically, if the best chain pair " + "to target t has score X_t, then all chain pairs for this read pair with score " + ">= X_t * postMergeChainSubThresh will be retained and passed to subsequent mapping phases. This value " + "must be in the range [0, 1]. Note: This option is only meaningful for paired-end libraries, and is ignored " + "for single-end libraries." + ) + ("orphanChainSubThresh", + po::value(&(sopt.orphan_chain_sub_thresh))->default_value(salmon::defaults::orphan_chain_sub_thresh), + "[selective-alignment mode only] : This threshold sets a global sub-optimality threshold for chains " + "corresponding to orphan mappings. That is, if the merging procedure results in no concordant mappings " + "then only orphan mappings with a chain score >= orphanChainSubThresh * bestChainScore will be " + "retained and passed to subsequent mapping phases. This value must be in the range [0, 1]. Note: This " + "option is only meaningful for paired-end libraries, and is ignored for single-end libraries." + ) ("scoreExp", po::value(&sopt.scoreExp)->default_value(salmon::defaults::scoreExp), "[selective-alignment mode only] : The factor by which sub-optimal alignment scores are " @@ -237,7 +264,7 @@ namespace salmon { ("writeMappings,z", po::value(&sopt.qmFileName) ->default_value(salmon::defaults::quasiMappingDefaultFile) ->implicit_value(salmon::defaults::quasiMappingImplicitFile), - "If this option is provided, then the quasi-mapping " + "If this option is provided, then the selective-alignment " "results will be written out in SAM-compatible " "format. By default, output will be directed to " "stdout, but an alternative file name can be " @@ -331,7 +358,7 @@ namespace salmon { "Run naive per equivalence class deduplication, generating only total number of UMIs") ( "noDedup", po::bool_switch()->default_value(alevin::defaults::noDedup), - "Stops the pipeline after CB sequence correction and quasi-mapping reads."); + "Stops the pipeline after CB sequence correction and selective-alignment of reads."); return alevindevs; } @@ -390,6 +417,10 @@ namespace salmon { "numCellBootstraps",po::value()->default_value(alevin::defaults::numBootstraps), "Generate mean and variance for cell x gene matrix quantification" " estimates.") + ( + "numCellGibbsSamples",po::value()->default_value(alevin::defaults::numGibbsSamples), + "Generate mean and variance for cell x gene matrix quantification by running gibbs chain" + " estimates.") ( "forceCells",po::value()->default_value(alevin::defaults::forceCells), "Explicitly specify the number of cells.") @@ -409,7 +440,7 @@ namespace salmon { ( "end",po::value(), "Cell-Barcodes end (5 or 3) location in the read sequence from where barcode has to" - "be extracted. (end, umiLength, barcodeLength)" + " be extracted. (end, umiLength, barcodeLength)" " should all be provided if using this option") ( "umiLength",po::value(), @@ -417,7 +448,7 @@ namespace salmon { " should all be provided if using this option") ( "barcodeLength",po::value(), - "umi length Parameter for unknown protocol. (end, umiLength, barcodeLength)" + "barcode length Parameter for unknown protocol. (end, umiLength, barcodeLength)" " should all be provided if using this option") ( "noem",po::bool_switch()->default_value(alevin::defaults::noEM), @@ -441,6 +472,9 @@ namespace salmon { ( "dumpUmiGraph", po::bool_switch()->default_value(alevin::defaults::dumpUmiGraph), "dump the per cell level Umi Graph.") + ( + "dumpCellEq", po::bool_switch()->default_value(alevin::defaults::dumpCellEq), + "dump the per cell level deduplicated equivalence classes.") ( "dumpFeatures", po::bool_switch()->default_value(alevin::defaults::dumpFeatures), "Dump features for whitelist and downstream analysis.") diff --git a/src/SBModel.cpp b/src/SBModel.cpp index 87ee205bc..56b08020c 100644 --- a/src/SBModel.cpp +++ b/src/SBModel.cpp @@ -1,4 +1,5 @@ #include "SBModel.hpp" +#include "jellyfish/mer_dna.hpp" #include #include @@ -44,6 +45,7 @@ SBModel::SBModel() : _trained(false) { _marginals = Eigen::MatrixXd(4, _contextLength); _marginals.setZero(); + _marginals.array() += _prior_prob; _shifts.clear(); _widths.clear(); @@ -61,13 +63,15 @@ SBModel::SBModel() : _trained(false) { } // Set k equal to the size of the contexts we'll parse. - _mer.k(_contextLength); + //_mer.k(_contextLength); + _sbmer.k(_contextLength); // To hold all probabilities the matrix must be 4^{max_order + 1} by // context-length _probs = Eigen::MatrixXd(constExprPow(4, maxOrder + 1), _contextLength); // We have no intial observations _probs.setZero(); + _probs.array() += _prior_prob; } bool SBModel::writeBinary(boost::iostreams::filtering_ostream& out) const { @@ -113,21 +117,26 @@ bool SBModel::writeBinary(boost::iostreams::filtering_ostream& out) const { double SBModel::evaluateLog(const char* seqIn) { double p = 0; - Mer mer; - mer.from_chars(seqIn); + //Mer mer; + //mer.from_chars(seqIn); + SBMer sbmer; + sbmer.fromChars(seqIn); for (int32_t i = 0; i < _contextLength; ++i) { - uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + //uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + uint64_t idx = sbmer.get_bits(_shifts[i], _widths[i]); p += _probs(idx, i); } return p; } -double SBModel::evaluateLog(const Mer& mer) { +double SBModel::evaluateLog(const SBMer& sbmer) { double p = 0; for (int32_t i = 0; i < _contextLength; ++i) { - uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + //uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + uint64_t idx = sbmer.get_bits(_shifts[i], _widths[i]); + p += _probs(idx, i); } return p; @@ -183,16 +192,19 @@ void SBModel::dumpConditionalProbabilities(std::ostream& os) { } bool SBModel::addSequence(const char* seqIn, bool revCmp, double weight) { - _mer.from_chars(seqIn); + //_mer.from_chars(seqIn); + bool ok = _sbmer.fromChars(seqIn); if (revCmp) { - _mer.reverse_complement(); + //_mer.reverse_complement(); + _sbmer.rc(); } - return addSequence(_mer, weight); + return addSequence(_sbmer, weight); } -bool SBModel::addSequence(const Mer& mer, double weight) { +bool SBModel::addSequence(const SBMer& sbmer, double weight) { for (int32_t i = 0; i < _contextLength; ++i) { - uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + //uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + uint64_t idx = sbmer.get_bits(_shifts[i], _widths[i]); _probs(idx, i) += weight; } return true; @@ -234,7 +246,7 @@ bool SBModel::normalize() { // std::cerr << "pos = " << pos << ", marginals = " << _marginals.col(pos) // << '\n'; } - + double logSmall = std::log(1e-5); auto takeLog = [logSmall](double x) -> double { return (x > 0.0) ? std::log(x) : logSmall; diff --git a/src/Salmon.cpp b/src/Salmon.cpp index a2ff3ed1f..4274e0575 100644 --- a/src/Salmon.cpp +++ b/src/Salmon.cpp @@ -56,12 +56,11 @@ int help(const std::vector& /*opts*/) { " salmon -c|--cite or \n" " salmon [--no-version-check] [-h | options]\n\n"); helpMsg.write("Commands:\n"); - helpMsg.write(" index Create a salmon index\n"); - helpMsg.write(" quant Quantify a sample\n"); - helpMsg.write(" alevin single cell analysis\n"); - helpMsg.write(" swim Perform super-secret operation\n"); - helpMsg.write( - " quantmerge Merge multiple quantifications into a single file\n"); + helpMsg.write(" index : create a salmon index\n"); + helpMsg.write(" quant : quantify a sample\n"); + helpMsg.write(" alevin : single cell analysis\n"); + helpMsg.write(" swim : perform super-secret operation\n"); + helpMsg.write(" quantmerge : merge multiple quantifications into a single file\n"); std::cout << helpMsg.str(); return 0; @@ -78,7 +77,7 @@ int dualModeMessage() { alignment-based algorithm will be used, otherwise the algorithm for quantifying from raw reads will be used. - to view the help for salmon's quasi-mapping-based mode, use the command + to view the help for salmon's selective-alignment-based mode, use the command salmon quant --help-reads @@ -319,4 +318,4 @@ int main(int argc, char* argv[]) { } return 0; -} +} \ No newline at end of file diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 88a2b3bd0..5446493ce 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -49,7 +49,7 @@ // Jellyfish 2 include -#include "jellyfish/mer_dna.hpp" +// #include "jellyfish/mer_dna.hpp" // Boost Includes #include @@ -70,7 +70,6 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" // logger includes #include "spdlog/spdlog.h" @@ -418,8 +417,8 @@ void processReadsQuasi( observedBiasParams .seqBiasModelRC; // readExp.readBias(salmon::utils::Direction::REVERSE_COMPLEMENT); // k-mers for sequence bias context - Mer leftMer; - Mer rightMer; + // Mer leftMer; + // Mer rightMer; //auto expectedLibType = rl.format(); @@ -450,7 +449,6 @@ void processReadsQuasi( std::vector jointHits; PairedAlignmentFormatter formatter(qidx); pufferfish::util::QueryCache qc; - phmap::flat_hash_map> bestScorePerTranscript; bool mimicStrictBT2 = salmonOpts.mimicStrictBT2; bool mimicBT2 = salmonOpts.mimicBT2; @@ -468,9 +466,6 @@ void processReadsQuasi( ////////////////////// // NOTE: validation mapping based new parameters std::string rc1; rc1.reserve(300); - // will hold the permutation to use to put the transcripts in order - std::vector> perm; - //std::vector alnCache; alnCache.reserve(15); AlnCacheMap alnCache; alnCache.reserve(16); /* @@ -482,10 +477,15 @@ void processReadsQuasi( } */ - size_t numDropped{0}; size_t numMappingsDropped{0}; size_t numDecoyFrags{0}; const double decoyThreshold = salmonOpts.decoyThreshold; + + salmon::mapping_utils::MappingScoreInfo msi(decoyThreshold); + // we only collect detailed decoy information if we will be + // writing output to SAM. + msi.collect_decoys(writeQuasimappings); + std::string readSubSeq; ////////////////////// @@ -504,6 +504,12 @@ void processReadsQuasi( LibraryFormat expectedLibraryFormat = rl.format(); + std::string extraBAMtags; + if(writeQuasimappings) { + size_t reserveSize { alevinOpts.protocol.barcodeLength + alevinOpts.protocol.umiLength + 12}; + extraBAMtags.reserve(reserveSize); + } + for (size_t i = 0; i < rangeSize; ++i) { // For all the read in this batch auto& rp = rg[i]; readLenLeft = rp.first.seq.length(); @@ -515,7 +521,6 @@ void processReadsQuasi( jointHitGroup.clearAlignments(); auto& jointAlignments= jointHitGroup.alignments(); - perm.clear(); hits.clear(); jointHits.clear(); memCollector.clear(); @@ -530,6 +535,7 @@ void processReadsQuasi( std::string umi;//, barcode; nonstd::optional barcode; nonstd::optional barcodeIdx; + extraBAMtags.clear(); bool seqOk; if (alevinOpts.protocol.end == bcEnd::FIVE || @@ -588,6 +594,12 @@ void processReadsQuasi( if(isUmiIdxOk){ jointHitGroup.setUMI(umiIdx.word(0)); + if (writeQuasimappings) { + extraBAMtags += "\tCB:Z:"; + extraBAMtags += *barcode; + extraBAMtags += "\tUR:Z:"; + extraBAMtags += umi; + } auto seq_len = rp.second.seq.size(); if (alevinOpts.trimRight > 0) { @@ -651,27 +663,16 @@ void processReadsQuasi( // adding validate mapping code if (tryAlign and !jointHits.empty()) { puffaligner.clear(); - bestScorePerTranscript.clear(); - - //auto* r1 = readSubSeq.data(); - //auto l1 = static_cast(readSubSeq.length()); + puffaligner.getScoreStatus().reset(); + msi.clear(jointHits.size()); - // the best scores start out as invalid - /* - int32_t bestScore = invalidScore; - int32_t secondBestScore = invalidScore; - int32_t bestDecoyScore = invalidScore; - */ - salmon::mapping_utils::MappingScoreInfo msi = {invalidScore, invalidScore, invalidScore, decoyThreshold}; - - std::vector scores(jointHits.size(), invalidScore); size_t idx{0}; - bool isMultimapping = (jointHits.size() > 1); + bool is_multimapping = (jointHits.size() > 1); for (auto &&jointHit : jointHits) { // for alevin, currently, we need these to have a mate status of PAIRED_END_RIGHT jointHit.mateStatus = MateStatus::PAIRED_END_RIGHT; - auto hitScore = puffaligner.calculateAlignments(readSubSeq, jointHit, hctr, isMultimapping, false); + auto hitScore = puffaligner.calculateAlignments(readSubSeq, jointHit, hctr, is_multimapping, false); bool validScore = (hitScore != invalidScore); numMappingsDropped += validScore ? 0 : 1; auto tid = qidx->getRefId(jointHit.tid); @@ -686,18 +687,13 @@ void processReadsQuasi( (!jointHit.orphanClust()->isFw and (expectedLibraryFormat.strandedness == ReadStrandedness::SA)); salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, - msi, - //bestScore, secondBestScore, bestDecoyScore, - scores, bestScorePerTranscript, perm); + msi); ++idx; } - //bool bestHitDecoy = (msi.bestScore < msi.bestDecoyScore); bool bestHitDecoy = msi.haveOnlyDecoyMappings(); if (msi.bestScore > invalidScore and !bestHitDecoy) { salmon::mapping_utils::filterAndCollectAlignments(jointHits, - scores, - perm, readSubSeq.length(), readSubSeq.length(), false, // true for single-end false otherwise @@ -706,29 +702,38 @@ void processReadsQuasi( salmonOpts.scoreExp, salmonOpts.minAlnProb, msi, - /* - bestScore, - secondBestScore, - bestDecoyScore, - */ jointAlignments); if (!jointAlignments.empty()) { mapType = salmon::utils::MappingType::SINGLE_MAPPED; } } else { numDecoyFrags += bestHitDecoy ? 1 : 0; - ++numDropped; - jointHitGroup.clearAlignments(); mapType = (bestHitDecoy) ? salmon::utils::MappingType::DECOY : salmon::utils::MappingType::UNMAPPED; + if (bestHitDecoy) { + salmon::mapping_utils::filterAndCollectAlignmentsDecoy( + jointHits, readSubSeq.length(), + readSubSeq.length(), + false, // true for single-end false otherwise + tryAlign, hardFilter, salmonOpts.scoreExp, + salmonOpts.minAlnProb, msi, + jointAlignments); + } else { + jointHitGroup.clearAlignments(); + } } } //end-if validate mapping if (writeQuasimappings) { - writeAlignmentsToStream(rp, formatter, jointAlignments, sstream, true, true); - /* - rapmap::utils::writeAlignmentsToStream(rp, formatter, - hctr, jointHits, sstream); - */ + writeAlignmentsToStream(rp, formatter, jointAlignments, sstream, true, true, extraBAMtags); + } + + // We've kept decoy aignments around to this point so that we can + // potentially write these alignments to the SAM file. However, if + // we got to this point and only have decoy mappings, then clear the + // mappings here because none of the procesing below is relevant for + // decoys. + if (mapType == salmon::utils::MappingType::DECOY) { + jointHitGroup.clearAlignments(); } if (writeUnmapped and mapType != salmon::utils::MappingType::SINGLE_MAPPED) { diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 16fe0f252..98cf15f35 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -69,7 +69,6 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" // logger includes #include "spdlog/spdlog.h" @@ -256,6 +255,13 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc logCMFCache.refresh(numAssignedFragments.load(), burnedIn.load()); } + const size_t maxCacheLen{salmonOpts.fragLenDistMax}; + // Caches to avoid fld updates _within_ the set of alignments of a fragment + auto fetchPMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.pmf(l); }; + auto fetchCMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.cmf(l); }; + distribution_utils::IndexedVersionedCache pmfCache(maxCacheLen); + distribution_utils::IndexedVersionedCache cmfCache(maxCacheLen); + int i{0}; { // Iterate over each group of alignments (a group consists of all alignments @@ -263,6 +269,9 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc // for a single read). Distribute the read's mass to the transcripts // where it potentially aligns. for (auto& alnGroup : batchHits) { + pmfCache.increment_generation(); + cmfCache.increment_generation(); + // If we had no alignments for this read, then skip it if (alnGroup.size() == 0) { continue; @@ -309,11 +318,10 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc uint32_t prevTxpID{0}; hasCompatibleMapping = false; + useAuxParams = ((localNumAssignedFragments + numAssignedFragments) >= + salmonOpts.numPreBurninFrags); // For each alignment of this read for (auto& aln : alnGroup.alignments()) { - - useAuxParams = ((localNumAssignedFragments + numAssignedFragments) >= - salmonOpts.numPreBurninFrags); bool considerCondProb{burnedIn or useAuxParams}; auto transcriptID = aln.transcriptID(); @@ -368,11 +376,13 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc if (flen > 0.0 and useFragLengthDist and considerCondProb) { size_t fl = flen; - double lenProb = fragLengthDist.pmf(fl); + double lenProb = pmfCache.get_or_update(fl, fetchPMF); + if (burnedIn) { /* condition fragment length prob on txp length */ - double refLengthCM = - fragLengthDist.cmf(static_cast(refLength)); + size_t rlen = static_cast(refLength); + double refLengthCM = cmfCache.get_or_update(fl, fetchCMF); + bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); logFragProb = (computeMass) ? (lenProb - refLengthCM) @@ -791,8 +801,11 @@ void processReads( .seqBiasModelRC; // readExp.readBias(salmon::utils::Direction::REVERSE_COMPLEMENT); // k-mers for sequence bias context - Mer leftMer; - Mer rightMer; + //Mer leftMer; + //Mer rightMer; + SBMer leftMer; + SBMer rightMer; + uint64_t firstTimestepOfRound = fmCalc.getCurrentTimestep(); size_t minK = qidx->k(); @@ -810,6 +823,7 @@ void processReads( //******* Setting up pufferfish mapping constexpr const int32_t invalidScore = std::numeric_limits::min(); + constexpr const int32_t invalidIndex = std::numeric_limits::min(); MemCollector memCollector(qidx); ksw2pp::KSW2Aligner aligner; pufferfish::util::AlignmentConfig aconf; @@ -829,7 +843,6 @@ void processReads( bool noDovetail = !salmonOpts.allowDovetail; bool useChainingHeuristic = !salmonOpts.disableChainingHeuristic; size_t numOrphansRescued{0}; - phmap::flat_hash_map> bestScorePerTranscript; uint64_t firstDecoyIndex = qidx->firstDecoyIndex(); //******* @@ -851,15 +864,16 @@ void processReads( } */ - // will hold the permutation to use to put the transcripts in order - std::vector> perm; - size_t numMappingsDropped{0}; size_t numFragsDropped{0}; size_t numDecoyFrags{0}; const double decoyThreshold = salmonOpts.decoyThreshold; uint32_t readLen{0}, mateLen{0}, totLen{0}; + salmon::mapping_utils::MappingScoreInfo msi(decoyThreshold); + // we only collect detailed decoy information if we will be + // writing output to SAM. + msi.collect_decoys(writeQuasimappings); auto rg = parser->getReadGroup(); while (parser->refill(rg)) { @@ -894,7 +908,6 @@ void processReads( ++hctr.numReads; - perm.clear(); jointHits.clear(); leftHits.clear(); rightHits.clear(); @@ -941,6 +954,24 @@ void processReads( ); hctr.numMappedAtLeastAKmer += (leftHits.size() > 0 || rightHits.size() > 0) ? 1 : 0; + /* + salmonOpts.jointLog->info("\n\n mappings for left end \n\n"); + + for (auto&& h : leftHits) { + salmonOpts.jointLog->info("hit to : {}", qidx->refName( h.first )); + for (auto&& mc : *h.second) { + salmonOpts.jointLog->info("\t ori : {}, pos : {}, score : {}", (mc.isFw ? "fw" : "rc"), mc.getTrFirstHitPos(), mc.score); + } + } + + salmonOpts.jointLog->info("\n\n mappings for right end \n\n"); + for (auto&& h : rightHits) { + salmonOpts.jointLog->info("hit to : {}", qidx->refName( h.first )); + for (auto&& mc : *h.second) { + salmonOpts.jointLog->info("\t ori : {}, pos : {}, score : {}", (mc.isFw ? "fw" : "rc"), mc.getTrFirstHitPos(), mc.score); + } + } + */ // TODO : PF_INTEGRATION /* @@ -1001,6 +1032,16 @@ void processReads( upperBoundHits += (jointHits.size() > 0); } + /* + salmonOpts.jointLog->info("\n\n mappings for joined ends \n\n"); + for (auto&& h : jointHits) { + salmonOpts.jointLog->info("hit to : {}", qidx->refName( h.tid )); + salmonOpts.jointLog->info("\t lpos : {}, rpos : {}, score : {}", h.leftClust->getTrFirstHitPos(), + h.rightClust->getTrFirstHitPos(), h.coverage()); + } + */ + + // FIXME: This clears the alignment group, but that contains nothing // at this point. We should either check only once we are at the alignment // phase (and therefore filter nothing based on pre-alignment hits), or @@ -1012,7 +1053,6 @@ void processReads( } } - // TODO: PF_INTEGRATION // NOTE : Under our new definition of orphans, alignments of read ends // can be orphans even if the other read end aligns to the same reference. // It only matters that the alignments were not paired. Thus, it is possible @@ -1064,22 +1104,14 @@ void processReads( if (tryAlign and !jointHits.empty()) { // clear the aligner for this read puffaligner.clear(); - bestScorePerTranscript.clear(); - - // the best scores start out as invalid - /* - int32_t bestScore = invalidScore; - int32_t secondBestScore = invalidScore; - int32_t bestDecoyScore = invalidScore; - */ - - salmon::mapping_utils::MappingScoreInfo msi = {invalidScore, invalidScore, invalidScore, decoyThreshold}; - std::vector scores(jointHits.size(), invalidScore); + puffaligner.getScoreStatus().reset(); + msi.clear(jointHits.size()); + size_t idx{0}; - bool isMultimapping = (jointHits.size() > 1); + bool is_multimapping = (jointHits.size() > 1); for (auto &&jointHit : jointHits) { - auto hitScore = puffaligner.calculateAlignments(rp.first.seq, rp.second.seq, jointHit, hctr, isMultimapping, false); + auto hitScore = puffaligner.calculateAlignments(rp.first.seq, rp.second.seq, jointHit, hctr, is_multimapping, false); bool validScore = (hitScore != invalidScore); numMappingsDropped += validScore ? 0 : 1; auto tid = qidx->getRefId(jointHit.tid); @@ -1166,17 +1198,13 @@ void processReads( **/ // end of alternative compat - salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, msi, - //bestScore, secondBestScore, bestDecoyScore, - scores, bestScorePerTranscript, perm); + salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, msi); ++idx; } bool bestHitDecoy = msi.haveOnlyDecoyMappings(); if (msi.bestScore > invalidScore and !bestHitDecoy) { salmon::mapping_utils::filterAndCollectAlignments(jointHits, - scores, - perm, readLen, mateLen, false, // true for single-end false otherwise @@ -1185,11 +1213,6 @@ void processReads( salmonOpts.scoreExp, salmonOpts.minAlnProb, msi, - /* - bestScore, - secondBestScore, - bestDecoyScore, - */ jointAlignments); // if we have alignments if (!jointAlignments.empty()) { @@ -1216,7 +1239,25 @@ void processReads( mapType = bestHitDecoy ? salmon::utils::MappingType::DECOY : salmon::utils::MappingType::UNMAPPED; numDecoyFrags += bestHitDecoy ? 1 : 0; ++numFragsDropped; - jointAlignmentGroup.clearAlignments(); + // TODO: Create alignment objects for the decoys so that we can write + // decoy alignments to file + + if (bestHitDecoy) { + salmon::mapping_utils::filterAndCollectAlignmentsDecoy(jointHits, + readLen, + mateLen, + false, // true for single-end false otherwise + tryAlign, + hardFilter, + salmonOpts.scoreExp, + salmonOpts.minAlnProb, + msi, + jointAlignments); + } else { + jointAlignmentGroup.clearAlignments(); + } + + //jointAlignmentGroup.clearAlignments(); } } else if (isPaired and noDovetail) { salmonOpts.jointLog->critical("This code path is not yet implemented!"); @@ -1235,6 +1276,24 @@ void processReads( jointAlignments.end()); } + if (writeQuasimappings) { + writeAlignmentsToStream(rp, formatter, + jointAlignments, + sstream, + true, // write orphans + true // transcript ID's already decoded (taking care of short refs) + ); + } + + // We've kept decoy aignments around to this point so that we can + // potentially write these alignments to the SAM file. However, if + // we got to this point and only have decoy mappings, then clear the + // mappings here because none of the procesing below is relevant for + // decoys. + if (mapType == salmon::utils::MappingType::DECOY) { + jointAlignmentGroup.clearAlignments(); + } + bool needBiasSample = salmonOpts.biasCorrect; std::uniform_int_distribution<> dis(0, jointAlignments.size()); @@ -1300,14 +1359,14 @@ void processReads( int32_t fwPos = (h.fwd) ? startPos1 : startPos2; int32_t rcPos = (h.fwd) ? startPos2 : startPos1; if (fwPos < rcPos) { - leftMer.from_chars(txpStart + startPos1 - + leftMer.fromChars(txpStart + startPos1 - readBias1.contextBefore(read1RC)); - rightMer.from_chars(txpStart + startPos2 - + rightMer.fromChars(txpStart + startPos2 - readBias2.contextBefore(read2RC)); if (read1RC) { - leftMer.reverse_complement(); + leftMer.rc(); } else { - rightMer.reverse_complement(); + rightMer.rc(); } success = readBias1.addSequence(leftMer, 1.0); @@ -1352,14 +1411,6 @@ void processReads( } } - if (writeQuasimappings) { - writeAlignmentsToStream(rp, formatter, - jointAlignments, - sstream, - true, // write orphans - true // transcript ID's already decoded (taking care of short refs) - ); - } } else { // This read was completely unmapped. mapType = salmon::utils::MappingType::UNMAPPED; @@ -1514,8 +1565,8 @@ void processReads( auto& readBiasFW = observedBiasParams.seqBiasModelFW; auto& readBiasRC = observedBiasParams.seqBiasModelRC; - Mer context; - + //Mer context; + SBMer context; uint64_t firstTimestepOfRound = fmCalc.getCurrentTimestep(); size_t minK = qidx->k(); @@ -1533,6 +1584,7 @@ void processReads( //******* Setting up pufferfish mapping constexpr const int32_t invalidScore = std::numeric_limits::min(); + constexpr const int32_t invalidIndex = std::numeric_limits::min(); MemCollector memCollector(qidx); ksw2pp::KSW2Aligner aligner; pufferfish::util::AlignmentConfig aconf; @@ -1545,7 +1597,6 @@ void processReads( std::vector jointHits; PairedAlignmentFormatter formatter(qidx); pufferfish::util::QueryCache qc; - phmap::flat_hash_map> bestScorePerTranscript; bool mimicStrictBT2 = salmonOpts.mimicStrictBT2; bool mimicBT2 = salmonOpts.mimicBT2; @@ -1566,14 +1617,16 @@ void processReads( std::string rc1; rc1.reserve(300); - // will hold the permutation to use to put the transcripts in order - std::vector> perm; - size_t numMappingsDropped{0}; size_t numFragsDropped{0}; size_t numDecoyFrags{0}; const double decoyThreshold = salmonOpts.decoyThreshold; + salmon::mapping_utils::MappingScoreInfo msi(decoyThreshold); + // we only collect detailed decoy information if we will be + // writing output to SAM. + msi.collect_decoys(writeQuasimappings); + //std::vector alnCache; alnCache.reserve(15); AlnCacheMap alnCache; alnCache.reserve(16); @@ -1602,7 +1655,6 @@ void processReads( auto& jointAlignments = jointHitGroup.alignments(); mapType = salmon::utils::MappingType::UNMAPPED; - perm.clear(); hits.clear(); jointHits.clear(); memCollector.clear(); @@ -1666,22 +1718,14 @@ void processReads( // clear the aligner for this read puffaligner.clear(); - bestScorePerTranscript.clear(); - - // the best scores start out as invalid - /* - int32_t bestScore = invalidScore; - int32_t secondBestScore = invalidScore; - int32_t bestDecoyScore = invalidScore; - */ - salmon::mapping_utils::MappingScoreInfo msi = {invalidScore, invalidScore, invalidScore, decoyThreshold}; - - std::vector scores(jointHits.size(), invalidScore); + puffaligner.getScoreStatus().reset(); + msi.clear(jointHits.size()); + size_t idx{0}; - bool isMultimapping = (jointHits.size() > 1); + bool is_multimapping = (jointHits.size() > 1); for (auto &&jointHit : jointHits) { - auto hitScore = puffaligner.calculateAlignments(rp.seq, jointHit, hctr, isMultimapping, false); + auto hitScore = puffaligner.calculateAlignments(rp.seq, jointHit, hctr, is_multimapping, false); bool validScore = (hitScore != invalidScore); numMappingsDropped += validScore ? 0 : 1; auto tid = qidx->getRefId(jointHit.tid); @@ -1690,19 +1734,14 @@ void processReads( (jointHit.orphanClust()->isFw and (expectedLibraryFormat.strandedness == ReadStrandedness::S)) or (!jointHit.orphanClust()->isFw and (expectedLibraryFormat.strandedness == ReadStrandedness::A)); - salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, - msi, - //bestScore, secondBestScore, bestDecoyScore, - scores, bestScorePerTranscript, perm); + salmon::mapping_utils::updateRefMappings( + tid, hitScore, isCompat, idx, transcripts, invalidScore, msi); ++idx; } - //bool bestHitDecoy = (msi.bestScore < msi.bestDecoyScore); bool bestHitDecoy = msi.haveOnlyDecoyMappings(); if (msi.bestScore > invalidScore and !bestHitDecoy) { salmon::mapping_utils::filterAndCollectAlignments(jointHits, - scores, - perm, readLen, readLen, true, // true for single-end false otherwise @@ -1711,11 +1750,6 @@ void processReads( salmonOpts.scoreExp, salmonOpts.minAlnProb, msi, - /* - bestScore, - secondBestScore, - bestDecoyScore, - */ jointAlignments); // if we have any alignments, then they are // just single mapped. @@ -1727,10 +1761,32 @@ void processReads( mapType = (bestHitDecoy) ? salmon::utils::MappingType::DECOY : salmon::utils::MappingType::UNMAPPED; numDecoyFrags += bestHitDecoy ? 1 : 0; ++numFragsDropped; - jointHitGroup.clearAlignments(); + if (bestHitDecoy) { + salmon::mapping_utils::filterAndCollectAlignmentsDecoy( + jointHits, readLen, readLen, + true, // true for single-end false otherwise + tryAlign, hardFilter, salmonOpts.scoreExp, + salmonOpts.minAlnProb, msi, + jointAlignments); + } else { + jointHitGroup.clearAlignments(); + } } } + if (writeQuasimappings) { + writeAlignmentsToStreamSingle(rp, formatter, jointAlignments, sstream, false, true); + } + + // We've kept decoy aignments around to this point so that we can + // potentially write these alignments to the SAM file. However, if + // we got to this point and only have decoy mappings, then clear the + // mappings here because none of the procesing below is relevant for + // decoys. + if (mapType == salmon::utils::MappingType::DECOY) { + jointHitGroup.clearAlignments(); + } + bool needBiasSample = salmonOpts.biasCorrect; std::uniform_int_distribution<> dis(0, jointAlignments.size()); @@ -1762,10 +1818,10 @@ void processReads( // read start sequences. if (startPos >= readBias.contextBefore(!h.fwd) and startPos + readBias.contextAfter(!h.fwd) < static_cast(t.RefLength)) { - context.from_chars(txpStart + startPos - + context.fromChars(txpStart + startPos - readBias.contextBefore(!h.fwd)); if (!h.fwd) { - context.reverse_complement(); + context.rc(); } success = readBias.addSequence(context, 1.0); } @@ -1787,10 +1843,6 @@ void processReads( } } - if (writeQuasimappings) { - writeAlignmentsToStreamSingle(rp, formatter, jointAlignments, sstream, false, true); - } - if (writeUnmapped and mapType != salmon::utils::MappingType::SINGLE_MAPPED) { // If we have no mappings --- then there's nothing to do // unless we're outputting names for un-mapped reads @@ -2365,7 +2417,7 @@ int salmonQuantify(int argc, const char* argv[]) { auto hstring = R"( Quant ========== -Perform dual-phase, mapping-based estimation of +Perform dual-phase, selective-alignment-based estimation of transcript abundance from RNA-seq reads )"; std::cout << hstring << std::endl; @@ -2381,7 +2433,7 @@ transcript abundance from RNA-seq reads } std::stringstream commentStream; - commentStream << "### salmon (mapping-based) v" << salmon::version << "\n"; + commentStream << "### salmon (selective-alignment-based) v" << salmon::version << "\n"; commentStream << "### [ program ] => salmon \n"; commentStream << "### [ command ] => quant \n"; for (auto& opt : orderedOptions.options) { diff --git a/src/SalmonQuantifyAlignments.cpp b/src/SalmonQuantifyAlignments.cpp index f9950317b..3df955dab 100644 --- a/src/SalmonQuantifyAlignments.cpp +++ b/src/SalmonQuantifyAlignments.cpp @@ -16,7 +16,6 @@ extern "C" { #include #include #include -#include #include #include #include @@ -174,9 +173,13 @@ void processMiniBatch(AlignmentLibraryT& alnLib, auto orphanProb = salmonOpts.discardOrphansAln ? LOG_0 : LOG_EPSILON; // k-mers for sequence bias - Mer leftMer; - Mer rightMer; - Mer context; + //Mer leftMer; + //Mer rightMer; + //Mer context; + SBMer leftMer; + SBMer rightMer; + SBMer context; + auto& refs = alnLib.transcripts(); auto& clusterForest = alnLib.clusterForest(); @@ -208,6 +211,13 @@ void processMiniBatch(AlignmentLibraryT& alnLib, distribution_utils::LogCMFCache logCMFCache(&fragLengthDist, singleEndLib); + const size_t maxCacheLen{salmonOpts.fragLenDistMax}; + // Caches to avoid fld updates _within_ the set of alignments of a fragment + auto fetchPMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.pmf(l); }; + auto fetchCMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.cmf(l); }; + distribution_utils::IndexedVersionedCache pmfCache(maxCacheLen); + distribution_utils::IndexedVersionedCache cmfCache(maxCacheLen); + std::chrono::microseconds sleepTime(1); MiniBatchInfo>* miniBatch = nullptr; bool updateCounts = initialRound; @@ -314,6 +324,8 @@ void processMiniBatch(AlignmentLibraryT& alnLib, // alignments reported for a single read). Distribute the read's mass // proportionally dependent on the current for (auto& alnGroup : alignmentGroups) { + pmfCache.increment_generation(); + cmfCache.increment_generation(); // EQCLASS std::vector txpIDs; @@ -392,12 +404,15 @@ void processMiniBatch(AlignmentLibraryT& alnLib, if (flen > 0.0 and aln->isPaired() and useFragLengthDist and considerCondProb) { + size_t fl = flen; - double lenProb = fragLengthDist.pmf(fl); + double lenProb = pmfCache.get_or_update(fl, fetchPMF); + if (burnedIn) { /* condition fragment length prob on txp length */ - double refLengthCM = - fragLengthDist.cmf(static_cast(refLength)); + size_t rlen = static_cast(refLength); + double refLengthCM = cmfCache.get_or_update(fl, fetchCMF); + bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); logFragProb = (computeMass) ? (lenProb - refLengthCM) @@ -656,14 +671,15 @@ void processMiniBatch(AlignmentLibraryT& alnLib, int32_t fwPos = (fwd1) ? startPos1 : startPos2; int32_t rcPos = (fwd1) ? startPos2 : startPos1; if (fwPos < rcPos) { - leftMer.from_chars(txpStart + startPos1 - + leftMer.fromChars(txpStart + startPos1 - readBias1.contextBefore(read1RC)); - rightMer.from_chars(txpStart + startPos2 - + rightMer.fromChars(txpStart + startPos2 - readBias2.contextBefore(read2RC)); + if (read1RC) { - leftMer.reverse_complement(); + leftMer.rc(); } else { - rightMer.reverse_complement(); + rightMer.rc(); } success = readBias1.addSequence(leftMer, 1.0); @@ -692,10 +708,11 @@ void processMiniBatch(AlignmentLibraryT& alnLib, if (startPos1 >= readBias1.contextBefore(!fwd1) and startPos1 + readBias1.contextAfter(!fwd1) < static_cast(transcript.RefLength)) { - context.from_chars(txpStart + startPos1 - + context.fromChars(txpStart + startPos1 - readBias1.contextBefore(!fwd1)); + if (!fwd1) { - context.reverse_complement(); + context.rc(); } success = readBias1.addSequence(context, 1.0); } diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index 27eec4578..6c55178d7 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -37,7 +37,7 @@ #include "gff.h" #include "FastxParser.hpp" -#include "jellyfish/mer_dna.hpp" +//#include "jellyfish/mer_dna.hpp" #include "GenomicFeature.hpp" #include "SGSmooth.hpp" @@ -1075,7 +1075,7 @@ TranscriptGeneMap transcriptGeneMapFromGTF(const std::string& fname, auto logger = spdlog::get("jointLog"); // Use GffReader to read the file - GffReader reader(const_cast(fname.c_str())); + GffReader reader(const_cast(fname.c_str()), true, false); // Remember the optional attributes reader.readAll(true); @@ -1423,23 +1423,24 @@ std::string getCurrentTimeAsString() { "`--validateMappings` is generally recommended.\n"); } + bool is_pe_library = (numLeft + numRight > 0); + bool is_se_library = (numUnpaired > 0); + // currently there is some strange use for this in alevin, I think ... // check with avi. - if (numLeft + numRight > 0 and numUnpaired > 0) { + if (is_pe_library and is_se_library) { sopt.jointLog->warn("You seem to have passed in both un-paired reads and paired-end reads. " "It is not currently possible to quantify hybrid library types in salmon."); } - - if (numLeft + numRight > 0) { + if (is_pe_library) { if (numLeft != numRight) { sopt.jointLog->error("You passed paired-end files to salmon, but you passed {} files to --mates1 " "and {} files to --mates2. You must pass the same number of files to both flags", numLeft, numRight); return false; } - } - + } auto checkScoreValue = [&sopt](int16_t score, std::string sname) -> bool { using score_t = int8_t; @@ -1510,7 +1511,7 @@ std::string getCurrentTimeAsString() { sopt.useRangeFactorization = true; } - // If the consensus slack was not set explicitly, then it defaults to 0.2 with + // If the consensus slack was not set explicitly, then it defaults to 0.35 with // validateMappings bool consensusSlackExplicit = !vm["consensusSlack"].defaulted(); if (!consensusSlackExplicit) { @@ -1520,6 +1521,55 @@ std::string getCurrentTimeAsString() { "Setting consensusSlack to {}.", sopt.consensusSlack); } + bool pre_merge_chain_sub_thresh_explicit = !vm["preMergeChainSubThresh"].defaulted(); + bool post_merge_chain_sub_thresh_explicit = !vm["postMergeChainSubThresh"].defaulted(); + bool orphan_chain_sub_thresh_explicit = !vm["orphanChainSubThresh"].defaulted(); + + // for a single-end library (or effectively so by being single-cell), we set + // pre_merge_chain_sub_thresh to 1.0 by default + if ( is_se_library or sopt.alevinMode ) { + + // The default of preMergeChainSubThresh for single-end libraries is 1.0, so set that here + if (!pre_merge_chain_sub_thresh_explicit) { + sopt.pre_merge_chain_sub_thresh = 1.0; + } + + // for single-end libraries, postMergeChainSubThresh and orphanChainSubThresh are meaningless + if (post_merge_chain_sub_thresh_explicit) { + sopt.jointLog->warn("The postMergeChainSubThresh is not meaningful for single-end " + "(or effectively single-end — e.g. tagged-end single-cell) libraries. Setting this value " + "to 1.0 and ignoring"); + } + if (orphan_chain_sub_thresh_explicit) { + sopt.jointLog->warn("The orphanChainSubThresh is not meaningful for single-end " + "(or effectively single-end — e.g. tagged-end single-cell) libraries. Setting this value " + "to 1.0 and ignoring"); + } + sopt.post_merge_chain_sub_thresh = 1.0; + sopt.orphan_chain_sub_thresh = 1.0; + } + + // value range check for filters + // pre-merge + if (sopt.pre_merge_chain_sub_thresh < 0 or sopt.pre_merge_chain_sub_thresh > 1.0) { + sopt.jointLog->error("You set preMergeChainSubThresh as {}, but it must in [0,1].", + sopt.pre_merge_chain_sub_thresh); + return false; + } + // post-merge + if (sopt.post_merge_chain_sub_thresh < 0 or sopt.post_merge_chain_sub_thresh > 1.0) { + sopt.jointLog->error("You set postMergeChainSubThresh as {}, but it must in [0,1].", + sopt.post_merge_chain_sub_thresh); + return false; + } + // orphan + if (sopt.orphan_chain_sub_thresh < 0 or sopt.orphan_chain_sub_thresh > 1.0) { + sopt.jointLog->error("You set orphanChainSubThresh as {}, but it must in [0,1].", + sopt.orphan_chain_sub_thresh); + return false; + } + + if (sopt.mimicBT2 and sopt.mimicStrictBT2) { sopt.jointLog->error("You passed both the --mimicBT2 and --mimicStrictBT2 parameters. These are mutually exclusive. " "Please select only one of these flags."); @@ -1655,7 +1705,7 @@ bool createAuxMapLoggers_(SalmonOpts& sopt, sopt.orphanLinkLog = outLog; } - // Determine what we'll do with quasi-mapping results + // Determine what we'll do with selective-alignment results bool writeQuasimappings = (sopt.qmFileName != ""); if (writeQuasimappings) { @@ -1682,7 +1732,7 @@ bool createAuxMapLoggers_(SalmonOpts& sopt, // Make sure file opened successfully. if (!sopt.qmFile.is_open()) { jointLog->error( - "Could not create file for writing quasi-mappings [{}]", + "Could not create file for writing selective-alignments [{}]", sopt.qmFileName); return false; } @@ -1856,7 +1906,7 @@ bool processQuantOptions(SalmonOpts& sopt, bfs::path indexDirectory(vm["index"].as()); sopt.indexDirectory = indexDirectory; - // Determine what we'll do with quasi-mapping results + // Determine what we'll do with selective-alignment results bool writeQuasimappings = (sopt.qmFileName != ""); // make it larger if we're writing mappings or @@ -1875,6 +1925,7 @@ bool processQuantOptions(SalmonOpts& sopt, // std::make_shared(rawConsoleSink); auto consoleSink = std::make_shared(); + consoleSink->set_color(spdlog::level::warn, consoleSink->magenta); auto consoleLog = spdlog::create("stderrLog", {consoleSink}); auto fileLog = spdlog::create("fileLog", {fileSink}); std::vector sinks{consoleSink, fileSink}; @@ -2652,14 +2703,21 @@ int contextSize = outsideContext + insideContext; windowLensTP.setZero(); // This transcript's sequence - const char* tseq = txp.Sequence(); - revComplement(tseq, refLen, rcSeq); + bool have_seq = txp.have_sequence(); + SBMer fwmer; + SBMer rcmer; + + const char* tseq = have_seq ? txp.Sequence() : nullptr; + // only do this if we have the sequence, we may not if just pos. bias. + if (have_seq) { + revComplement(tseq, refLen, rcSeq); + fwmer.fromChars(tseq); + rcmer.fromChars(rcSeq); + } + // may be empty, but we shouldn't actually + // use it unless it is meaningful. const char* rseq = rcSeq.c_str(); - Mer fwmer; - fwmer.from_chars(tseq); - Mer rcmer; - rcmer.from_chars(rseq); int32_t contextLength{expectSeqFW.getContextLength()}; if (gcBiasCorrect and seqBiasCorrect) { @@ -2692,8 +2750,10 @@ int contextSize = outsideContext + insideContext; } // shift the context one nucleotide to the right - fwmer.shift_left(tseq[fragStartPos + contextLength]); - rcmer.shift_left(rseq[fragStartPos + contextLength]); + //fwmer.shift_left(tseq[fragStartPos + contextLength]); + //rcmer.shift_left(rseq[fragStartPos + contextLength]); + fwmer.append(tseq[fragStartPos + contextLength]); + rcmer.append(rseq[fragStartPos + contextLength]); } // end: Seq-specific bias // fragment-GC bias @@ -2887,8 +2947,15 @@ int contextSize = outsideContext + insideContext; std::vector posFactorsRC(refLen, 1.0); // This transcript's sequence - const char* tseq = txp.Sequence(); - revComplement(tseq, refLen, rcSeq); + bool have_seq = txp.have_sequence(); + const char* tseq = have_seq ? txp.Sequence() : nullptr; + // only do this if we have the sequence, we may not if just pos. + // bias. + if (have_seq) { + revComplement(tseq, refLen, rcSeq); + } + // may be empty, but we shouldn't actually + // use it unless it is meaningful. const char* rseq = rcSeq.c_str(); int32_t fl = locFLDLow; @@ -2929,10 +2996,15 @@ int contextSize = outsideContext + insideContext; // and seqFactorsRC will contain the sequence-specific bias for each // position on the 3' strand. if (seqBiasCorrect) { - Mer mer; - Mer rcmer; - mer.from_chars(tseq); - rcmer.from_chars(rseq); + //Mer mer; + //Mer rcmer; + //mer.from_chars(tseq); + //rcmer.from_chars(rseq); + SBMer mer; + SBMer rcmer; + mer.fromChars(tseq); + rcmer.fromChars(rseq); + int32_t contextLength{exp5.getContextLength()}; for (int32_t fragStart = 0; fragStart < refLen - K; ++fragStart) { @@ -2948,8 +3020,10 @@ int contextSize = outsideContext + insideContext; exp3.evaluateLog(rcmer)); } // shift the context one nucleotide to the right - mer.shift_left(tseq[fragStart + contextLength]); - rcmer.shift_left(rseq[fragStart + contextLength]); + //mer.shift_left(tseq[fragStart + contextLength]); + //rcmer.shift_left(rseq[fragStart + contextLength]); + mer.append(tseq[fragStart + contextLength]); + rcmer.append(rseq[fragStart + contextLength]); } // We need these in 5' -> 3' order, so reverse them seqFactorsRC.reverseInPlace(); @@ -3309,59 +3383,19 @@ template void salmon::utils::normalizeAlphas>( template void salmon::utils::normalizeAlphas>( const SalmonOpts& sopt, BulkAlignLibT& alnLib); -// explicit instantiations for effective length updates --- -/* -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, - ReadExperiment>( - SalmonOpts& sopt, ReadExperiment& readExp, Eigen::VectorXd& effLensIn, - std::vector>& alphas, bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths, ReadExperiment>( - SalmonOpts& sopt, ReadExperiment& readExp, Eigen::VectorXd& effLensIn, - std::vector& alphas, bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, - bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector& alphas, bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, - bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector& alphas, bool finalRound); -*/ - // explicit instantiations for effective length updates --- template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, BulkExpT>( SalmonOpts& sopt, BulkExpT& readExp, Eigen::VectorXd& effLensIn, - std::vector>& alphas, std::vector& available, + std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, SCExpT>( SalmonOpts& sopt, SCExpT& readExp, Eigen::VectorXd& effLensIn, - std::vector>& alphas, std::vector& available, + std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd @@ -3375,10 +3409,10 @@ salmon::utils::updateEffectiveLengths, SCExpT>( std::vector& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, BulkAlignLibT>( SalmonOpts& sopt, BulkAlignLibT& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, + Eigen::VectorXd& effLensIn, std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd @@ -3389,10 +3423,10 @@ salmon::utils::updateEffectiveLengths, std::vector& available, bool finalRound); template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, BulkAlignLibT>( SalmonOpts& sopt, BulkAlignLibT& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, + Eigen::VectorXd& effLensIn, std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd diff --git a/src/VersionChecker.cpp b/src/VersionChecker.cpp index 64d2d0e4d..4cc5634a4 100644 --- a/src/VersionChecker.cpp +++ b/src/VersionChecker.cpp @@ -1,187 +1,28 @@ -// based off of -// async_client.cpp -// ~~~~~~~~~~~~~~~~ -// -// Copyright (c) 2003-2012 Christopher M. Kohlhoff (chris at kohlhoff dot com) -// -// Distributed under the Boost Software License, Version 1.0. (See accompanying -// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) -// - #include "VersionChecker.hpp" #include "SalmonConfig.hpp" - -using boost::asio::ip::tcp; - -VersionChecker::VersionChecker(boost::asio::io_service& io_service, - const std::string& server, - const std::string& path) - : resolver_(io_service), socket_(io_service), deadline_(io_service) { - // Form the request. We specify the "Connection: close" header so that the - // server will close the socket after transmitting the response. This will - // allow us to treat all data up until the EOF as the content. - std::ostream request_stream(&request_); - request_stream << "GET " << path << " HTTP/1.0\r\n"; - request_stream << "Host: " << server << "\r\n"; - request_stream << "Accept: */*\r\n"; - request_stream << "Connection: close\r\n\r\n"; - - - deadline_.expires_from_now(boost::posix_time::seconds(1)); - deadline_.async_wait(boost::bind(&VersionChecker::cancel_upgrade_check, this, - boost::asio::placeholders::error)); - - // Start an asynchronous resolve to translate the server and service names - // into a list of endpoints. - tcp::resolver::query query(server, "http"); - resolver_.async_resolve(query, - boost::bind(&VersionChecker::handle_resolve, this, - boost::asio::placeholders::error, - boost::asio::placeholders::iterator)); -} - -std::string VersionChecker::message() { return messageStream_.str(); } - -void VersionChecker::cancel_upgrade_check( - const boost::system::error_code& err) { - if (err != boost::asio::error::operation_aborted) { - deadline_.cancel(); - messageStream_ - << "Could not resolve upgrade information in the alotted time.\n"; - messageStream_ << "Check for upgrades manually at " - "https://combine-lab.github.io/salmon\n"; - socket_.close(); - } -} - -void VersionChecker::handle_resolve(const boost::system::error_code& err, - tcp::resolver::iterator endpoint_iterator) { - if (!err) { - // Attempt a connection to each endpoint in the list until we - // successfully establish a connection. - boost::asio::async_connect(socket_, endpoint_iterator, - boost::bind(&VersionChecker::handle_connect, - this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_connect(const boost::system::error_code& err) { - if (!err) { - // The connection was successful. Send the request. - boost::asio::async_write(socket_, request_, - boost::bind(&VersionChecker::handle_write_request, - this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_write_request( - const boost::system::error_code& err) { - if (!err) { - // Read the response status line. The response_ streambuf will - // automatically grow to accommodate the entire line. The growth may be - // limited by passing a maximum size to the streambuf constructor. - boost::asio::async_read_until( - socket_, response_, "\r\n", - boost::bind(&VersionChecker::handle_read_status_line, this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_read_status_line( - const boost::system::error_code& err) { - if (!err) { - // Check that response is OK. - std::istream response_stream(&response_); - std::string http_version; - response_stream >> http_version; - unsigned int status_code; - response_stream >> status_code; - std::string status_message; - std::getline(response_stream, status_message); - if (!response_stream || http_version.substr(0, 5) != "HTTP/") { - deadline_.cancel(); - cancel_upgrade_check(err); - return; - } - if (status_code != 200) { - deadline_.cancel(); - cancel_upgrade_check(err); - return; - } - - // Read the response headers, which are terminated by a blank line. - boost::asio::async_read_until( - socket_, response_, "\r\n\r\n", - boost::bind(&VersionChecker::handle_read_headers, this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_read_headers(const boost::system::error_code& err) { - if (!err) { - deadline_.cancel(); - // Process the response headers. - std::istream response_stream(&response_); - std::string header; - while (std::getline(response_stream, header) && header != "\r") { - } - - // Write whatever content we already have to output. - if (response_.size() > 0) - messageStream_ << &response_; - - // Start reading remaining data until EOF. - boost::asio::async_read( - socket_, response_, boost::asio::transfer_at_least(1), - boost::bind(&VersionChecker::handle_read_content, this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_read_content(const boost::system::error_code& err) { - if (!err) { - // Write all of the data that has been read so far. - messageStream_ << &response_; - - // Continue reading remaining data until EOF. - boost::asio::async_read( - socket_, response_, boost::asio::transfer_at_least(1), - boost::bind(&VersionChecker::handle_read_content, this, - boost::asio::placeholders::error)); - } else if (err != boost::asio::error::eof) { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} +#include "httplib.hpp" std::string getVersionMessage() { - std::string baseSite{"combine-lab.github.io"}; - std::string path{"/salmon/version_info/"}; - path += salmon::version; - std::stringstream ss; try { - boost::asio::io_service io_service; - VersionChecker c(io_service, baseSite, path); - io_service.run(); - ss << "Version Info: " << c.message(); + // NOTE: getaddrinfo / freeaddrinfo will cause a "memory leak" + // once per address, program pair. This is a known issue + // https://lists.debian.org/debian-glibc/2016/03/msg00243.html. + // If valgrind leads you here, best not to worry about it. + httplib::Client cli("combine-lab.github.io"); + std::string path{"/salmon/version_info/"}; + path += salmon::version; + cli.set_timeout_sec(2); // timeouts in 2 seconds + auto res = cli.Get(path.c_str()); + if (res) { // non-null response + if (res->status == 200) { // response OK + ss << "Version Info: " << res->body; + } else { // response something else + ss << "Version Server Response: " << httplib::detail::status_message(res->status) << "\n"; + } + } else { // null response + ss << "Version Info Exception: server did not respond before timeout\n"; + } } catch (std::exception& e) { ss << "Version Info Exception: " << e.what() << "\n"; } diff --git a/tests/basic_alevin_test.sh b/tests/basic_alevin_test.sh index f1464dd7c..2bf1fd9e3 100755 --- a/tests/basic_alevin_test.sh +++ b/tests/basic_alevin_test.sh @@ -1,9 +1,10 @@ ALVBIN=$1 #"/mnt/scratch6/salmon_ci/COMBINE-lab/salmon/master/build/salmon-latest_linux_x86_64/bin/salmon" +BASEDIR="/mnt/scratch6/avi/alevin/alevin" OUT=$PWD tfile=$(mktemp /tmp/foo.XXXXXXXXX) -/usr/bin/time -o $tfile $ALVBIN alevin -lISR --chromium -1 /mnt/scratch5/avi/alevin/data/10x/v2/mohu/100/all_bcs.fq -2 /mnt/scratch5/avi/alevin/data/10x/v2/mohu/100/all_reads.fq -o $OUT/prediction -i /mnt/scratch5/avi/alevin/data/mohu/salmon_index -p 20 --tgMap /mnt/scratch5/avi/alevin/data/mohu/gtf/txp2gene.tsv --dumpMtx --no-version-check --dumpFeatures --dumpArborescence #--whitelist ./alevin_test_data/alevin/quants_mat_rows.txt +/usr/bin/time -o $tfile $ALVBIN alevin -lISR --chromium -1 ${BASEDIR}/data/10x/v2/mohu/100/all_bcs.fq -2 ${BASEDIR}/data/10x/v2/mohu/100/all_reads.fq -o $OUT/prediction -i ${BASEDIR}/data/mohu/salmon_index -p 20 --tgMap ${BASEDIR}/data/mohu/gtf/txp2gene.tsv --dumpMtx --no-version-check --dumpFeatures --dumpArborescence --writeMappings=$OUT/prediction/with_bug.sam #--whitelist ./alevin_test_data/alevin/quants_mat_rows.txt #--dumpBfh --whitelist /mnt/scratch5/avi/alevin/bin/salmon/tests/whitelist.txt #--dumpUmiGraph --numCellBootstraps 100 --dumpBfh --dumpBarcodeEq --dumpMtx --expectCells 1001 --end 6 --umiLength 10 --barcodeLength 16