diff --git a/.drone.yml b/.drone.yml deleted file mode 100644 index 87686c61066..00000000000 --- a/.drone.yml +++ /dev/null @@ -1,36 +0,0 @@ - kind: pipeline - type: docker - name: amd64-gcc - platform: - arch: amd64 - - steps: - - name: build-gcc - image: ubuntu:bionic - commands: - - apt-get update -y - - apt-get install -y g++ cmake libtbb-dev libcurl4-openssl-dev libace-dev git libmysql++-dev - - export TBB_ROOT_DIR=/usr/include/tbb/ - - export ACE_ROOT=/usr/include/ace/ - - mkdir build - - cd build - - cmake -DDEBUG=0 -DUSE_LIBCURL=1 .. - - make -j$(nproc) - - --- - kind: pipeline - type: docker - name: amd64-clang - - steps: - - name: build-clang - image: ubuntu:bionic - commands: - - apt-get update -y - - apt-get install -y clang cmake libtbb-dev libcurl4-openssl-dev libace-dev git libmysql++-dev - - export CC=/usr/bin/clang - - export CXX=/usr/bin/clang++ - - mkdir build - - cd build - - cmake -DDEBUG=0 -DUSE_LIBCURL=1 .. - - make -j$(nproc) diff --git a/.github/workflows/dev-release.yml b/.github/workflows/dev-release.yml index 9d22ddbbf14..7dfbdbb94bf 100644 --- a/.github/workflows/dev-release.yml +++ b/.github/workflows/dev-release.yml @@ -11,7 +11,6 @@ on: - '.github/workflows/db_check.yml' - '.github/workflows/db_dump.yml' - 'sql/**' - - '.drone.yml' - 'README.md' - 'LICENSE' - '.gitignore' @@ -26,23 +25,13 @@ jobs: - uses: actions/checkout@v4 - name: windows dependencies - #Sets versions for ACE/TBB + #Sets versions for TBB env: - ACE_VERSION: 6.5.11 - ACE_VERSION2: 6_5_11 TBB_VERSION: 2020.3 run: | - #directory variables - export ACE_ROOT=$GITHUB_WORKSPACE/ACE_wrappers + # Setup TBB export TBB_ROOT_DIR=$GITHUB_WORKSPACE/tbb - #ACE package download - curl -LOJ http://github.com/DOCGroup/ACE_TAO/releases/download/ACE%2BTAO-$ACE_VERSION2/ACE-$ACE_VERSION.zip - unzip ACE-$ACE_VERSION.zip - rm ACE-$ACE_VERSION.zip - #configuration of ACE header - echo "#include \"ace/config-win32.h\"" >> $ACE_ROOT/ace/config.h - #TBB package download curl -LOJ https://github.com/oneapi-src/oneTBB/releases/download/v$TBB_VERSION/tbb-$TBB_VERSION-win.zip unzip tbb-$TBB_VERSION-win.zip rm tbb-$TBB_VERSION-win.zip @@ -52,11 +41,7 @@ jobs: #build and install - name: windows build & install run: | - #directory variables - export ACE_ROOT=$GITHUB_WORKSPACE/ACE_wrappers - cd $GITHUB_WORKSPACE/ACE_wrappers - /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2019/Enterprise/MSBuild/Current/Bin/MSBuild.exe "ACE_wrappers_vs2019.sln" //p:Configuration=Release //p:Platform=x64 //t:ACE //m:2 - cd $GITHUB_WORKSPACE + # Setup TBB mkdir build cd build cmake -D TBB_ROOT_DIR=$GITHUB_WORKSPACE/tbb -DWITH_WARNINGS=0 -DBUILD_FOR_HOST_CPU=0 -DUSE_EXTRACTORS=1 -G "Visual Studio 16 2019" -A x64 .. @@ -79,7 +64,6 @@ jobs: copy ${{github.workspace}}/tbb/bin/intel64/vc14/tbbmalloc_proxy.dll ${{github.workspace}}/bin/Release/tbbmalloc_proxy.dll copy ${{github.workspace}}/tbb/bin/intel64/vc14/tbbmalloc_proxy_debug.dll ${{github.workspace}}/bin/Release/tbbmalloc_proxy_debug.dll - copy ${{github.workspace}}/ACE_wrappers/lib/ACE.dll ${{github.workspace}}/bin/Release/ACE.dll copy ${{github.workspace}}/dep/windows/lib/x64_release/libmySQL.dll ${{github.workspace}}/bin/Release/libmySQL.dll # copy "c:/Program Files/OpenSSL-Win64/bin/libssl-1_1-x64.dll" ${{github.workspace}}/bin/Release/libssl-1_1-x64.dll # copy "c:/Program Files/OpenSSL-Win64/bin/libcrypto-1_1-x64.dll" ${{github.workspace}}/bin/Release/libcrypto-1_1-x64.dll diff --git a/.github/workflows/vmangos.yml b/.github/workflows/vmangos.yml index 8eb76bf9259..38c53db7ac5 100644 --- a/.github/workflows/vmangos.yml +++ b/.github/workflows/vmangos.yml @@ -15,7 +15,6 @@ on: - '.github/workflows/db_check.yml' - '.github/workflows/db_dump.yml' - 'sql/**' - - '.drone.yml' - 'README.md' - 'LICENSE' - '.gitignore' @@ -29,7 +28,6 @@ on: - '.github/workflows/db_check.yml' - '.github/workflows/db_dump.yml' - 'sql/**' - - '.drone.yml' - 'README.md' - 'LICENSE' - '.gitignore' @@ -57,27 +55,17 @@ jobs: if: matrix.os == 'ubuntu-latest' run: | sudo apt-get -qq update - sudo apt-get -qq install build-essential cmake libace-dev libmysql++-dev libtbb-dev libcurl4-openssl-dev openssl + sudo apt-get -qq install build-essential cmake libmysql++-dev libtbb-dev libcurl4-openssl-dev openssl #windows dependencies - name: windows dependencies if: matrix.os == 'windows-2019' - #Sets versions for ACE/TBB + #Sets versions TBB env: - ACE_VERSION: 6.5.11 - ACE_VERSION2: 6_5_11 TBB_VERSION: 2020.3 run: | - #directory variables - export ACE_ROOT=$GITHUB_WORKSPACE/ACE_wrappers + # Setup TBB export TBB_ROOT_DIR=$GITHUB_WORKSPACE/tbb - #ACE package download - curl -LOJ http://github.com/DOCGroup/ACE_TAO/releases/download/ACE%2BTAO-$ACE_VERSION2/ACE-$ACE_VERSION.zip - unzip ACE-$ACE_VERSION.zip - rm ACE-$ACE_VERSION.zip - #configuration of ACE header - echo "#include \"ace/config-win32.h\"" >> $ACE_ROOT/ace/config.h - #TBB package download curl -LOJ https://github.com/oneapi-src/oneTBB/releases/download/v$TBB_VERSION/tbb-$TBB_VERSION-win.zip unzip tbb-$TBB_VERSION-win.zip rm tbb-$TBB_VERSION-win.zip @@ -99,10 +87,6 @@ jobs: - name: windows build & install if: matrix.os == 'windows-2019' run: | - # Build ACE - export ACE_ROOT=$GITHUB_WORKSPACE/ACE_wrappers - cd $GITHUB_WORKSPACE/ACE_wrappers - /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2019/Enterprise/MSBuild/Current/Bin/MSBuild.exe "ACE_wrappers_vs2019.sln" //p:Configuration=Release //p:Platform=x64 //t:ACE //m:2 # Build CURL cd $GITHUB_WORKSPACE/dep/windows/optional_dependencies/ ./curl_download_and_build.bat diff --git a/CMakeLists.txt b/CMakeLists.txt index b19b6327ea3..5950aa7b311 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,7 +20,7 @@ cmake_minimum_required(VERSION 3.1...3.20) project(MaNGOS) -# Allow -DACE_ROOT, -DTBB_ROOT, etc. +# Allow -DTBB_ROOT, etc. if(${CMAKE_VERSION} VERSION_GREATER "3.11") cmake_policy(SET CMP0074 NEW) endif() @@ -83,19 +83,6 @@ set(CMAKE_INSTALL_RPATH ${LIBS_DIR}) set(CMAKE_INSTALL_RPATH_USE_LINK_PATH ON) # Find needed packages and if necessery abort if something important is missing -unset(ACE_INCLUDE_DIR CACHE) -unset(ACE_LIBRARIES CACHE) -unset(ACE_LIBRARIES_DIR CACHE) -unset(ACE_INCLUDE_DIR) -unset(ACE_LIBRARIES) -unset(ACE_LIBRARIES_DIR) - -find_package(ACE) -if(NOT ACE_FOUND) - message(FATAL_ERROR - "This project requires ACE installed. Please download the ACE Micro Release Kit from http://download.dre.vanderbilt.edu/ and install it. If this script didn't find ACE and it was correctly installed please set ACE_ROOT to the correct path." - ) -endif() if(NOT USE_STD_MALLOC) unset(TBB_INCLUDE_DIRS CACHE) diff --git a/cmake/find/FindACE.cmake b/cmake/find/FindACE.cmake deleted file mode 100644 index 5411be723b1..00000000000 --- a/cmake/find/FindACE.cmake +++ /dev/null @@ -1,81 +0,0 @@ -# -# Find the ACE client includes and library -# - -# This module defines -# ACE_INCLUDE_DIR, where to find ace.h -# ACE_LIBRARIES, the libraries to link against -# ACE_FOUND, if false, you cannot build anything that requires ACE - -# also defined, but not for general use are -# ACE_LIBRARY, where to find the ACE library. - -set(ACE_FOUND 0) - -if (UNIX) - - FIND_PATH(ACE_INCLUDE_DIR - NAMES - ace/ACE.h - PATHS - /usr/include - /usr/include/ace - /usr/local/include - /usr/local/include/ace - ${ACE_ROOT} - ${ACE_ROOT}/include - $ENV{ACE_ROOT} - $ENV{ACE_ROOT}/include - DOC "Specify include-directories that might contain ace.h here.") - - FIND_LIBRARY(ACE_LIBRARIES - NAMES - ace ACE - PATHS - /usr/lib - /usr/lib/ace - /usr/local/lib - /usr/local/lib/ace - /usr/local/ace/lib - ${ACE_ROOT} - ${ACE_ROOT}/lib - $ENV{ACE_ROOT}/lib - $ENV{ACE_ROOT} - DOC "Specify library-locations that might contain the ACE library here.") -endif (UNIX) - -if (WIN32) - - FIND_PATH(ACE_INCLUDE_DIR - NAMES - ace/ACE.h - PATHS - ${ACE_ROOT} - ${ACE_ROOT}/include - $ENV{ACE_ROOT} - $ENV{ACE_ROOT}/include - DOC "Specify include-directories that might contain ace.h here.") - - FIND_LIBRARY(ACE_LIBRARIES - NAMES - ace ACE ACEd - PATHS - ${ACE_ROOT} - ${ACE_ROOT}/lib - $ENV{ACE_ROOT}/lib - $ENV{ACE_ROOT} - DOC "Specify library-locations that might contain the ACE library here.") - -endif (WIN32) - -if (ACE_LIBRARIES) - if (ACE_INCLUDE_DIR) - set(ACE_FOUND 1) - message(STATUS "Found ACE library: ${ACE_LIBRARIES}") - message( STATUS "Found ACE headers: ${ACE_INCLUDE_DIR}") - else (ACE_INCLUDE_DIR) - message(FATAL_ERROR "Could not find ACE headers! Please install ACE libraries and headers") - endif (ACE_INCLUDE_DIR) -endif (ACE_LIBRARIES) - -mark_as_advanced(ACE_FOUND ACE_LIBRARIES ACE_INCLUDE_DIR) diff --git a/contrib/mmap/src/MMapCommon.h b/contrib/mmap/src/MMapCommon.h index 6b8ec860017..881b3709848 100644 --- a/contrib/mmap/src/MMapCommon.h +++ b/contrib/mmap/src/MMapCommon.h @@ -19,11 +19,6 @@ #ifndef _MMAP_COMMON_H #define _MMAP_COMMON_H -// stop warning spam from ACE includes -#ifdef _WIN32 -# pragma warning( disable : 4996 ) -#endif - #include #include diff --git a/contrib/mmap/src/MapBuilder.cpp b/contrib/mmap/src/MapBuilder.cpp index 291518aeb5f..25f2f4e8d4f 100644 --- a/contrib/mmap/src/MapBuilder.cpp +++ b/contrib/mmap/src/MapBuilder.cpp @@ -24,6 +24,7 @@ #include "Maps/GridMapDefines.h" #include "DetourNavMeshBuilder.h" #include "DetourCommon.h" +#include using namespace VMAP; diff --git a/contrib/mmap/src/MapBuilder.h b/contrib/mmap/src/MapBuilder.h index 3a3ab850f69..2ff195662c5 100644 --- a/contrib/mmap/src/MapBuilder.h +++ b/contrib/mmap/src/MapBuilder.h @@ -30,7 +30,6 @@ #include "TileWorker.h" using namespace VMAP; -// G3D namespace typedefs conflicts with ACE typedefs using json = nlohmann::json; diff --git a/dep/CMakeLists.txt b/dep/CMakeLists.txt index 03e0dd8913e..0644da3d755 100644 --- a/dep/CMakeLists.txt +++ b/dep/CMakeLists.txt @@ -21,3 +21,12 @@ add_subdirectory(gsoap) if (USE_EXTRACTORS) add_subdirectory(libmpq) endif() + +add_subdirectory(cpptrace) +set_target_properties(cpptrace-lib PROPERTIES FOLDER "3rd Party") +get_cmake_property(variables CACHE_VARIABLES) # hide "CPPTRACE_*" variables in default cmake config +foreach(variable ${variables}) + if(variable MATCHES "^CPPTRACE_") + mark_as_advanced(${variable}) + endif() +endforeach() diff --git a/dep/cpptrace/CMakeLists.txt b/dep/cpptrace/CMakeLists.txt new file mode 100644 index 00000000000..2c73da03f13 --- /dev/null +++ b/dep/cpptrace/CMakeLists.txt @@ -0,0 +1,619 @@ +cmake_minimum_required(VERSION 3.12) + +include(cmake/PreventInSourceBuilds.cmake) + +# ---- Initialize Project ---- + +# Used to support find_package +set(package_name "cpptrace") + +project( + cpptrace + VERSION 0.7.1 + DESCRIPTION "Simple, portable, and self-contained stacktrace library for C++11 and newer " + HOMEPAGE_URL "https://github.com/jeremy-rifkin/cpptrace" + LANGUAGES C CXX +) + +# Don't change include order, OptionVariables checks if project is top level +include(cmake/ProjectIsTopLevel.cmake) +include(cmake/OptionVariables.cmake) + +include(GNUInstallDirs) +include(CheckCXXSourceCompiles) +include(CheckCXXCompilerFlag) + +if(PROJECT_IS_TOP_LEVEL) + find_program(CCACHE_FOUND ccache) + if(CCACHE_FOUND) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) + endif() +endif() + +if(PROJECT_IS_TOP_LEVEL) + if(CMAKE_GENERATOR STREQUAL "Ninja") + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fdiagnostics-color=always") + elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics") + endif() + if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU") + SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fdiagnostics-color=always") + elseif("${CMAKE_C_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") + SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fcolor-diagnostics") + endif() + endif() +endif() + +if(CPPTRACE_SANITIZER_BUILD) + add_compile_options(-fsanitize=address) + add_link_options(-fsanitize=address) +endif() + +if(NOT "${CPPTRACE_BACKTRACE_PATH}" STREQUAL "") + # quotes used over <> because of a macro substitution issue where + # + # is expanded to + # + string(CONCAT CPPTRACE_BACKTRACE_PATH "\"" ${CPPTRACE_BACKTRACE_PATH}) + string(CONCAT CPPTRACE_BACKTRACE_PATH ${CPPTRACE_BACKTRACE_PATH} "\"") + #message(STATUS ${CPPTRACE_BACKTRACE_PATH}) + string(CONCAT CPPTRACE_BACKTRACE_PATH_DEFINITION "-DCPPTRACE_BACKTRACE_PATH=" ${CPPTRACE_BACKTRACE_PATH}) + #message(STATUS ${CPPTRACE_BACKTRACE_PATH_DEFINITION}) +else() + set(CPPTRACE_BACKTRACE_PATH_DEFINITION "") +endif() + +# =============================================== Platform Support =============================================== +function(check_support var source includes libraries definitions) + set(CMAKE_REQUIRED_INCLUDES "${includes}") + list(APPEND CMAKE_REQUIRED_INCLUDES "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + set(CMAKE_REQUIRED_LIBRARIES "${libraries}") + set(CMAKE_REQUIRED_DEFINITIONS "${definitions}") + string(CONCAT full_source "#include \"${source}\"" ${nonce}) + check_cxx_source_compiles(${full_source} ${var}) + set(${var} ${${var}} PARENT_SCOPE) +endfunction() + +if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + check_support(HAS_CXXABI has_cxxabi.cpp "" "" "") +endif() + +if(NOT WIN32) + check_support(HAS_UNWIND has_unwind.cpp "" "" "") + check_support(HAS_EXECINFO has_execinfo.cpp "" "" "") +else() + check_support(HAS_STACKWALK has_stackwalk.cpp "" "dbghelp" "") +endif() + +if(NOT WIN32 OR MINGW) + check_support(HAS_BACKTRACE has_backtrace.cpp "" "backtrace" "${CPPTRACE_BACKTRACE_PATH_DEFINITION}") + set(STACKTRACE_LINK_LIB "stdc++_libbacktrace") + check_support(HAS_CXX_EXCEPTION_TYPE has_cxx_exception_type.cpp "" "" "") +endif() + +if(UNIX AND NOT APPLE) + check_support(HAS_DL_FIND_OBJECT has_dl_find_object.cpp "" "dl" "") + if(NOT HAS_DL_FIND_OBJECT) + check_support(HAS_DLADDR1 has_dladdr1.cpp "" "dl" "") + endif() +endif() + +if(APPLE) + check_support(HAS_MACH_VM has_mach_vm.cpp "" "" "") +endif() + +# =============================================== Autoconfig unwinding =============================================== +# Unwind back-ends +if( + NOT ( + CPPTRACE_UNWIND_WITH_UNWIND OR + CPPTRACE_UNWIND_WITH_LIBUNWIND OR + CPPTRACE_UNWIND_WITH_EXECINFO OR + CPPTRACE_UNWIND_WITH_WINAPI OR + CPPTRACE_UNWIND_WITH_DBGHELP OR + CPPTRACE_UNWIND_WITH_NOTHING + ) +) + # Attempt to auto-config + if(APPLE AND ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang")) + if(HAS_EXECINFO) + set(CPPTRACE_UNWIND_WITH_EXECINFO On) + message(STATUS "Cpptrace auto config: Using execinfo.h for unwinding") + else() + set(CPPTRACE_UNWIND_WITH_NOTHING On) + message(FATAL_ERROR "Cpptrace auto config: No unwinding back-end seems to be supported, stack tracing will not work. To compile anyway set CPPTRACE_UNWIND_WITH_NOTHING.") + endif() + elseif(UNIX) + if(HAS_UNWIND) + set(CPPTRACE_UNWIND_WITH_UNWIND On) + message(STATUS "Cpptrace auto config: Using libgcc unwind for unwinding") + elseif(HAS_EXECINFO) + set(CPPTRACE_UNWIND_WITH_EXECINFO On) + message(STATUS "Cpptrace auto config: Using execinfo.h for unwinding") + else() + set(CPPTRACE_UNWIND_WITH_NOTHING On) + message(FATAL_ERROR "Cpptrace auto config: No unwinding back-end seems to be supported, stack tracing will not work. To compile anyway set CPPTRACE_UNWIND_WITH_NOTHING.") + endif() + elseif(MINGW OR WIN32) + if(HAS_STACKWALK) + set(CPPTRACE_UNWIND_WITH_DBGHELP On) + message(STATUS "Cpptrace auto config: Using dbghelp for unwinding") + else() + set(CPPTRACE_UNWIND_WITH_WINAPI On) + message(STATUS "Cpptrace auto config: Using winapi for unwinding") + endif() + endif() +else() + #message(STATUS "MANUAL CONFIG SPECIFIED") +endif() + +# =============================================== Autoconfig symbols =============================================== +if( + NOT ( + CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE OR + CPPTRACE_GET_SYMBOLS_WITH_LIBDL OR + CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE OR + CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF OR + CPPTRACE_GET_SYMBOLS_WITH_DBGHELP OR + CPPTRACE_GET_SYMBOLS_WITH_NOTHING + ) +) + if(UNIX) + message(STATUS "Cpptrace auto config: Using libdwarf for symbols") + set(CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF On) + elseif(MINGW) + message(STATUS "Cpptrace auto config: Using libdwarf + dbghelp for symbols") + # Use both dbghelp and libdwarf under mingw: Some files may use pdb symbols, e.g. system dlls like KERNEL32.dll and + # ntdll.dll at the very least, but also other libraries linked with may have pdb symbols. + set(CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF On) + set(CPPTRACE_GET_SYMBOLS_WITH_DBGHELP On) + else() + message(STATUS "Cpptrace auto config: Using dbghelp for symbols") + set(CPPTRACE_GET_SYMBOLS_WITH_DBGHELP On) + endif() +endif() + +# =============================================== Autoconfig demangling =============================================== +# Handle demangle configuration +if( + NOT ( + CPPTRACE_DEMANGLE_WITH_CXXABI OR + CPPTRACE_DEMANGLE_WITH_WINAPI OR + CPPTRACE_DEMANGLE_WITH_NOTHING + ) +) + if(HAS_CXXABI) + message(STATUS "Cpptrace auto config: Using cxxabi for demangling") + set(CPPTRACE_DEMANGLE_WITH_CXXABI On) + elseif(WIN32 AND NOT MINGW) + message(STATUS "Cpptrace auto config: Using dbghelp for demangling") + set(CPPTRACE_DEMANGLE_WITH_WINAPI On) + else() + set(CPPTRACE_DEMANGLE_WITH_NOTHING On) + endif() +else() + #message(STATUS "Manual demangling back-end specified") +endif() + +# =============================================== Now define the library =============================================== + +# Target that we can modify (can't modify ALIAS targets) +# Target name should not be the same as ${PROJECT_NAME}, causes add_subdirectory issues +set(target_name "cpptrace-lib") +add_library(${target_name} ${build_type}) + +# Alias to cause error at configuration time instead of link time if target is missing +add_library(cpptrace::cpptrace ALIAS ${target_name}) + +# Add /include files to target +# This is solely for IDE benefit, doesn't affect building +target_sources( + ${target_name} PRIVATE + include/cpptrace/cpptrace.hpp + include/ctrace/ctrace.h +) + +# add /src files to target +target_sources( + ${target_name} PRIVATE + # src + src/binary/elf.cpp + src/binary/mach-o.cpp + src/binary/module_base.cpp + src/binary/object.cpp + src/binary/pe.cpp + src/binary/safe_dl.cpp + src/cpptrace.cpp + src/ctrace.cpp + src/from_current.cpp + src/demangle/demangle_with_cxxabi.cpp + src/demangle/demangle_with_nothing.cpp + src/demangle/demangle_with_winapi.cpp + src/snippets/snippet.cpp + src/symbols/dwarf/debug_map_resolver.cpp + src/symbols/dwarf/dwarf_resolver.cpp + src/symbols/symbols_core.cpp + src/symbols/symbols_with_addr2line.cpp + src/symbols/symbols_with_dbghelp.cpp + src/symbols/symbols_with_dl.cpp + src/symbols/symbols_with_libbacktrace.cpp + src/symbols/symbols_with_libdwarf.cpp + src/symbols/symbols_with_nothing.cpp + src/unwind/unwind_with_dbghelp.cpp + src/unwind/unwind_with_execinfo.cpp + src/unwind/unwind_with_libunwind.cpp + src/unwind/unwind_with_nothing.cpp + src/unwind/unwind_with_unwind.cpp + src/unwind/unwind_with_winapi.cpp +) + +target_include_directories( + ${target_name} + PUBLIC + $ + $ +) + +target_include_directories( + ${target_name} + PRIVATE + src +) + +set( + warning_options + $<$>:-Wall -Wextra -Werror=return-type -Wundef> + $<$:-Wuseless-cast -Wmaybe-uninitialized> + $<$:/W4 /permissive-> +) + +if(CPPTRACE_WERROR_BUILD) + set( + warning_options + ${warning_options} + $<$>:-Werror -Wpedantic> + $<$:/WX> + ) +endif() + +target_compile_options( + ${target_name} + PRIVATE + ${warning_options} +) + +# ---- Generate Build Info Headers ---- + +if(build_type STREQUAL "STATIC") + target_compile_definitions(${target_name} PUBLIC CPPTRACE_STATIC_DEFINE) + set(CPPTRACE_STATIC_DEFINE TRUE) +endif() + +# ---- Library Properties ---- + +# Hide all symbols by default +# Use SameMajorVersion versioning for shared library runtime linker lookup +set_target_properties( + ${target_name} PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN YES + VERSION "${PROJECT_VERSION}" + SOVERSION "${PROJECT_VERSION_MAJOR}" + EXPORT_NAME "cpptrace" + OUTPUT_NAME "cpptrace" + POSITION_INDEPENDENT_CODE ${CPPTRACE_POSITION_INDEPENDENT_CODE} +) + +# Header files generated by CMake +target_include_directories( + ${target_name} SYSTEM PUBLIC + "$" +) + +# Header files from /include +target_include_directories( + ${target_name} ${warning_guard} PUBLIC + "$" +) + +# Require C++11 support +target_compile_features( + ${target_name} + PRIVATE cxx_std_11 +) + +target_compile_definitions(${target_name} PRIVATE NOMINMAX) + +if(NOT CPPTRACE_STD_FORMAT) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_NO_STD_FORMAT) +endif() + +if(CPPTRACE_UNPREFIXED_TRY_CATCH) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_UNPREFIXED_TRY_CATCH) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + SET(CMAKE_C_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") + SET(CMAKE_CXX_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") +endif() + +# =============================================== Apply options to build =============================================== + +if(HAS_CXX_EXCEPTION_TYPE) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_HAS_CXX_EXCEPTION_TYPE) +endif() + +if(HAS_DL_FIND_OBJECT) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_HAS_DL_FIND_OBJECT) +endif() + +if(HAS_DLADDR1) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_HAS_DLADDR1) +endif() + +if(HAS_MACH_VM) + target_compile_definitions(${target_name} PUBLIC HAS_MACH_VM) +endif() + +# Symbols +if(CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE) + if(NOT HAS_BACKTRACE) + if(NOT "${CPPTRACE_BACKTRACE_PATH}" STREQUAL "") + message(WARNING "Cpptrace: Using libbacktrace for symbols but libbacktrace doesn't appear installed or configured properly.") + else() + message(WARNING "Cpptrace: Using libbacktrace for symbols but libbacktrace doesn't appear installed or configured properly. You may need to specify CPPTRACE_BACKTRACE_PATH.") + endif() + endif() + target_compile_definitions(${target_name} PUBLIC CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE) + target_link_libraries(${target_name} PRIVATE backtrace ${CMAKE_DL_LIBS}) +endif() + +if(CPPTRACE_GET_SYMBOLS_WITH_LIBDL) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_GET_SYMBOLS_WITH_LIBDL) + target_link_libraries(${target_name} PRIVATE ${CMAKE_DL_LIBS}) +endif() + +if(CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE) + # set(CPPTRACE_ADDR2LINE_PATH "" CACHE STRING "Absolute path to the addr2line executable you want to use.") + # option(CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH "" OFF) + if(CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH) + else() + if("${CPPTRACE_ADDR2LINE_PATH}" STREQUAL "") + if(APPLE) + find_program(CPPTRACE_ADDR2LINE_PATH_FINAL atos PATHS ENV PATH REQUIRED) + else() + find_program(CPPTRACE_ADDR2LINE_PATH_FINAL addr2line PATHS ENV PATH REQUIRED) + endif() + else() + set(CPPTRACE_ADDR2LINE_PATH_FINAL "${CPPTRACE_ADDR2LINE_PATH}") + endif() + message(STATUS "Cpptrace: Using ${CPPTRACE_ADDR2LINE_PATH_FINAL} for addr2line path") + target_compile_definitions(${target_name} PUBLIC CPPTRACE_ADDR2LINE_PATH="${CPPTRACE_ADDR2LINE_PATH_FINAL}") + endif() + target_compile_definitions(${target_name} PUBLIC CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE) + if(UNIX) + target_link_libraries(${target_name} PRIVATE ${CMAKE_DL_LIBS}) + endif() +endif() + +if(CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF) + if(CPPTRACE_USE_EXTERNAL_LIBDWARF) + if(NOT CPPTRACE_FIND_LIBDWARF_WITH_PKGCONFIG) + find_package(libdwarf REQUIRED) + else() + find_package(PkgConfig) + pkg_check_modules(LIBDWARF REQUIRED libdwarf) + endif() + else() + include(FetchContent) + # First, dependencies: Zstd and zlib (currently relying on system zlib) + if(CPPTRACE_USE_EXTERNAL_ZSTD) + find_package(zstd) + else() + cmake_policy(SET CMP0074 NEW) + set(ZSTD_BUILD_PROGRAMS OFF) + set(ZSTD_BUILD_CONTRIB OFF) + set(ZSTD_BUILD_TESTS OFF) + set(ZSTD_BUILD_STATIC ON) + set(ZSTD_BUILD_SHARED OFF) + set(ZSTD_LEGACY_SUPPORT OFF) + FetchContent_Declare( + zstd + SOURCE_SUBDIR build/cmake + DOWNLOAD_EXTRACT_TIMESTAMP TRUE + URL "${CPPTRACE_ZSTD_URL}" + ) + FetchContent_MakeAvailable(zstd) + endif() + # Libdwarf itself + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + set(PIC_ALWAYS TRUE) + set(BUILD_DWARFDUMP FALSE) + FetchContent_Declare( + libdwarf + GIT_REPOSITORY ${CPPTRACE_LIBDWARF_REPO} + GIT_TAG ${CPPTRACE_LIBDWARF_TAG} + GIT_SHALLOW ${CPPTRACE_LIBDWARF_SHALLOW} + ) + FetchContent_MakeAvailable(libdwarf) + target_include_directories( + dwarf + PRIVATE + ${zstd_SOURCE_DIR}/lib + ) + endif() + if(CPPTRACE_CONAN) + target_link_libraries(${target_name} PRIVATE libdwarf::libdwarf) + elseif(CPPTRACE_VCPKG) + target_link_libraries(${target_name} PRIVATE libdwarf::dwarf) + elseif(CPPTRACE_USE_EXTERNAL_LIBDWARF) + if(DEFINED LIBDWARF_LIBRARIES) + target_link_libraries(${target_name} PRIVATE ${LIBDWARF_LIBRARIES}) + else() + # if LIBDWARF_LIBRARIES wasn't set by find_package, try looking for libdwarf::dwarf-static, + # libdwarf::dwarf-shared, libdwarf::dwarf, then libdwarf + # libdwarf v0.8.0 installs with the target libdwarf::dwarf somehow, despite creating libdwarf::dwarf-static or + # libdwarf::dwarf-shared under fetchcontent + if(TARGET libdwarf::dwarf-static) + set(LIBDWARF_LIBRARIES libdwarf::dwarf-static) + elseif(TARGET libdwarf::dwarf-shared) + set(LIBDWARF_LIBRARIES libdwarf::dwarf-shared) + elseif(TARGET libdwarf::dwarf) + set(LIBDWARF_LIBRARIES libdwarf::dwarf) + elseif(TARGET libdwarf) + set(LIBDWARF_LIBRARIES libdwarf) + else() + message(FATAL_ERROR "Couldn't find libdwarf target name to link against") + endif() + target_link_libraries(${target_name} PRIVATE ${LIBDWARF_LIBRARIES}) + endif() + # There seems to be no consistency at all about where libdwarf decides to place its headers........ Figure out if + # it's libdwarf/libdwarf.h and libdwarf/dwarf.h or just libdwarf.h and dwarf.h + include(CheckIncludeFileCXX) + # libdwarf's cmake doesn't properly set variables to indicate where its libraries live + if(NOT CPPTRACE_FIND_LIBDWARF_WITH_PKGCONFIG) + get_target_property(LIBDWARF_INCLUDE_DIRS ${LIBDWARF_LIBRARIES} INTERFACE_INCLUDE_DIRECTORIES) + else() + target_include_directories(${target_name} PRIVATE ${LIBDWARF_INCLUDE_DIRS}) + endif() + set(CMAKE_REQUIRED_INCLUDES ${LIBDWARF_INCLUDE_DIRS}) + CHECK_INCLUDE_FILE_CXX("libdwarf/libdwarf.h" LIBDWARF_IS_NESTED) + CHECK_INCLUDE_FILE_CXX("libdwarf.h" LIBDWARF_IS_NOT_NESTED) + # check_include_file("libdwarf/libdwarf.h" LIBDWARF_IS_NESTED) + # check_support(LIBDWARF_IS_NESTED nested_libdwarf_include.cpp "" "" "") + if(${LIBDWARF_IS_NESTED}) + target_compile_definitions(${target_name} PRIVATE CPPTRACE_USE_NESTED_LIBDWARF_HEADER_PATH) + elseif(NOT LIBDWARF_IS_NOT_NESTED) + message(FATAL_ERROR "Couldn't find libdwarf.h") + endif() + else() + target_link_libraries(${target_name} PRIVATE libdwarf::dwarf-static) + endif() + if(UNIX) + target_link_libraries(${target_name} PRIVATE ${CMAKE_DL_LIBS}) + endif() +endif() + +if(CPPTRACE_GET_SYMBOLS_WITH_DBGHELP) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_GET_SYMBOLS_WITH_DBGHELP) + target_link_libraries(${target_name} PRIVATE dbghelp) +endif() + +if(CPPTRACE_GET_SYMBOLS_WITH_NOTHING) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_GET_SYMBOLS_WITH_NOTHING) +endif() + +# Unwinding +if(CPPTRACE_UNWIND_WITH_UNWIND) + if(NOT HAS_UNWIND) + message(WARNING "Cpptrace: CPPTRACE_UNWIND_WITH_UNWIND specified but libgcc unwind doesn't seem to be available.") + endif() + target_compile_definitions(${target_name} PUBLIC CPPTRACE_UNWIND_WITH_UNWIND) +endif() + +if(CPPTRACE_UNWIND_WITH_LIBUNWIND) + find_package(PkgConfig) + if(PkgConfig_FOUND) + pkg_check_modules(LIBUNWIND QUIET libunwind) + if(libunwind_FOUND) + target_compile_options(${target_name} PRIVATE ${LIBUNWIND_CFLAGS_OTHER}) + target_include_directories(${target_name} PRIVATE ${LIBUNWIND_INCLUDE_DIRS}) + target_link_libraries(${target_name} PRIVATE ${LIBUNWIND_LDFLAGS}) + endif() + endif() + if(NOT libunwind_FOUND) + if (NOT APPLE) + # set_property(GLOBAL PROPERTY FIND_LIBRARY_USE_LIB64_PATHS ON) + # set_property(GLOBAL PROPERTY FIND_LIBRARY_USE_LIB32_PATHS ON) + find_path(LIBUNWIND_INCLUDE_DIRS NAMES "libunwind.h") + find_library(LIBUNWIND NAMES unwind libunwind libunwind8 libunwind.so.8 REQUIRED PATHS "/usr/lib/x86_64-linux-gnu/") + if(LIBUNWIND) + set(libunwind_FOUND TRUE) + endif() + if(NOT libunwind_FOUND) + # message(FATAL_ERROR "Unable to locate libunwind") + # Try to link with it if it's where it should be + # This path can be entered if libunwind was installed via the system package manager, sometimes. I probably messed + # up the find_library above. + set(LIBUNWIND_LDFLAGS "-lunwind") + endif() + if(NOT LIBUNWIND_LDFLAGS) + set(LIBUNWIND_LDFLAGS "${LIBUNWIND}") + endif() + target_compile_options(${target_name} PRIVATE ${LIBUNWIND_CFLAGS_OTHER}) + target_include_directories(${target_name} PRIVATE ${LIBUNWIND_INCLUDE_DIRS}) + target_link_libraries(${target_name} PRIVATE ${LIBUNWIND_LDFLAGS}) + endif() + target_compile_definitions(${target_name} PUBLIC CPPTRACE_UNWIND_WITH_LIBUNWIND UNW_LOCAL_ONLY) + endif() +endif() + +if(CPPTRACE_UNWIND_WITH_EXECINFO) + if(NOT HAS_EXECINFO) + message(WARNING "Cpptrace: CPPTRACE_UNWIND_WITH_EXECINFO specified but execinfo.h doesn't seem to be available.") + endif() + target_compile_definitions(${target_name} PUBLIC CPPTRACE_UNWIND_WITH_EXECINFO) +endif() + +if(CPPTRACE_UNWIND_WITH_WINAPI) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_UNWIND_WITH_WINAPI) +endif() + +if(CPPTRACE_UNWIND_WITH_DBGHELP) + if(NOT HAS_STACKWALK) + message(WARNING "Cpptrace: CPPTRACE_UNWIND_WITH_DBGHELP specified but dbghelp stackwalk64 doesn't seem to be available.") + endif() + target_compile_definitions(${target_name} PUBLIC CPPTRACE_UNWIND_WITH_DBGHELP) + target_link_libraries(${target_name} PRIVATE dbghelp) +endif() + +if(CPPTRACE_UNWIND_WITH_NOTHING) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_UNWIND_WITH_NOTHING) +endif() + +# Demangling +if(CPPTRACE_DEMANGLE_WITH_CXXABI) + if(NOT HAS_CXXABI) + message(WARNING "Cpptrace: CPPTRACE_DEMANGLE_WITH_CXXABI specified but cxxabi.h doesn't seem to be available.") + endif() + target_compile_definitions(${target_name} PUBLIC CPPTRACE_DEMANGLE_WITH_CXXABI) +endif() + +if(CPPTRACE_DEMANGLE_WITH_WINAPI) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_DEMANGLE_WITH_WINAPI) + target_link_libraries(${target_name} PRIVATE dbghelp) +endif() + +if(CPPTRACE_DEMANGLE_WITH_NOTHING) + target_compile_definitions(${target_name} PUBLIC CPPTRACE_DEMANGLE_WITH_NOTHING) +endif() + +if(NOT "${CPPTRACE_BACKTRACE_PATH}" STREQUAL "") + target_compile_definitions(${target_name} PUBLIC CPPTRACE_BACKTRACE_PATH=${CPPTRACE_BACKTRACE_PATH}) +endif() + +if(NOT "${CPPTRACE_HARD_MAX_FRAMES}" STREQUAL "") + target_compile_definitions(${target_name} PUBLIC CPPTRACE_HARD_MAX_FRAMES=${CPPTRACE_HARD_MAX_FRAMES}) +endif() + +# =============================================== Install =============================================== + +if(NOT CMAKE_SKIP_INSTALL_RULES) + include(cmake/InstallRules.cmake) +endif() + +# =============================================== Demo/test =============================================== + +if(CPPTRACE_BUILD_TESTING) + if(PROJECT_IS_TOP_LEVEL) + enable_testing() + endif() + add_subdirectory(test) +endif() + +if(CPPTRACE_BUILD_BENCHMARKING) + add_subdirectory(benchmarking) +endif() diff --git a/dep/cpptrace/LICENSE b/dep/cpptrace/LICENSE new file mode 100644 index 00000000000..299bf1fac8c --- /dev/null +++ b/dep/cpptrace/LICENSE @@ -0,0 +1,18 @@ +The MIT License (MIT) + +Copyright (c) 2023-2024 Jeremy Rifkin + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES +OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/dep/cpptrace/README.md b/dep/cpptrace/README.md new file mode 100644 index 00000000000..191f2ee946a --- /dev/null +++ b/dep/cpptrace/README.md @@ -0,0 +1,43 @@ +# jeremy-rifkin/cpptrace + +> Cpptrace is a simple, portable, and self-contained C++ stacktrace library supporting C++11 and greater on Linux, macOS, and Windows including MinGW and Cygwin environments. The goal: Make stack traces simple for once. + +It is used in MaNGOS to print the stack trace on failure. + +## Source +Commit: https://github.com/jeremy-rifkin/cpptrace/commit/54a3e6fdf7837c44d20436c77d3469f4524bf6a1 +Date: 2024-09-17T13:25:23Z + +## Copied files +``` +cmake/* +include/* +src/* +CMakeLists.txt +LICENSE +README.md -> README_original.md +``` + +## Manual changes +### Changed `cmake_minimum_required` to `3.12`. +```diff +--- a/dep/cpptrace/CMakeLists.txt ++++ b/dep/cpptrace/CMakeLists.txt +@@ -1,4 +1,4 @@ +-cmake_minimum_required(VERSION 3.14) ++cmake_minimum_required(VERSION 3.12) +``` + +### Adjusted InstallRules for older cmake version +```diff +--- a/dep/cpptrace/cmake/InstallRules.cmake ++++ b/dep/cpptrace/cmake/InstallRules.cmake +@@ -25,6 +25,7 @@ install( + COMPONENT ${package_name}-development + INCLUDES # + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" ++ ARCHIVE DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" + ) + + # create config file that points to targets file +``` diff --git a/dep/cpptrace/README_original.md b/dep/cpptrace/README_original.md new file mode 100644 index 00000000000..ebfe37e2c92 --- /dev/null +++ b/dep/cpptrace/README_original.md @@ -0,0 +1,1197 @@ +# Cpptrace + +[![build](https://github.com/jeremy-rifkin/cpptrace/actions/workflows/build.yml/badge.svg?branch=main)](https://github.com/jeremy-rifkin/cpptrace/actions/workflows/build.yml) +[![test](https://github.com/jeremy-rifkin/cpptrace/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/jeremy-rifkin/cpptrace/actions/workflows/test.yml) +[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=jeremy-rifkin_cpptrace&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=jeremy-rifkin_cpptrace) +
+[![Community Discord Link](https://img.shields.io/badge/Chat%20on%20the%20(very%20small)-Community%20Discord-blue?labelColor=2C3239&color=7289DA&style=flat&logo=discord&logoColor=959DA5)](https://discord.gg/frjaAZvqUZ) +
+[![Try on Compiler Explorer](https://img.shields.io/badge/-Compiler%20Explorer-brightgreen?logo=&labelColor=2C3239&style=flat&label=Try+it+on&color=30C452)](https://godbolt.org/z/c6TqTzqcf) + +Cpptrace is a simple, portable, and self-contained C++ stacktrace library supporting C++11 and greater on Linux, macOS, +and Windows including MinGW and Cygwin environments. The goal: Make stack traces simple for once. + +Cpptrace also has a C API, docs [here](docs/c-api.md). + +## Table of Contents + +- [30-Second Overview](#30-second-overview) + - [CMake FetchContent Usage](#cmake-fetchcontent-usage) +- [FAQ](#faq) + - [What about C++23 ``?](#what-about-c23-stacktrace) + - [What does cpptrace have over other C++ stacktrace libraries?](#what-does-cpptrace-have-over-other-c-stacktrace-libraries) +- [In-Depth Documentation](#in-depth-documentation) + - [Prerequisites](#prerequisites) + - [`namespace cpptrace`](#namespace-cpptrace) + - [Stack Traces](#stack-traces) + - [Object Traces](#object-traces) + - [Raw Traces](#raw-traces) + - [Utilities](#utilities) + - [Configuration](#configuration) + - [Traces From All Exceptions](#traces-from-all-exceptions) + - [Removing the `CPPTRACE_` prefix](#removing-the-cpptrace_-prefix) + - [How it works](#how-it-works) + - [Performance](#performance) + - [Traced Exception Objects](#traced-exception-objects) + - [Wrapping std::exceptions](#wrapping-stdexceptions) + - [Exception handling with cpptrace](#exception-handling-with-cpptrace) + - [Signal-Safe Tracing](#signal-safe-tracing) + - [Utility Types](#utility-types) +- [Supported Debug Formats](#supported-debug-formats) +- [Usage](#usage) + - [CMake FetchContent](#cmake-fetchcontent) + - [System-Wide Installation](#system-wide-installation) + - [Local User Installation](#local-user-installation) + - [Use Without CMake](#use-without-cmake) + - [Installation Without Package Managers or FetchContent](#installation-without-package-managers-or-fetchcontent) + - [Package Managers](#package-managers) + - [Conan](#conan) + - [Vcpkg](#vcpkg) +- [Platform Logistics](#platform-logistics) + - [Windows](#windows) + - [macOS](#macos) +- [Library Back-Ends](#library-back-ends) + - [Summary of Library Configurations](#summary-of-library-configurations) +- [Testing Methodology](#testing-methodology) +- [Notes About the Library](#notes-about-the-library) +- [Contributing](#contributing) +- [License](#license) + +# 30-Second Overview + +Generating stack traces is as easy as: + +```cpp +#include + +void trace() { + cpptrace::generate_trace().print(); +} +``` + +![Demo](res/demo.png) + +Cpptrace can also retrieve function inlining information on optimized release builds: + +![Inlining](res/inlining.png) + +Cpptrace provides access to resolved stack traces as well as lightweight raw traces (just addresses) that can be +resolved later: + +```cpp +const auto raw_trace = cpptrace::generate_raw_trace(); +// then later +raw_trace.resolve().print(); +``` + +Cpptrace provides a way to produce stack traces on arbitrary exceptions. More information on this system +[below](#traces-from-all-exceptions). +```cpp +#include +void foo() { + throw std::runtime_error("foo failed"); +} +int main() { + CPPTRACE_TRY { + foo(); + } CPPTRACE_CATCH(const std::exception& e) { + std::cerr<<"Exception: "< + +void trace() { + throw cpptrace::logic_error("This wasn't supposed to happen!"); +} +``` + +![Exception](res/exception.png) + +Additional notable features: + +- Utilities for demangling +- Utilities for catching `std::exception`s and wrapping them in traced exceptions +- Signal-safe stack tracing +- Source code snippets in traces + +![Snippets](res/snippets.png) + +## CMake FetchContent Usage + +```cmake +include(FetchContent) +FetchContent_Declare( + cpptrace + GIT_REPOSITORY https://github.com/jeremy-rifkin/cpptrace.git + GIT_TAG v0.7.1 # +) +FetchContent_MakeAvailable(cpptrace) +target_link_libraries(your_target cpptrace::cpptrace) + +# Needed for shared library builds on windows: copy cpptrace.dll to the same directory as the +# executable for your_target +if(WIN32) + add_custom_command( + TARGET your_target POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $ + ) +endif() +``` + +Be sure to configure with `-DCMAKE_BUILD_TYPE=Debug` or `-DCMAKE_BUILD_TYPE=RelWithDebInfo` for symbols and line +information. + +On macOS it is recommended to generate a `.dSYM` file, see [Platform Logistics](#platform-logistics) below. + +For other ways to use the library, such as through package managers, a system-wide installation, or on a platform +without internet access see [Usage](#usage) below. + +# FAQ + +## What about C++23 ``? + +Some day C++23's `` will be ubiquitous. And maybe one day the msvc implementation will be acceptable. +The original motivation for cpptrace was to support projects using older C++ standards and as the library has grown its +functionality has extended beyond the standard library's implementation. + +Cpptrace provides functionality beyond what the standard library provides and what implementations provide, such as: +- Walking inlined function calls +- Providing a lightweight interface for "raw traces" +- Resolving function parameter types +- Providing traced exception objects +- Providing an API for signal-safe stacktrace generation +- Providing a way to retrieve stack traces from arbitrary exceptions, not just special cpptrace traced exception + objects. This is a feature coming to C++26, but cpptrace provides a solution for C++11. + +## What does cpptrace have over other C++ stacktrace libraries? + +Other C++ stacktrace libraries, such as boost stacktrace and backward-cpp, fall short when it comes to portability and +ease of use. In testing, I found neither to provide adaquate coverage of various environments. Even when they can be +made to work in an environment they require manual configuration from the end-user, possibly requiring manual +installation of third-party dependencies. This is a highly undesirable burden to impose on users, especially when it is +for a software package which just provides diagnostics as opposed to core functionality. Additionally, cpptrace provides +support for resolving inlined calls by default for DWARF symbols (boost does not do this, backward-cpp can do this but +only for some back-ends), better support for resolving full function signatures, and nicer API, among other features. + +# In-Depth Documentation + +## Prerequisites + +> [!IMPORTANT] +> Debug info (`-g`/`/Z7`/`/Zi`/`/DEBUG`/`-DBUILD_TYPE=Debug`/`-DBUILD_TYPE=RelWithDebInfo`) is required for complete +> trace information. + +## `namespace cpptrace` + +`cpptrace::generate_trace()` can be used to generate a stacktrace object at the current call site. Resolved frames can +be accessed from this object with `.frames` and also the trace can be printed with `.print()`. Cpptrace also provides a +method to get lightweight raw traces, which are just vectors of program counters, which can be resolved at a later time. + +All functions are thread-safe unless otherwise noted. + +### Stack Traces + +The core resolved stack trace object. Generate a trace with `cpptrace::generate_trace()` or +`cpptrace::stacktrace::current()`. On top of a set of helper functions `struct stacktrace` allows +direct access to frames as well as iterators. + +`cpptrace::stacktrace::print` can be used to print a stacktrace. `cpptrace::stacktrace::print_with_snippets` can be used +to print a stack trace with source code snippets. + +```cpp +namespace cpptrace { + // Some type sufficient for an instruction pointer, currently always an alias to std::uintptr_t + using frame_ptr = std::uintptr_t; + + struct stacktrace_frame { + frame_ptr raw_address; // address in memory + frame_ptr object_address; // address in the object file + // nullable represents a nullable integer. More docs later. + nullable line; + nullable column; + std::string filename; + std::string symbol; + bool is_inline; + bool operator==(const stacktrace_frame& other) const; + bool operator!=(const stacktrace_frame& other) const; + object_frame get_object_info() const; // object_address is stored but if the object_path is needed this can be used + std::string to_string() const; + /* operator<<(ostream, ..) and std::format support exist for this object */ + }; + + struct stacktrace { + std::vector frames; + // here as a drop-in for std::stacktrace + static stacktrace current(std::size_t skip = 0); + static stacktrace current(std::size_t skip, std::size_t max_depth); + void print() const; + void print(std::ostream& stream) const; + void print(std::ostream& stream, bool color) const; + void print_with_snippets() const; + void print_with_snippets(std::ostream& stream) const; + void print_with_snippets(std::ostream& stream, bool color) const; + std::string to_string(bool color = false) const; + void clear(); + bool empty() const noexcept; + /* operator<<(ostream, ..), std::format support, and iterators exist for this object */ + }; + + stacktrace generate_trace(std::size_t skip = 0); + stacktrace generate_trace(std::size_t skip, std::size_t max_depth); +} +``` + +### Object Traces + +Object traces contain the most basic information needed to construct a stack trace outside the currently running +executable. It contains the raw address, the address in the binary (ASLR and the object file's memory space and whatnot +is resolved), and the path to the object the instruction pointer is located in. + +```cpp +namespace cpptrace { + struct object_frame { + std::string object_path; + frame_ptr raw_address; + frame_ptr object_address; + }; + + struct object_trace { + std::vector frames; + static object_trace current(std::size_t skip = 0); + static object_trace current(std::size_t skip, std::size_t max_depth); + stacktrace resolve() const; + void clear(); + bool empty() const noexcept; + /* iterators exist for this object */ + }; + + object_trace generate_object_trace(std::size_t skip = 0); + object_trace generate_object_trace(std::size_t skip, std::size_t max_depth); +} +``` + +### Raw Traces + +Raw trace access: A vector of program counters. These are ideal for fast and cheap traces you want to resolve later. + +Note it is important executables and shared libraries in memory aren't somehow unmapped otherwise libdl calls (and +`GetModuleFileName` in windows) will fail to figure out where the program counter corresponds to. + +```cpp +namespace cpptrace { + struct raw_trace { + std::vector frames; + static raw_trace current(std::size_t skip = 0); + static raw_trace current(std::size_t skip, std::size_t max_depth); + object_trace resolve_object_trace() const; + stacktrace resolve() const; + void clear(); + bool empty() const noexcept; + /* iterators exist for this object */ + }; + + raw_trace generate_raw_trace(std::size_t skip = 0); + raw_trace generate_raw_trace(std::size_t skip, std::size_t max_depth); +} +``` + +### Utilities + +`cpptrace::demangle` provides a helper function for name demangling, since it has to implement that helper internally +anyways. + +`cpptrace::get_snippet` gets a text snippet, if possible, from for the given source file for +/- `context_size` lines +around `line`. + +`cpptrace::isatty` and the fileno definitions are useful for deciding whether to use color when printing stack traces. + +`cpptrace::register_terminate_handler()` is a helper function to set a custom `std::terminate` handler that prints a +stack trace from a cpptrace exception (more info below) and otherwise behaves like the normal terminate handler. + +```cpp +namespace cpptrace { + std::string demangle(const std::string& name); + std::string get_snippet( + const std::string& path, + std::size_t line, + std::size_t context_size, + bool color = false + ); + bool isatty(int fd); + + extern const int stdin_fileno; + extern const int stderr_fileno; + extern const int stdout_fileno; + + void register_terminate_handler(); +} +``` + +### Configuration + +`cpptrace::absorb_trace_exceptions`: Configure whether the library silently absorbs internal exceptions and continues. +Default is true. + +`cpptrace::ctrace_enable_inlined_call_resolution`: Configure whether the library will attempt to resolve inlined call +information for release builds. Default is true. + +`cpptrace::experimental::set_cache_mode`: Control time-memory tradeoffs within the library. By default speed is +prioritized. If using this function, set the cache mode at the very start of your program before any traces are +performed. + +```cpp +namespace cpptrace { + void absorb_trace_exceptions(bool absorb); + void ctrace_enable_inlined_call_resolution(bool enable); + + enum class cache_mode { + // Only minimal lookup tables + prioritize_memory, + // Build lookup tables but don't keep them around between trace calls + hybrid, + // Build lookup tables as needed + prioritize_speed + }; + + namespace experimental { + void set_cache_mode(cache_mode mode); + } +} +``` + +### Traces From All Exceptions + +Cpptrace provides `CPPTRACE_TRY` and `CPPTRACE_CATCH` macros that allow a stack trace to be collected from the current +thrown exception object, with minimal or no overhead in the non-throwing path: + +```cpp +#include +void foo() { + throw std::runtime_error("foo failed"); +} +int main() { + CPPTRACE_TRY { + foo(); + } CPPTRACE_CATCH(const std::exception& e) { + std::cerr<<"Exception: "<`. + +Any declarator `catch` accepts works with `CPPTRACE_CATCH`, including `...`. This works with any thrown object, not just +`std::exceptions`, it even works with `throw 0;` + +![from_current](res/from_current.png) + +There are a few extraneous frames at the top of the stack corresponding to standard library exception handling +internals. These are a small price to pay for stack traces on all exceptions. + +API functions: +- `cpptrace::raw_trace_from_current_exception`: Returns `const raw_trace&` from the current exception. +- `cpptrace::from_current_exception`: Returns a resolved `const stacktrace&` from the current exception. Invalidates + references to traces returned by `cpptrace::raw_trace_from_current_exception`. + +There is a performance tradeoff with this functionality: Either the try-block can be zero overhead in the +non-throwing path with potential expense in the throwing path, or the try-block can have very minimal overhead +in the non-throwing path due to bookkeeping with guarantees about the expense of the throwing path. More details on +this tradeoff [below](#performance). Cpptrace provides macros for both sides of this tradeoff: +- `CPPTRACE_TRY`/`CPPTRACE_CATCH`: Minimal overhead in the non-throwing path (one `mov` on x86, and this may be + optimized out if the compiler is able) +- `CPPTRACE_TRYZ`/`CPPTRACE_CATCHZ`: Zero overhead in the non-throwing path, potential extra cost in the throwing path + +Note: It's important to not mix the `Z` variants with the non-`Z` variants. + +Unfortunately the try/catch macros are needed to insert some magic to perform a trace during the unwinding search phase. +In order to have multiple catch alternatives, either `CPPTRACE_CATCH_ALT` or a normal `catch` must be used: +```cpp +CPPTRACE_TRY { + foo(); +} CPPTRACE_CATCH(const std::exception&) { // <- First catch must be CPPTRACE_CATCH + // ... +} CPPTRACE_CATCH_ALT(const std::exception&) { // <- Ok + // ... +} catch(const std::exception&) { // <- Also Ok + // ... +} CPPTRACE_CATCH(const std::exception&) { // <- Not Ok + // ... +} +``` + +Note: The current exception is the exception most recently seen by a cpptrace try-catch macro block. + +```cpp +CPPTRACE_TRY { + throw std::runtime_error("foo"); +} CPPTRACE_CATCH(const std::exception& e) { + cpptrace::from_current_exception().print(); // the trace for std::runtime_error("foo") + CPPTRACE_TRY { + throw std::runtime_error("bar"); + } CPPTRACE_CATCH(const std::exception& e) { + cpptrace::from_current_exception().print(); // the trace for std::runtime_error("bar") + } + cpptrace::from_current_exception().print(); // the trace for std::runtime_error("bar"), again +} +``` + +#### Removing the `CPPTRACE_` prefix + +`CPPTRACE_TRY` is a little cumbersome to type. To remove the `CPPTRACE_` prefix you can use the +`CPPTRACE_UNPREFIXED_TRY_CATCH` cmake option or the `CPPTRACE_UNPREFIXED_TRY_CATCH` preprocessor definition: + +```cpp +TRY { + foo(); +} CATCH(const std::exception& e) { + std::cerr<<"Exception: "< [!TIP] +> The choice between the `Z` and non-`Z` (zero-overhead and non-zero-overhead) variants of the exception handlers should +> not matter 99% of the time, however, both are provided in the rare case that it does. +> +> `CPPTRACE_TRY`/`CPPTRACE_CATCH` could only hurt performance if used in a hot loop where the compiler can't optimize +> away the internal bookkeeping, otherwise the bookkeeping should be completely negligible. +> +> `CPPTRACE_TRYZ`/`CPPTRACE_CATCHZ` could only hurt performance when there is an exceptionally deep nesting of exception +> handlers in a call stack before a matching handler. + +More information on performance considerations with the zero-overhead variant: + +Tracing the stack multiple times in throwing paths should not matter for the vast majority applications given that: +1. Performance very rarely is critical in throwing paths and exceptions should be exceptionally rare +2. Exception handling is not usually used in such a way that you could have a deep nesting of handlers before finding a + matching handler +3. Most call stacks are fairly shallow + +To put the scale of this performance consideration into perspective: In my benchmarking I have found generation of raw +traces to take on the order of `100ns` per frame. Thus, even if there were 100 non-matching handlers before a matching +handler in a 100-deep call stack the total time would stil be on the order of one millisecond. + +Nonetheless, I chose a default bookkeeping behavior for `CPPTRACE_TRY`/`CPPTRACE_CATCH` since it is safer with better +performance guarantees for the most general possible set of users. + +### Traced Exception Objects + +Cpptrace provides a handful of traced exception classes which automatically collect stack traces when thrown. These +are useful when throwing exceptions that may not be caught by `CPPTRACE_CATCH`. + +The base traced exception class is `cpptrace::exception` and cpptrace provides a handful of helper classes for working +with traced exceptions. These exceptions generate relatively lightweight raw traces and resolve symbols and line numbers +lazily if and when requested. + +These are provided both as a useful utility and as a reference implementation for traced exceptions. + +The basic interface is: +```cpp +namespace cpptrace { + class exception : public std::exception { + public: + virtual const char* what() const noexcept = 0; // The what string both the message and trace + virtual const char* message() const noexcept = 0; + virtual const stacktrace& trace() const noexcept = 0; + }; +} +``` + +There are two ways to go about traced exception objects: Traces can be resolved eagerly or lazily. Cpptrace provides the +basic implementation of exceptions as lazy exceptions. I hate to have anything about the implementation exposed in the +interface or type system but this seems to be the best way to do this. + +```cpp +namespace cpptrace { + class lazy_exception : public exception { + // lazy_trace_holder is basically a std::variant, more docs later + mutable detail::lazy_trace_holder trace_holder; + mutable std::string what_string; + public: + explicit lazy_exception( + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept : trace_holder(std::move(trace)) {} + const char* what() const noexcept override; + const char* message() const noexcept override; + const stacktrace& trace() const noexcept override; + }; +} +``` + +`cpptrace::lazy_exception` can be freely thrown or overridden. Generally `message()` is the only field to override. + +Lastly cpptrace provides an exception class that takes a user-provided message, `cpptrace::exception_with_message`, as +well as a number of traced exception classes resembling ``: + +```cpp +namespace cpptrace { + class exception_with_message : public lazy_exception { + mutable std::string user_message; + public: + explicit exception_with_message( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept : lazy_exception(std::move(trace)), user_message(std::move(message_arg)) {} + const char* message() const noexcept override; + }; + + // All stdexcept errors have analogs here. All but system_error have the constructor: + // explicit the_error( + // std::string&& message_arg, + // raw_trace&& trace = detail::get_raw_trace_and_absorb() + // ) noexcept + // : exception_with_message(std::move(message_arg), std::move(trace)) {} + class logic_error : public exception_with_message { ... }; + class domain_error : public exception_with_message { ... }; + class invalid_argument : public exception_with_message { ... }; + class length_error : public exception_with_message { ... }; + class out_of_range : public exception_with_message { ... }; + class runtime_error : public exception_with_message { ... }; + class range_error : public exception_with_message { ... }; + class overflow_error : public exception_with_message { ... }; + class underflow_error : public exception_with_message { ... }; + class system_error : public runtime_error { + public: + explicit system_error( + int error_code, + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept; + const std::error_code& code() const noexcept; + }; +} +``` + +## Wrapping std::exceptions + +> [!NOTE] +> This section is largely obsolete now that cpptrace provides a better mechanism for collecting +> [traces from exceptions](#traces-from-exceptions) + +Cpptrace exceptions can provide great information for user-controlled exceptions. For non-cpptrace::exceptions that may +originate outside of code you control, e.g. the standard library, cpptrace provides some wrapper utilities that can +rethrow these exceptions nested in traced cpptrace exceptions. The trace won't be perfect, the trace will start where +the wrapper caught it, but these utilities can provide good diagnostic information. Unfortunately this is the best +solution for this problem, as far as I know. + +```cpp +std::vector foo = {1, 2, 3}; +CPPTRACE_WRAP_BLOCK( + foo.at(4) = 2; + foo.at(5)++; +); +std::cout< [!NOTE] +> This section pertains to cpptrace traced exception objects and not the mechanism for collecting +> [traces from arbitrary exceptions](#traces-from-exceptions) + +Working with cpptrace exceptions in your code: +```cpp +try { + foo(); +} catch(cpptrace::exception& e) { + // Prints the exception info and stack trace, conditionally enabling color codes depending on + // whether stderr is a terminal + std::cerr << "Error: " << e.message() << '\n'; + e.trace().print(std::cerr, cpptrace::isatty(cpptrace::stderr_fileno)); +} catch(std::exception& e) { + std::cerr << "Error: " << e.what() << '\n'; +} +``` + +Additionally cpptrace provides a custom `std::terminate` handler that prints a stack trace from a cpptrace exception and otherwise behaves like the normal terminate handler and prints the stack trace involved in reaching `std::terminate`. +The stack trace to `std::terminate` may be helpful or it may not, it depends on the implementation, but often if an +implementation can't find an appropriate `catch` while unwinding it will jump directly to `std::terminate` giving +good information. + +To register this custom handler: + +```cpp +cpptrace::register_terminate_handler(); +``` + +## Signal-Safe Tracing + +Signal-safe stack tracing is very useful for debugging application crashes, e.g. SIGSEGVs or +SIGTRAPs, but it's very difficult to do correctly and most implementations I see online do this +incorrectly. + +In order to do this full process safely the way to go is collecting basic information in the signal +handler and then either resolving later or handing that information to another process to resolve. + +It's not as simple as calling `cpptrace::generate_trace().print()`, though you might be able to get +away with that, but this is what is needed to really do this safely. + +The safe API is as follows: + +```cpp +namespace cpptrace { + std::size_t safe_generate_raw_trace(frame_ptr* buffer, std::size_t size, std::size_t skip = 0); + std::size_t safe_generate_raw_trace(frame_ptr* buffer, std::size_t size, std::size_t skip, std::size_t max_depth); + struct safe_object_frame { + frame_ptr raw_address; + frame_ptr address_relative_to_object_start; // object base address must yet be added + char object_path[CPPTRACE_PATH_MAX + 1]; + object_frame resolve() const; // To be called outside a signal handler. Not signal safe. + }; + void get_safe_object_frame(frame_ptr address, safe_object_frame* out); + bool can_signal_safe_unwind(); +} +``` + +> [!IMPORTANT] +> Currently signal-safe stack unwinding is only possible with `libunwind`, which must be +> [manually enabled](#library-back-ends). If signal-safe unwinding isn't supported, `safe_generate_raw_trace` will just +> produce an empty trace. `can_signal_safe_unwind` can be used to check for signal-safe unwinding support. If object +> information can't be resolved in a signal-safe way then `get_safe_object_frame` will not populate fields beyond the +> `raw_address`. + +> [!IMPORTANT] +> `_dl_find_object` is required for signal-safe stack tracing. This is a relatively recent addition to glibc, added in +> glibc 2.35. + +> [!CAUTION] +> Calls to shared objects can be lazy-loaded where the first call to the shared object invokes non-signal-safe functions +> such as `malloc()`. To avoid this, call these routines in `main()` ahead of a signal handler to "warm up" the library. + +Because signal-safe tracing is an involved process, I have written up a comprehensive overview of +what is involved at [signal-safe-tracing.md](docs/signal-safe-tracing.md). + +## Utility Types + +A couple utility types are used to provide the library with a good interface. + +`nullable` is used for a nullable integer type. Internally the maximum value for `T` is used as a +sentinel. `std::optional` would be used if this library weren't c++11. But, `nullable` provides +an `std::optional`-like interface and it's less heavy-duty for this use than an `std::optional`. + +`detail::lazy_trace_holder` is a utility type for `lazy_exception` used in place of an +`std::variant`. + +```cpp +namespace cpptrace { + template::value, int>::type = 0> + struct nullable { + T raw_value; + nullable& operator=(T value) + bool has_value() const noexcept; + T& value() noexcept; + const T& value() const noexcept; + T value_or(T alternative) const noexcept; + void swap(nullable& other) noexcept; + void reset() noexcept; + bool operator==(const nullable& other) const noexcept; + bool operator!=(const nullable& other) const noexcept; + constexpr static nullable null() noexcept; // returns a null instance + }; + + namespace detail { + class lazy_trace_holder { + bool resolved; + union { + raw_trace trace; + stacktrace resolved_trace; + }; + public: + // constructors + lazy_trace_holder() : trace() {} + explicit lazy_trace_holder(raw_trace&& _trace); + explicit lazy_trace_holder(stacktrace&& _resolved_trace); + // logistics + lazy_trace_holder(const lazy_trace_holder& other); + lazy_trace_holder(lazy_trace_holder&& other) noexcept; + lazy_trace_holder& operator=(const lazy_trace_holder& other); + lazy_trace_holder& operator=(lazy_trace_holder&& other) noexcept; + ~lazy_trace_holder(); + // access + const raw_trace& get_raw_trace() const; + stacktrace& get_resolved_trace(); + const stacktrace& get_resolved_trace() const; // throws if not already resolved + private: + void clear(); + }; + } +} +``` + +# Supported Debug Formats + +| Format | Supported | +| --------------------------------- | --------- | +| DWARF in binary | ✔️ | +| GNU debug link | ️️✔️ | +| Split dwarf (debug fission) | ✔️* | +| DWARF in dSYM | ✔️ | +| DWARF via Mach-O debug map | ✔️ | +| Windows debug symbols in PDB | ✔️ | + +*There seem to be a couple issues upstream with libdwarf however they will hopefully be resolved soon. + +DWARF5 added DWARF package files. As far as I can tell no compiler implements these yet. + +# Usage + +## CMake FetchContent + +With CMake FetchContent: + +```cmake +include(FetchContent) +FetchContent_Declare( + cpptrace + GIT_REPOSITORY https://github.com/jeremy-rifkin/cpptrace.git + GIT_TAG v0.7.1 # +) +FetchContent_MakeAvailable(cpptrace) +target_link_libraries(your_target cpptrace::cpptrace) +``` + +It's as easy as that. Cpptrace will automatically configure itself for your system. Note: On windows and macos some +extra work is required, see [Platform Logistics](#platform-logistics) below. + +Be sure to configure with `-DCMAKE_BUILD_TYPE=Debug` or `-DCMAKE_BUILD_TYPE=RelWithDebInfo` for symbols and line +information. + +## System-Wide Installation + +```sh +git clone https://github.com/jeremy-rifkin/cpptrace.git +git checkout v0.7.1 +mkdir cpptrace/build +cd cpptrace/build +cmake .. -DCMAKE_BUILD_TYPE=Release +make -j +sudo make install +``` + +Using through cmake: +```cmake +find_package(cpptrace REQUIRED) +target_link_libraries( cpptrace::cpptrace) +``` +Be sure to configure with `-DCMAKE_BUILD_TYPE=Debug` or `-DCMAKE_BUILD_TYPE=RelWithDebInfo` for symbols and line +information. + +Or compile with `-lcpptrace`: + +```sh +g++ main.cpp -o main -g -Wall -lcpptrace +./main +``` + +> [!IMPORTANT] +> If you aren't using cmake and are linking statically you must manually specify `-DCPPTRACE_STATIC_DEFINE`. + +If you get an error along the lines of +``` +error while loading shared libraries: libcpptrace.so: cannot open shared object file: No such file or directory +``` +You may have to run `sudo /sbin/ldconfig` to create any necessary links and update caches so the system can find +libcpptrace.so (I had to do this on Ubuntu). Only when installing system-wide. Usually your package manager does this for +you when installing new libraries. + +> [!NOTE] +> Libdwarf requires a relatively new version of libdwarf. Sometimes a previously-installed system-wide libdwarf may +> cause issues due to being too old. Libdwarf 8 and newer is known to work. + +
+ System-wide install on windows + +```ps1 +git clone https://github.com/jeremy-rifkin/cpptrace.git +git checkout v0.7.1 +mkdir cpptrace/build +cd cpptrace/build +cmake .. -DCMAKE_BUILD_TYPE=Release +msbuild cpptrace.sln +msbuild INSTALL.vcxproj +``` + +Note: You'll need to run as an administrator in a developer powershell, or use vcvarsall.bat distributed with visual +studio to get the correct environment variables set. +
+ +## Local User Installation + +To install just for the local user (or any custom prefix): + +```sh +git clone https://github.com/jeremy-rifkin/cpptrace.git +git checkout v0.7.1 +mkdir cpptrace/build +cd cpptrace/build +cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=$HOME/wherever +make -j +make install +``` + +Using through cmake: +```cmake +find_package(cpptrace REQUIRED PATHS $ENV{HOME}/wherever) +target_link_libraries( cpptrace::cpptrace) +``` + +Using manually: +``` +g++ main.cpp -o main -g -Wall -I$HOME/wherever/include -L$HOME/wherever/lib -lcpptrace +``` + +> [!IMPORTANT] +> If you aren't using cmake and are linking statically you must manually specify `-DCPPTRACE_STATIC_DEFINE`. + +## Use Without CMake + +To use the library without cmake first follow the installation instructions at +[System-Wide Installation](#system-wide-installation), [Local User Installation](#local-user-installation), +or [Package Managers](#package-managers). + +In addition to any include or library paths you'll need to specify to tell the compiler where cpptrace was installed. +The typical dependencies for cpptrace are: + +| Compiler | Platform | Dependencies | +| ----------------------- | ---------------- | ----------------------------------------- | +| gcc, clang, intel, etc. | Linux/macos/unix | `-lcpptrace -ldwarf -lz -lzstd -ldl` | +| gcc | Windows | `-lcpptrace -ldbghelp -ldwarf -lz -lzstd` | +| msvc | Windows | `cpptrace.lib dbghelp.lib` | +| clang | Windows | `-lcpptrace -ldbghelp` | + +Note: Newer libdwarf requires `-lzstd`, older libdwarf does not. + +> [!IMPORTANT] +> If you are linking statically, you will additionally need to specify `-DCPPTRACE_STATIC_DEFINE`. + +Dependencies may differ if different back-ends are manually selected. + +## Installation Without Package Managers or FetchContent + +Some users may prefer, or need to, to install cpptrace without package managers or fetchcontent (e.g. if their system +does not have internet access). Below are instructions for how to install libdwarf and cpptrace. + +
+ Installation Without Package Managers or FetchContent + +Here is an example for how to build cpptrace and libdwarf. `~/scratch/cpptrace-test` is used as a working directory and +the libraries are installed to `~/scratch/cpptrace-test/resources`. + +```sh +mkdir -p ~/scratch/cpptrace-test/resources + +cd ~/scratch/cpptrace-test +git clone https://github.com/facebook/zstd.git +cd zstd +git checkout 63779c798237346c2b245c546c40b72a5a5913fe +cd build/cmake +mkdir build +cd build +cmake .. -DCMAKE_INSTALL_PREFIX=~/scratch/cpptrace-test/resources -DZSTD_BUILD_PROGRAMS=On -DZSTD_BUILD_CONTRIB=On -DZSTD_BUILD_TESTS=On -DZSTD_BUILD_STATIC=On -DZSTD_BUILD_SHARED=On -DZSTD_LEGACY_SUPPORT=On +make -j +make install + +cd ~/scratch/cpptrace-test +git clone https://github.com/jeremy-rifkin/libdwarf-lite.git +cd libdwarf-lite +git checkout 6dbcc23dba6ffd230063bda4b9d7298bf88d9d41 +mkdir build +cd build +cmake .. -DPIC_ALWAYS=On -DBUILD_DWARFDUMP=Off -DCMAKE_PREFIX_PATH=~/scratch/cpptrace-test/resources -DCMAKE_INSTALL_PREFIX=~/scratch/cpptrace-test/resources +make -j +make install + +cd ~/scratch/cpptrace-test +git clone https://github.com/jeremy-rifkin/cpptrace.git +cd cpptrace +git checkout v0.7.1 +mkdir build +cd build +cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=On -DCPPTRACE_USE_EXTERNAL_LIBDWARF=On -DCMAKE_PREFIX_PATH=~/scratch/cpptrace-test/resources -DCMAKE_INSTALL_PREFIX=~/scratch/cpptrace-test/resources +make -j +make install +``` + +The `~/scratch/cpptrace-test/resources` directory also serves as a bundle you can ship with all the installed files for +cpptrace and its dependencies. + +
+ +## Package Managers + +### Conan + +Cpptrace is available through conan at https://conan.io/center/recipes/cpptrace. +``` +[requires] +cpptrace/0.7.1 +[generators] +CMakeDeps +CMakeToolchain +[layout] +cmake_layout +``` +```cmake +# ... +find_package(cpptrace REQUIRED) +# ... +target_link_libraries(YOUR_TARGET cpptrace::cpptrace) +``` + +### Vcpkg + +``` +vcpkg install cpptrace +``` +```cmake +find_package(cpptrace CONFIG REQUIRED) +target_link_libraries(main PRIVATE cpptrace::cpptrace) +``` + +# Platform Logistics + +Windows and macOS require a little extra work to get everything in the right place. + +## Windows + +Copying the library `.dll` on Windows: + +```cmake +# Copy the cpptrace.dll on windows to the same directory as the executable for your_target. +# Not required if static linking. +if(WIN32) + add_custom_command( + TARGET your_target POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $ + ) +endif() +``` + +## macOS + +On macOS, it is recommended to generate a `dSYM` file containing debug information for your program. +This is not required as cpptrace makes a good effort at finding and reading the debug information +without this, but having a `dSYM` file is the most robust method. + +When using Xcode with CMake, this can be done with: + +```cmake +set_target_properties(your_target PROPERTIES XCODE_ATTRIBUTE_DEBUG_INFORMATION_FORMAT "dwarf-with-dsym") +``` + +Outside of Xcode, this can be done with `dsymutil yourbinary`: + +```cmake +# Create a .dSYM file on macOS +if(APPLE) + add_custom_command( + TARGET your_target + POST_BUILD + COMMAND dsymutil $ + ) +endif() +``` + +# Library Back-Ends + +Cpptrace supports a number of back-ends to produce stack traces. Stack traces are produced in roughly three steps: +Unwinding, symbol resolution, and demangling. + +The library's CMake automatically configures itself for what your system supports. The ideal configuration is as +follows: + +| Platform | Unwinding | Symbols | Demangling | +| -------- | ------------------------------------------------------- | ------------------ | -------------------- | +| Linux | `_Unwind` | libdwarf | cxxabi.h | +| MacOS | `_Unwind` for gcc, execinfo.h for clang and apple clang | libdwarf | cxxabi.h | +| Windows | `StackWalk64` | dbghelp | No demangling needed | +| MinGW | `StackWalk64` | libdwarf + dbghelp | cxxabi.h | + +Support for these back-ends is the main development focus and they should work well. If you want to use a different +back-end such as addr2line, for example, you can configure the library to do so. + +**Unwinding** + +| Library | CMake config | Platforms | Info | +| ------------- | -------------------------------- | ---------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| libgcc unwind | `CPPTRACE_UNWIND_WITH_UNWIND` | linux, macos, mingw | Frames are captured with libgcc's `_Unwind_Backtrace`, which currently produces the most accurate stack traces on gcc/clang/mingw. Libgcc is often linked by default, and llvm has something equivalent. | +| execinfo.h | `CPPTRACE_UNWIND_WITH_EXECINFO` | linux, macos | Frames are captured with `execinfo.h`'s `backtrace`, part of libc on linux/unix systems. | +| winapi | `CPPTRACE_UNWIND_WITH_WINAPI` | windows, mingw | Frames are captured with `CaptureStackBackTrace`. | +| dbghelp | `CPPTRACE_UNWIND_WITH_DBGHELP` | windows, mingw | Frames are captured with `StackWalk64`. | +| libunwind | `CPPTRACE_UNWIND_WITH_LIBUNWIND` | linux, macos, windows, mingw | Frames are captured with [libunwind](https://github.com/libunwind/libunwind). **Note:** This is the only back-end that requires a library to be installed by the user, and a `CMAKE_PREFIX_PATH` may also be needed. | +| N/A | `CPPTRACE_UNWIND_WITH_NOTHING` | all | Unwinding is not done, stack traces will be empty. | + +Some back-ends (execinfo and `CaptureStackBackTrace`) require a fixed buffer has to be created to read addresses into +while unwinding. By default the buffer can hold addresses for 400 frames (beyond the `skip` frames). This is +configurable with `CPPTRACE_HARD_MAX_FRAMES`. + +**Symbol resolution** + +| Library | CMake config | Platforms | Info | +| ------------ | ---------------------------------------- | --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| libdwarf | `CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF` | linux, macos, mingw | Libdwarf is the preferred method for symbol resolution for cpptrace. Cpptrace will get it via FetchContent or find_package depending on `CPPTRACE_USE_EXTERNAL_LIBDWARF`. | +| dbghelp | `CPPTRACE_GET_SYMBOLS_WITH_DBGHELP` | windows | Dbghelp.h is the preferred method for symbol resolution on windows under msvc/clang and is supported on all windows machines. | +| libbacktrace | `CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE` | linux, macos*, mingw* | Libbacktrace is already installed on most systems or available through the compiler directly. For clang you must specify the absolute path to `backtrace.h` using `CPPTRACE_BACKTRACE_PATH`. | +| addr2line | `CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE` | linux, macos, mingw | Symbols are resolved by invoking `addr2line` (or `atos` on mac) via `fork()` (on linux/unix, and `popen` under mingw). | +| libdl | `CPPTRACE_GET_SYMBOLS_WITH_LIBDL` | linux, macos | Libdl uses dynamic export information. Compiling with `-rdynamic` is needed for symbol information to be retrievable. Line numbers won't be retrievable. | +| N/A | `CPPTRACE_GET_SYMBOLS_WITH_NOTHING` | all | No attempt is made to resolve symbols. | + +*: Requires installation + +One back-end should be used. For MinGW `CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF` and `CPPTRACE_GET_SYMBOLS_WITH_DBGHELP` can +be used in conjunction. + +Note for addr2line: By default cmake will resolve an absolute path to addr2line to bake into the library. This path can +be configured with `CPPTRACE_ADDR2LINE_PATH`, or `CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH` can be used to have the library +search the system path for `addr2line` at runtime. This is not the default to prevent against path injection attacks. + +**Demangling** + +Lastly, depending on other back-ends used a demangler back-end may be needed. + +| Library | CMake config | Platforms | Info | +| --------- | -------------------------------- | ------------------- | ---------------------------------------------------------------------------------- | +| cxxabi.h | `CPPTRACE_DEMANGLE_WITH_CXXABI` | Linux, macos, mingw | Should be available everywhere other than [msvc](https://godbolt.org/z/93ca9rcdz). | +| dbghelp.h | `CPPTRACE_DEMANGLE_WITH_WINAPI` | Windows | Demangle with `UnDecorateSymbolName`. | +| N/A | `CPPTRACE_DEMANGLE_WITH_NOTHING` | all | Don't attempt to do anything beyond what the symbol resolution back-end does. | + +**More?** + +There are plenty more libraries that can be used for unwinding, parsing debug information, and demangling. In the future +more back-ends can be added. Ideally this library can "just work" on systems, without additional installation work. + +## Summary of Library Configurations + +Summary of all library configuration options: + +Back-ends: +- `CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF=On/Off` +- `CPPTRACE_GET_SYMBOLS_WITH_DBGHELP=On/Off` +- `CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE=On/Off` +- `CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE=On/Off` +- `CPPTRACE_GET_SYMBOLS_WITH_LIBDL=On/Off` +- `CPPTRACE_GET_SYMBOLS_WITH_NOTHING=On/Off` +- `CPPTRACE_UNWIND_WITH_UNWIND=On/Off` +- `CPPTRACE_UNWIND_WITH_LIBUNWIND=On/Off` +- `CPPTRACE_UNWIND_WITH_EXECINFO=On/Off` +- `CPPTRACE_UNWIND_WITH_WINAPI=On/Off` +- `CPPTRACE_UNWIND_WITH_DBGHELP=On/Off` +- `CPPTRACE_UNWIND_WITH_NOTHING=On/Off` +- `CPPTRACE_DEMANGLE_WITH_CXXABI=On/Off` +- `CPPTRACE_DEMANGLE_WITH_WINAPI=On/Off` +- `CPPTRACE_DEMANGLE_WITH_NOTHING=On/Off` + +Back-end configuration: +- `CPPTRACE_BACKTRACE_PATH=`: Path to libbacktrace backtrace.h, needed when compiling with clang/ +- `CPPTRACE_HARD_MAX_FRAMES=`: Some back-ends write to a fixed-size buffer. This is the size of that buffer. + Default is `400`. +- `CPPTRACE_ADDR2LINE_PATH=`: Specify the absolute path to the addr2line binary for cpptrace to invoke. By + default the config script will search for a binary and use that absolute path (this is to prevent against path + injection). +- `CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH=On/Off`: Specifies whether cpptrace should let the system search the PATH + environment variable directories for the binary. + +Other useful configurations: +- `CPPTRACE_BUILD_SHARED=On/Off`: Override for `BUILD_SHARED_LIBS`. +- `CPPTRACE_INCLUDES_WITH_SYSTEM=On/Off`: Marks cpptrace headers as `SYSTEM` which will hide any warnings that aren't + the fault of your project. Defaults to On. +- `CPPTRACE_INSTALL_CMAKEDIR`: Override for the installation path for the cmake configs. +- `CPPTRACE_USE_EXTERNAL_LIBDWARF=On/Off`: Get libdwarf from `find_package` rather than `FetchContent`. +- `CPPTRACE_POSITION_INDEPENDENT_CODE=On/Off`: Compile the library as a position independent code (PIE). Defaults to On. +- `CPPTRACE_STD_FORMAT=On/Off`: Control inclusion of `` and provision of `std::formatter` specializations by + cpptrace.hpp. This can also be controlled with the macro `CPPTRACE_NO_STD_FORMAT`. + +Testing: +- `CPPTRACE_BUILD_TESTING` Build small demo and test program +- `CPPTRACE_BUILD_TEST_RDYNAMIC` Use `-rdynamic` when compiling the test program + +# Testing Methodology + +Cpptrace currently uses integration and functional testing, building and running under every combination of back-end +options. The implementation is based on [github actions matrices][1] and driven by python scripts located in the +[`ci/`](ci/) folder. Testing used to be done by github actions matrices directly, however, launching hundreds of two +second jobs was extremely inefficient. Test outputs are compared against expected outputs located in +[`test/expected/`](test/expected/). Stack trace addresses may point to the address after an instruction depending on the +unwinding back-end, and the python script will check for an exact or near-match accordingly. + +[1]: https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs + +# Notes About the Library + +For the most part I'm happy with the state of the library. But I'm sure that there is room for improvement and issues +will exist. If you encounter any issue, please let me know! If you find any pain-points in the library, please let me +know that too. + +A note about performance: For handling of DWARF symbols there is a lot of room to explore for performance optimizations +and time-memory tradeoffs. If you find the current implementation is either slow or using too much memory, I'd be happy +to explore some of these options. + +A couple things I'd like to improve in the future: +- On Windows when collecting symbols with dbghelp (msvc/clang) parameter types are almost perfect but due to limitations + in dbghelp the library cannot accurately show const and volatile qualifiers or rvalue references (these appear as + pointers). + +# Contributing + +I'm grateful for the help I've received with this library and I welcome contributions! For information on contributing +please refer to [CONTRIBUTING.md](./CONTRIBUTING.md). + +# License + +This library is under the MIT license. + +Cpptrace uses libdwarf on linux, macos, and mingw/cygwin unless configured to use something else. If this library is +statically linked with libdwarf then the library's binary will itself be LGPL. + +[P2490R3]: https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2490r3.html diff --git a/dep/cpptrace/cmake/Findzstd.cmake b/dep/cpptrace/cmake/Findzstd.cmake new file mode 100644 index 00000000000..fc8eb529c1e --- /dev/null +++ b/dep/cpptrace/cmake/Findzstd.cmake @@ -0,0 +1,51 @@ +# Libdwarf needs zstd, cpptrace doesn't, and libdwarf has its own Findzstd but it doesn't define zstd::libzstd_static / +# zstd::libzstd_shared targets which leads to issues, necessitating a find_dependency(zstd) in cpptrace's cmake config +# and in order to support non-cmake-module installs we need to provide a Findzstd script. +# https://github.com/jeremy-rifkin/cpptrace/issues/112 + +# This will define +# zstd_FOUND +# zstd_INCLUDE_DIR +# zstd_LIBRARY + +find_path(zstd_INCLUDE_DIR NAMES zstd.h) + +find_library(zstd_LIBRARY_DEBUG NAMES zstdd zstd_staticd) +find_library(zstd_LIBRARY_RELEASE NAMES zstd zstd_static) + +include(SelectLibraryConfigurations) +SELECT_LIBRARY_CONFIGURATIONS(zstd) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS( + zstd DEFAULT_MSG + zstd_LIBRARY zstd_INCLUDE_DIR +) + +if(zstd_FOUND) + message(STATUS "Found Zstd: ${zstd_LIBRARY}") +endif() + +mark_as_advanced(zstd_INCLUDE_DIR zstd_LIBRARY) + +if(zstd_FOUND) + # just defining them the same... cmake will figure it out + if(NOT TARGET zstd::libzstd_static) + add_library(zstd::libzstd_static UNKNOWN IMPORTED) + set_target_properties( + zstd::libzstd_static + PROPERTIES + IMPORTED_LOCATION "${zstd_LIBRARIES}" + INTERFACE_INCLUDE_DIRECTORIES "${zstd_INCLUDE_DIR}" + ) + endif() + if(NOT TARGET zstd::libzstd_shared) + add_library(zstd::libzstd_shared UNKNOWN IMPORTED) + set_target_properties( + zstd::libzstd_shared + PROPERTIES + IMPORTED_LOCATION "${zstd_LIBRARIES}" + INTERFACE_INCLUDE_DIRECTORIES "${zstd_INCLUDE_DIR}" + ) + endif() +endif() diff --git a/dep/cpptrace/cmake/InstallRules.cmake b/dep/cpptrace/cmake/InstallRules.cmake new file mode 100644 index 00000000000..14dbf59ed30 --- /dev/null +++ b/dep/cpptrace/cmake/InstallRules.cmake @@ -0,0 +1,78 @@ +include(CMakePackageConfigHelpers) + +# copy header files to CMAKE_INSTALL_INCLUDEDIR +# don't include third party header files +install( + DIRECTORY + "${PROJECT_SOURCE_DIR}/include/" # our header files + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" + COMPONENT ${package_name}-development + # PATTERN "**/third_party" EXCLUDE # skip third party directory + # PATTERN "**/third_party/**" EXCLUDE # skip third party files +) + +# copy target build output artifacts to OS dependent locations +# (Except includes, that just sets a compiler flag with the path) +install( + TARGETS ${target_name} + EXPORT ${package_name}-targets + RUNTIME # + COMPONENT ${package_name}-runtime + LIBRARY # + COMPONENT ${package_name}-runtime + NAMELINK_COMPONENT ${package_name}-development + ARCHIVE # + COMPONENT ${package_name}-development + INCLUDES # + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" +) + +# create config file that points to targets file +configure_file( + "${PROJECT_SOURCE_DIR}/cmake/in/cpptrace-config-cmake.in" + "${PROJECT_BINARY_DIR}/cmake/${package_name}-config.cmake" + @ONLY +) + +# copy config file for find_package to find +install( + FILES "${PROJECT_BINARY_DIR}/cmake/${package_name}-config.cmake" + DESTINATION "${CPPTRACE_INSTALL_CMAKEDIR}" + COMPONENT ${package_name}-development +) + +# create version file for consumer to check version in CMake +write_basic_package_version_file( + "${package_name}-config-version.cmake" + COMPATIBILITY SameMajorVersion # a.k.a SemVer +) + +# copy version file for find_package to find for version check +install( + FILES "${PROJECT_BINARY_DIR}/${package_name}-config-version.cmake" + DESTINATION "${CPPTRACE_INSTALL_CMAKEDIR}" + COMPONENT ${package_name}-development +) + +# create targets file included by config file with targets for consumers +install( + EXPORT ${package_name}-targets + NAMESPACE cpptrace:: + DESTINATION "${CPPTRACE_INSTALL_CMAKEDIR}" + COMPONENT ${package_name}-development +) + +# Findzstd.cmake +# vcpkg doesn't like anything being put in share/, which is where this goes apparently on their setup +if(NOT CPPTRACE_VCPKG) + install( + FILES "${PROJECT_SOURCE_DIR}/cmake/Findzstd.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${package_name}" + ) +endif() + +# support packaging library +if(PROJECT_IS_TOP_LEVEL) + include(CPack) +endif() diff --git a/dep/cpptrace/cmake/OptionVariables.cmake b/dep/cpptrace/cmake/OptionVariables.cmake new file mode 100644 index 00000000000..e4c7e477c09 --- /dev/null +++ b/dep/cpptrace/cmake/OptionVariables.cmake @@ -0,0 +1,200 @@ +# Included further down to avoid interfering with our cache variables +# include(GNUInstallDirs) + +# ---- Options Summary ---- + +# --------------------------------------------------------------------------------------------------- +# | Option | Availability | Default | +# |=================================|===============|===============================================| +# | BUILD_SHARED_LIBS | Top-Level | OFF | +# | BUILD_TESTING | Top-Level | OFF | +# |---------------------------------|---------------|-----------------------------------------------| +# | CPPTRACE_BUILD_SHARED | Always | ${BUILD_SHARED_LIBS} | +# | CPPTRACE_BUILD_TESTING | Always | ${BUILD_TESTING} AND ${PROJECT_IS_TOP_LEVEL} | +# | CPPTRACE_INCLUDES_WITH_SYSTEM | Not Top-Level | ON | +# | CPPTRACE_INSTALL_CMAKEDIR | Always | ${CMAKE_INSTALL_LIBDIR}/cmake/${package_name} | +# | CPPTRACE_USE_EXTERNAL_LIBDWARF | Always | OFF | +# | CPPTRACE_USE_EXTERNAL_ZSTD | Always | OFF | +# | ... | | | +# --------------------------------------------------------------------------------------------------- + +# ---- Build Shared ---- + +# Sometimes it's useful to be able to single out a dependency to be built as +# static or shared, even if obtained from source +if(PROJECT_IS_TOP_LEVEL) + option(BUILD_SHARED_LIBS "Build shared libs" OFF) +endif() +option( + CPPTRACE_BUILD_SHARED + "Override BUILD_SHARED_LIBS for ${package_name} library" + ${BUILD_SHARED_LIBS} +) +mark_as_advanced(CPPTRACE_BUILD_SHARED) +set(build_type STATIC) +if(CPPTRACE_BUILD_SHARED) + set(build_type SHARED) +endif() + +# ---- Warning Guard ---- + +# target_include_directories with SYSTEM modifier will request the compiler to +# omit warnings from the provided paths, if the compiler supports that. +# This is to provide a user experience similar to find_package when +# add_subdirectory or FetchContent is used to consume this project. +set(warning_guard ) +if(NOT PROJECT_IS_TOP_LEVEL) + option( + CPPTRACE_INCLUDES_WITH_SYSTEM + "Use SYSTEM modifier for ${package_name}'s includes, disabling warnings" + ON + ) + mark_as_advanced(CPPTRACE_INCLUDES_WITH_SYSTEM) + if(CPPTRACE_INCLUDES_WITH_SYSTEM) + set(warning_guard SYSTEM) + endif() +endif() + +# ---- Enable Testing ---- + +# By default tests aren't enabled even with BUILD_TESTING=ON unless the library +# is built as a top level project. +# This is in order to cut down on unnecessary compile times, since it's unlikely +# for users to want to run the tests of their dependencies. +if(PROJECT_IS_TOP_LEVEL) + option(BUILD_TESTING "Build tests" OFF) +endif() +if(PROJECT_IS_TOP_LEVEL AND BUILD_TESTING) + set(build_testing ON) +endif() +option( + CPPTRACE_BUILD_TESTING + "Override BUILD_TESTING for ${package_name} library" + ${build_testing} +) +set(build_testing ) +mark_as_advanced(CPPTRACE_BUILD_TESTING) + +# ---- Install Include Directory ---- + +# Adds an extra directory to the include path by default, so that when you link +# against the target, you get `/include/` added to your +# include paths rather than `/include`. +# This doesn't affect include paths used by consumers of this project, but helps +# prevent consumers having access to other projects in the same include +# directory (e.g. usr/include). +# The variable type is STRING rather than PATH, because otherwise passing +# -DCMAKE_INSTALL_INCLUDEDIR=include on the command line would expand to an +# absolute path with the base being the current CMake directory, leading to +# unexpected errors. +# if(PROJECT_IS_TOP_LEVEL) +# set( +# CMAKE_INSTALL_INCLUDEDIR "include/${package_name}-${PROJECT_VERSION}" +# CACHE STRING "" +# ) +# # marked as advanced in GNUInstallDirs version, so we follow their lead +# mark_as_advanced(CMAKE_INSTALL_INCLUDEDIR) +# endif() + +# do not include earlier or we can't set CMAKE_INSTALL_INCLUDEDIR above +# include required for CMAKE_INSTALL_LIBDIR below +include(GNUInstallDirs) + +# ---- Install CMake Directory ---- + +# This allows package maintainers to freely override the installation path for +# the CMake configs. +# This doesn't affects include paths used by consumers of this project. +# The variable type is STRING rather than PATH, because otherwise passing +# -DCPPTRACE_INSTALL_CMAKEDIR=lib/cmake on the command line would expand to an +# absolute path with the base being the current CMake directory, leading to +# unexpected errors. +set( + CPPTRACE_INSTALL_CMAKEDIR "${CMAKE_INSTALL_LIBDIR}/cmake/${package_name}" + CACHE STRING "CMake package config location relative to the install prefix" +) +# depends on CMAKE_INSTALL_LIBDIR which is marked as advanced in GNUInstallDirs +mark_as_advanced(CPPTRACE_INSTALL_CMAKEDIR) + +# ---- Symbol Options ---- + +option(CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE "" OFF) +option(CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF "" OFF) +option(CPPTRACE_GET_SYMBOLS_WITH_LIBDL "" OFF) +option(CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE "" OFF) +option(CPPTRACE_GET_SYMBOLS_WITH_DBGHELP "" OFF) +option(CPPTRACE_GET_SYMBOLS_WITH_NOTHING "" OFF) + +# ---- Unwinding Options ---- + +option(CPPTRACE_UNWIND_WITH_UNWIND "" OFF) +option(CPPTRACE_UNWIND_WITH_LIBUNWIND "" OFF) +option(CPPTRACE_UNWIND_WITH_EXECINFO "" OFF) +option(CPPTRACE_UNWIND_WITH_WINAPI "" OFF) +option(CPPTRACE_UNWIND_WITH_DBGHELP "" OFF) +option(CPPTRACE_UNWIND_WITH_NOTHING "" OFF) + +# ---- Demangling Options ---- + +option(CPPTRACE_DEMANGLE_WITH_CXXABI "" OFF) +option(CPPTRACE_DEMANGLE_WITH_WINAPI "" OFF) +option(CPPTRACE_DEMANGLE_WITH_NOTHING "" OFF) + +# ---- Back-end configurations ---- + +set(CPPTRACE_BACKTRACE_PATH "" CACHE STRING "Path to backtrace.h, if the compiler doesn't already know it. Check /usr/lib/gcc/x86_64-linux-gnu/*/include.") +set(CPPTRACE_HARD_MAX_FRAMES "" CACHE STRING "Hard limit on unwinding depth. Default is 400.") +set(CPPTRACE_ADDR2LINE_PATH "" CACHE STRING "Absolute path to the addr2line executable you want to use.") +option(CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH "" OFF) + +# ---- Other configurations ---- + +if(PROJECT_IS_TOP_LEVEL) + option(CPPTRACE_BUILD_TESTING "" OFF) + option(CPPTRACE_BUILD_BENCHMARK "" OFF) + option(CPPTRACE_BUILD_TESTING_SPLIT_DWARF "" OFF) + set(CPPTRACE_BUILD_TESTING_DWARF_VERSION "0" CACHE STRING "") + option(CPPTRACE_BUILD_TEST_RDYNAMIC "" OFF) + mark_as_advanced( + CPPTRACE_BUILD_TESTING + CPPTRACE_BUILD_BENCHMARKING + CPPTRACE_BUILD_TESTING_SPLIT_DWARF + CPPTRACE_BUILD_TESTING_DWARF_VERSION + CPPTRACE_BUILD_TEST_RDYNAMIC + ) +endif() + +option(CPPTRACE_USE_EXTERNAL_LIBDWARF "" OFF) +option(CPPTRACE_FIND_LIBDWARF_WITH_PKGCONFIG "" OFF) +option(CPPTRACE_USE_EXTERNAL_ZSTD "" OFF) +option(CPPTRACE_CONAN "" OFF) +option(CPPTRACE_VCPKG "" OFF) +option(CPPTRACE_SANITIZER_BUILD "" OFF) +option(CPPTRACE_WERROR_BUILD "" OFF) +option(CPPTRACE_POSITION_INDEPENDENT_CODE "" ON) +option(CPPTRACE_SKIP_UNIT "" OFF) +option(CPPTRACE_STD_FORMAT "" ON) +option(CPPTRACE_UNPREFIXED_TRY_CATCH "" OFF) +option(CPPTRACE_USE_EXTERNAL_GTEST "" OFF) +set(CPPTRACE_ZSTD_URL "https://github.com/facebook/zstd/releases/download/v1.5.6/zstd-1.5.6.tar.gz" CACHE STRING "") +set(CPPTRACE_LIBDWARF_REPO "https://github.com/jeremy-rifkin/libdwarf-lite.git" CACHE STRING "") +set(CPPTRACE_LIBDWARF_TAG "97fd68c6026c0237943106d6bc3e83f3661d39e8" CACHE STRING "") # v0.11.0 +set(CPPTRACE_LIBDWARF_SHALLOW "1" CACHE STRING "") + +mark_as_advanced( + CPPTRACE_BACKTRACE_PATH + CPPTRACE_ADDR2LINE_PATH + CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH + CPPTRACE_SANITIZER_BUILD + CPPTRACE_WERROR_BUILD + CPPTRACE_CONAN + CPPTRACE_VCPKG + CPPTRACE_SKIP_UNIT + CPPTRACE_USE_EXTERNAL_GTEST + CPPTRACE_ZSTD_REPO + CPPTRACE_ZSTD_TAG + CPPTRACE_ZSTD_SHALLOW + CPPTRACE_LIBDWARF_REPO + CPPTRACE_LIBDWARF_TAG + CPPTRACE_LIBDWARF_SHALLOW +) diff --git a/dep/cpptrace/cmake/PreventInSourceBuilds.cmake b/dep/cpptrace/cmake/PreventInSourceBuilds.cmake new file mode 100644 index 00000000000..b9eedc1ad8a --- /dev/null +++ b/dep/cpptrace/cmake/PreventInSourceBuilds.cmake @@ -0,0 +1,8 @@ +# In-source build guard +if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR) + message( + FATAL_ERROR + "In-source builds are not supported. " + "You may need to delete 'CMakeCache.txt' and 'CMakeFiles/' before rebuilding this project." + ) +endif() diff --git a/dep/cpptrace/cmake/ProjectIsTopLevel.cmake b/dep/cpptrace/cmake/ProjectIsTopLevel.cmake new file mode 100644 index 00000000000..34f3e6bb599 --- /dev/null +++ b/dep/cpptrace/cmake/ProjectIsTopLevel.cmake @@ -0,0 +1,6 @@ +# This variable is set by project() in CMake 3.21+ +string( + COMPARE EQUAL + "${CMAKE_SOURCE_DIR}" "${PROJECT_SOURCE_DIR}" + PROJECT_IS_TOP_LEVEL +) diff --git a/dep/cpptrace/cmake/has_backtrace.cpp b/dep/cpptrace/cmake/has_backtrace.cpp new file mode 100644 index 00000000000..784994a742b --- /dev/null +++ b/dep/cpptrace/cmake/has_backtrace.cpp @@ -0,0 +1,9 @@ +#ifdef CPPTRACE_BACKTRACE_PATH +#include CPPTRACE_BACKTRACE_PATH +#else +#include +#endif + +int main() { + backtrace_state* state = backtrace_create_state(nullptr, true, nullptr, nullptr); +} diff --git a/dep/cpptrace/cmake/has_cxx_exception_type.cpp b/dep/cpptrace/cmake/has_cxx_exception_type.cpp new file mode 100644 index 00000000000..a5ea526265b --- /dev/null +++ b/dep/cpptrace/cmake/has_cxx_exception_type.cpp @@ -0,0 +1,6 @@ +#include + +int main() { + std::type_info* t = abi::__cxa_current_exception_type(); + (void*) t; +} diff --git a/dep/cpptrace/cmake/has_cxxabi.cpp b/dep/cpptrace/cmake/has_cxxabi.cpp new file mode 100644 index 00000000000..9c53baf9fb2 --- /dev/null +++ b/dep/cpptrace/cmake/has_cxxabi.cpp @@ -0,0 +1,6 @@ +#include + +int main() { + int status; + abi::__cxa_demangle("_Z3foov", nullptr, nullptr, &status); +} diff --git a/dep/cpptrace/cmake/has_dl.cpp b/dep/cpptrace/cmake/has_dl.cpp new file mode 100644 index 00000000000..efa25e5d9d4 --- /dev/null +++ b/dep/cpptrace/cmake/has_dl.cpp @@ -0,0 +1,6 @@ +#include + +int main() { + Dl_info info; + dladdr(nullptr, &info); +} diff --git a/dep/cpptrace/cmake/has_dl_find_object.cpp b/dep/cpptrace/cmake/has_dl_find_object.cpp new file mode 100644 index 00000000000..a75ca537666 --- /dev/null +++ b/dep/cpptrace/cmake/has_dl_find_object.cpp @@ -0,0 +1,6 @@ +#include + +int main() { + dl_find_object result; + _dl_find_object(reinterpret_cast(main), &result); +} diff --git a/dep/cpptrace/cmake/has_dladdr1.cpp b/dep/cpptrace/cmake/has_dladdr1.cpp new file mode 100644 index 00000000000..fd069a5ba77 --- /dev/null +++ b/dep/cpptrace/cmake/has_dladdr1.cpp @@ -0,0 +1,8 @@ +#include +#include + +int main() { + Dl_info info; + link_map* link_map_info; + dladdr1(reinterpret_cast(&main), &info, reinterpret_cast(&link_map_info), RTLD_DL_LINKMAP); +} diff --git a/dep/cpptrace/cmake/has_execinfo.cpp b/dep/cpptrace/cmake/has_execinfo.cpp new file mode 100644 index 00000000000..0ad09ce905a --- /dev/null +++ b/dep/cpptrace/cmake/has_execinfo.cpp @@ -0,0 +1,6 @@ +#include + +int main() { + void* frames[10]; + backtrace(frames, 10); +} diff --git a/dep/cpptrace/cmake/has_mach_vm.cpp b/dep/cpptrace/cmake/has_mach_vm.cpp new file mode 100644 index 00000000000..9361e676537 --- /dev/null +++ b/dep/cpptrace/cmake/has_mach_vm.cpp @@ -0,0 +1,23 @@ +#include +#include +#include + +int main() { + mach_vm_size_t vmsize; + uintptr_t addr = reinterpret_cast(&vmsize); + uintptr_t page_addr = addr & ~(4096 - 1); + mach_vm_address_t address = (mach_vm_address_t)page_addr; + vm_region_basic_info_data_t info; + mach_msg_type_number_t info_count = + sizeof(size_t) == 8 ? VM_REGION_BASIC_INFO_COUNT_64 : VM_REGION_BASIC_INFO_COUNT; + memory_object_name_t object; + mach_vm_region( + mach_task_self(), + &address, + &vmsize, + VM_REGION_BASIC_INFO, + (vm_region_info_t)&info, + &info_count, + &object + ); +} diff --git a/dep/cpptrace/cmake/has_stackwalk.cpp b/dep/cpptrace/cmake/has_stackwalk.cpp new file mode 100644 index 00000000000..bd9127c67c0 --- /dev/null +++ b/dep/cpptrace/cmake/has_stackwalk.cpp @@ -0,0 +1,101 @@ +#include +#include + +#define IS_CLANG 0 +#define IS_GCC 0 +#define IS_MSVC 0 + +#if defined(__clang__) + #undef IS_CLANG + #define IS_CLANG 1 +#elif defined(__GNUC__) || defined(__GNUG__) + #undef IS_GCC + #define IS_GCC 1 +#elif defined(_MSC_VER) + #undef IS_MSVC + #define IS_MSVC 1 +#else + #error "Unsupported compiler" +#endif + +int main() { + HANDLE proc = GetCurrentProcess(); + HANDLE thread = GetCurrentThread(); + // https://jpassing.com/2008/03/12/walking-the-stack-of-the-current-thread/ + + // Get current thread context + // GetThreadContext cannot be used on the current thread. + // RtlCaptureContext doesn't work on i386 + CONTEXT context; + #if defined(_M_IX86) || defined(__i386__) + ZeroMemory(&context, sizeof(CONTEXT)); + context.ContextFlags = CONTEXT_CONTROL; + #if IS_MSVC + __asm { + label: + mov [context.Ebp], ebp; + mov [context.Esp], esp; + mov eax, [label]; + mov [context.Eip], eax; + } + #else + asm( + "label:\n\t" + "mov{l %%ebp, %[cEbp] | %[cEbp], ebp};\n\t" + "mov{l %%esp, %[cEsp] | %[cEsp], esp};\n\t" + "mov{l $label, %%eax | eax, OFFSET label};\n\t" + "mov{l %%eax, %[cEip] | %[cEip], eax};\n\t" + : [cEbp] "=r" (context.Ebp), + [cEsp] "=r" (context.Esp), + [cEip] "=r" (context.Eip) + ); + #endif + #else + RtlCaptureContext(&context); + #endif + // Setup current frame + STACKFRAME64 frame; + ZeroMemory(&frame, sizeof(STACKFRAME64)); + DWORD machine_type; + #if defined(_M_IX86) || defined(__i386__) + machine_type = IMAGE_FILE_MACHINE_I386; + frame.AddrPC.Offset = context.Eip; + frame.AddrPC.Mode = AddrModeFlat; + frame.AddrFrame.Offset = context.Ebp; + frame.AddrFrame.Mode = AddrModeFlat; + frame.AddrStack.Offset = context.Esp; + frame.AddrStack.Mode = AddrModeFlat; + #elif defined(_M_X64) || defined(__x86_64__) + machine_type = IMAGE_FILE_MACHINE_AMD64; + frame.AddrPC.Offset = context.Rip; + frame.AddrPC.Mode = AddrModeFlat; + frame.AddrFrame.Offset = context.Rsp; + frame.AddrFrame.Mode = AddrModeFlat; + frame.AddrStack.Offset = context.Rsp; + frame.AddrStack.Mode = AddrModeFlat; + #elif defined(_M_IA64) || defined(__aarch64__) + machine_type = IMAGE_FILE_MACHINE_IA64; + frame.AddrPC.Offset = context.StIIP; + frame.AddrPC.Mode = AddrModeFlat; + frame.AddrFrame.Offset = context.IntSp; + frame.AddrFrame.Mode = AddrModeFlat; + frame.AddrBStore.Offset= context.RsBSP; + frame.AddrBStore.Mode = AddrModeFlat; + frame.AddrStack.Offset = context.IntSp; + frame.AddrStack.Mode = AddrModeFlat; + #else + #error "Cpptrace: StackWalk64 not supported for this platform yet" + #endif + ZeroMemory(&context, sizeof(CONTEXT)); + StackWalk64( + machine_type, + proc, + thread, + &frame, + machine_type == IMAGE_FILE_MACHINE_I386 ? NULL : &context, + NULL, + SymFunctionTableAccess64, + SymGetModuleBase64, + NULL + ); +} diff --git a/dep/cpptrace/cmake/has_unwind.cpp b/dep/cpptrace/cmake/has_unwind.cpp new file mode 100644 index 00000000000..503e306331c --- /dev/null +++ b/dep/cpptrace/cmake/has_unwind.cpp @@ -0,0 +1,14 @@ +#include + +#include + +_Unwind_Reason_Code unwind_callback(_Unwind_Context* context, void* arg) { + _Unwind_GetIP(context); + int is_before_instruction = 0; + uintptr_t ip = _Unwind_GetIPInfo(context, &is_before_instruction); + return _URC_END_OF_STACK; +} + +int main() { + _Unwind_Backtrace(unwind_callback, nullptr); +} diff --git a/dep/cpptrace/cmake/in/cpptrace-config-cmake.in b/dep/cpptrace/cmake/in/cpptrace-config-cmake.in new file mode 100644 index 00000000000..ef6e4a0adae --- /dev/null +++ b/dep/cpptrace/cmake/in/cpptrace-config-cmake.in @@ -0,0 +1,32 @@ +# Init @ variables before doing anything else +@PACKAGE_INIT@ + +# Dependencies +if(@CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF@) + include(CMakeFindDependencyMacro) + # we don't go the Findzstd.cmake route on vcpkg + if(@CPPTRACE_VCPKG@) + find_dependency(zstd CONFIG REQUIRED) + else() + set(CMAKE_MODULE_PATH_OLD "${CMAKE_MODULE_PATH}") + set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${CMAKE_CURRENT_LIST_DIR}") + find_dependency(zstd) + set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH_OLD}") + unset(CMAKE_MODULE_PATH_OLD) + endif() + if(NOT @CPPTRACE_FIND_LIBDWARF_WITH_PKGCONFIG@) + find_dependency(libdwarf REQUIRED) + endif() +endif() + +# We cannot modify an existing IMPORT target +if(NOT TARGET cpptrace::cpptrace) + + # import targets + include("${CMAKE_CURRENT_LIST_DIR}/@package_name@-targets.cmake") + +endif() + +if(@CPPTRACE_STATIC_DEFINE@) + target_compile_definitions(cpptrace::cpptrace INTERFACE CPPTRACE_STATIC_DEFINE) +endif() diff --git a/dep/cpptrace/include/cpptrace/cpptrace.hpp b/dep/cpptrace/include/cpptrace/cpptrace.hpp new file mode 100644 index 00000000000..667510a0dbd --- /dev/null +++ b/dep/cpptrace/include/cpptrace/cpptrace.hpp @@ -0,0 +1,498 @@ +#ifndef CPPTRACE_HPP +#define CPPTRACE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define CPPTRACE_EXPORT_ATTR __declspec(dllexport) +#define CPPTRACE_IMPORT_ATTR __declspec(dllimport) +#else +#define CPPTRACE_EXPORT_ATTR __attribute__((visibility("default"))) +#define CPPTRACE_IMPORT_ATTR __attribute__((visibility("default"))) +#endif + +#ifdef CPPTRACE_STATIC_DEFINE +# define CPPTRACE_EXPORT +# define CPPTRACE_NO_EXPORT +#else +# ifndef CPPTRACE_EXPORT +# ifdef cpptrace_lib_EXPORTS + /* We are building this library */ +# define CPPTRACE_EXPORT CPPTRACE_EXPORT_ATTR +# else + /* We are using this library */ +# define CPPTRACE_EXPORT CPPTRACE_IMPORT_ATTR +# endif +# endif +#endif + +#ifndef CPPTRACE_NO_STD_FORMAT + #if __cplusplus >= 202002L + #ifdef __has_include + #if __has_include() + #define CPPTRACE_STD_FORMAT + #include + #endif + #endif + #endif +#endif + +#ifdef _MSC_VER + #define CPPTRACE_FORCE_NO_INLINE __declspec(noinline) +#else + #define CPPTRACE_FORCE_NO_INLINE __attribute__((noinline)) +#endif + +#ifdef _MSC_VER +#pragma warning(push) +// warning C4251: using non-dll-exported type in dll-exported type, firing on std::vector and others for some +// reason +// 4275 is the same thing but for base classes +#pragma warning(disable: 4251; disable: 4275) +#endif + +namespace cpptrace { + struct object_trace; + struct stacktrace; + + // Some type sufficient for an instruction pointer, currently always an alias to std::uintptr_t + using frame_ptr = std::uintptr_t; + + struct CPPTRACE_EXPORT raw_trace { + std::vector frames; + static raw_trace current(std::size_t skip = 0); + static raw_trace current(std::size_t skip, std::size_t max_depth); + object_trace resolve_object_trace() const; + stacktrace resolve() const; + void clear(); + bool empty() const noexcept; + + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + inline iterator begin() noexcept { return frames.begin(); } + inline iterator end() noexcept { return frames.end(); } + inline const_iterator begin() const noexcept { return frames.begin(); } + inline const_iterator end() const noexcept { return frames.end(); } + inline const_iterator cbegin() const noexcept { return frames.cbegin(); } + inline const_iterator cend() const noexcept { return frames.cend(); } + }; + + struct CPPTRACE_EXPORT object_frame { + frame_ptr raw_address; + frame_ptr object_address; + std::string object_path; + }; + + struct CPPTRACE_EXPORT object_trace { + std::vector frames; + static object_trace current(std::size_t skip = 0); + static object_trace current(std::size_t skip, std::size_t max_depth); + stacktrace resolve() const; + void clear(); + bool empty() const noexcept; + + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + inline iterator begin() noexcept { return frames.begin(); } + inline iterator end() noexcept { return frames.end(); } + inline const_iterator begin() const noexcept { return frames.begin(); } + inline const_iterator end() const noexcept { return frames.end(); } + inline const_iterator cbegin() const noexcept { return frames.cbegin(); } + inline const_iterator cend() const noexcept { return frames.cend(); } + }; + + // This represents a nullable integer type + // The max value of the type is used as a sentinel + // This is used over std::optional because the library is C++11 and also std::optional is a bit heavy-duty for this + // use. + template::value, int>::type = 0> + struct nullable { + T raw_value; + nullable& operator=(T value) { + raw_value = value; + return *this; + } + bool has_value() const noexcept { + return raw_value != (std::numeric_limits::max)(); + } + T& value() noexcept { + return raw_value; + } + const T& value() const noexcept { + return raw_value; + } + T value_or(T alternative) const noexcept { + return has_value() ? raw_value : alternative; + } + void swap(nullable& other) noexcept { + std::swap(raw_value, other.raw_value); + } + void reset() noexcept { + raw_value = (std::numeric_limits::max)(); + } + bool operator==(const nullable& other) const noexcept { + return raw_value == other.raw_value; + } + bool operator!=(const nullable& other) const noexcept { + return raw_value != other.raw_value; + } + constexpr static nullable null() noexcept { + return { (std::numeric_limits::max)() }; + } + }; + + struct CPPTRACE_EXPORT stacktrace_frame { + frame_ptr raw_address; + frame_ptr object_address; + nullable line; + nullable column; + std::string filename; + std::string symbol; + bool is_inline; + + bool operator==(const stacktrace_frame& other) const { + return raw_address == other.raw_address + && object_address == other.object_address + && line == other.line + && column == other.column + && filename == other.filename + && symbol == other.symbol; + } + + bool operator!=(const stacktrace_frame& other) const { + return !operator==(other); + } + + object_frame get_object_info() const; + + std::string to_string() const; + friend std::ostream& operator<<(std::ostream& stream, const stacktrace_frame& frame); + }; + + struct CPPTRACE_EXPORT stacktrace { + std::vector frames; + static stacktrace current(std::size_t skip = 0); + static stacktrace current(std::size_t skip, std::size_t max_depth); + void print() const; + void print(std::ostream& stream) const; + void print(std::ostream& stream, bool color) const; + void print_with_snippets() const; + void print_with_snippets(std::ostream& stream) const; + void print_with_snippets(std::ostream& stream, bool color) const; + void clear(); + bool empty() const noexcept; + std::string to_string(bool color = false) const; + friend std::ostream& operator<<(std::ostream& stream, const stacktrace& trace); + + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + inline iterator begin() noexcept { return frames.begin(); } + inline iterator end() noexcept { return frames.end(); } + inline const_iterator begin() const noexcept { return frames.begin(); } + inline const_iterator end() const noexcept { return frames.end(); } + inline const_iterator cbegin() const noexcept { return frames.cbegin(); } + inline const_iterator cend() const noexcept { return frames.cend(); } + private: + void print(std::ostream& stream, bool color, bool newline_at_end, const char* header) const; + void print_with_snippets(std::ostream& stream, bool color, bool newline_at_end, const char* header) const; + friend void print_terminate_trace(); + }; + + CPPTRACE_EXPORT raw_trace generate_raw_trace(std::size_t skip = 0); + CPPTRACE_EXPORT raw_trace generate_raw_trace(std::size_t skip, std::size_t max_depth); + CPPTRACE_EXPORT object_trace generate_object_trace(std::size_t skip = 0); + CPPTRACE_EXPORT object_trace generate_object_trace(std::size_t skip, std::size_t max_depth); + CPPTRACE_EXPORT stacktrace generate_trace(std::size_t skip = 0); + CPPTRACE_EXPORT stacktrace generate_trace(std::size_t skip, std::size_t max_depth); + + // Path max isn't so simple, so I'm choosing 4096 which seems to encompass what all major OS's expect and should be + // fine in all reasonable cases. + // https://eklitzke.org/path-max-is-tricky + // https://insanecoding.blogspot.com/2007/11/pathmax-simply-isnt.html + #define CPPTRACE_PATH_MAX 4096 + + // safe tracing interface + // signal-safe + CPPTRACE_EXPORT std::size_t safe_generate_raw_trace( + frame_ptr* buffer, + std::size_t size, + std::size_t skip = 0 + ); + // signal-safe + CPPTRACE_EXPORT std::size_t safe_generate_raw_trace( + frame_ptr* buffer, + std::size_t size, + std::size_t skip, + std::size_t max_depth + ); + struct CPPTRACE_EXPORT safe_object_frame { + frame_ptr raw_address; + // This ends up being the real object address. It was named at a time when I thought the object base address + // still needed to be added in + frame_ptr address_relative_to_object_start; + char object_path[CPPTRACE_PATH_MAX + 1]; + // To be called outside a signal handler. Not signal safe. + object_frame resolve() const; + }; + // signal-safe + CPPTRACE_EXPORT void get_safe_object_frame(frame_ptr address, safe_object_frame* out); + CPPTRACE_EXPORT bool can_signal_safe_unwind(); + + // utilities: + CPPTRACE_EXPORT std::string demangle(const std::string& name); + CPPTRACE_EXPORT std::string get_snippet( + const std::string& path, + std::size_t line, + std::size_t context_size, + bool color = false + ); + CPPTRACE_EXPORT bool isatty(int fd); + + CPPTRACE_EXPORT extern const int stdin_fileno; + CPPTRACE_EXPORT extern const int stderr_fileno; + CPPTRACE_EXPORT extern const int stdout_fileno; + + CPPTRACE_EXPORT void register_terminate_handler(); + + // configuration: + CPPTRACE_EXPORT void absorb_trace_exceptions(bool absorb); + CPPTRACE_EXPORT void enable_inlined_call_resolution(bool enable); + + enum class cache_mode { + // Only minimal lookup tables + prioritize_memory = 0, + // Build lookup tables but don't keep them around between trace calls + hybrid = 1, + // Build lookup tables as needed + prioritize_speed = 2 + }; + + namespace experimental { + CPPTRACE_EXPORT void set_cache_mode(cache_mode mode); + } + + // tracing exceptions: + namespace detail { + // This is a helper utility, if the library weren't C++11 an std::variant would be used + class CPPTRACE_EXPORT lazy_trace_holder { + bool resolved; + union { + raw_trace trace; + stacktrace resolved_trace; + }; + public: + // constructors + lazy_trace_holder() : resolved(false), trace() {} + explicit lazy_trace_holder(raw_trace&& _trace) : resolved(false), trace(std::move(_trace)) {} + explicit lazy_trace_holder(stacktrace&& _resolved_trace) : resolved(true), resolved_trace(std::move(_resolved_trace)) {} + // logistics + lazy_trace_holder(const lazy_trace_holder& other); + lazy_trace_holder(lazy_trace_holder&& other) noexcept; + lazy_trace_holder& operator=(const lazy_trace_holder& other); + lazy_trace_holder& operator=(lazy_trace_holder&& other) noexcept; + ~lazy_trace_holder(); + // access + const raw_trace& get_raw_trace() const; + stacktrace& get_resolved_trace(); + const stacktrace& get_resolved_trace() const; + private: + void clear(); + }; + + CPPTRACE_EXPORT raw_trace get_raw_trace_and_absorb(std::size_t skip, std::size_t max_depth); + CPPTRACE_EXPORT raw_trace get_raw_trace_and_absorb(std::size_t skip = 0); + } + + // Interface for a traced exception object + class CPPTRACE_EXPORT exception : public std::exception { + public: + const char* what() const noexcept override = 0; + virtual const char* message() const noexcept = 0; + virtual const stacktrace& trace() const noexcept = 0; + }; + + // Cpptrace traced exception object + // I hate to have to expose anything about implementation detail but the idea here is that + class CPPTRACE_EXPORT lazy_exception : public exception { + mutable detail::lazy_trace_holder trace_holder; + mutable std::string what_string; + + public: + explicit lazy_exception( + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) : trace_holder(std::move(trace)) {} + // std::exception + const char* what() const noexcept override; + // cpptrace::exception + const char* message() const noexcept override; + const stacktrace& trace() const noexcept override; + }; + + class CPPTRACE_EXPORT exception_with_message : public lazy_exception { + mutable std::string user_message; + + public: + explicit exception_with_message( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept : lazy_exception(std::move(trace)), user_message(std::move(message_arg)) {} + + const char* message() const noexcept override; + }; + + class CPPTRACE_EXPORT logic_error : public exception_with_message { + public: + explicit logic_error( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT domain_error : public exception_with_message { + public: + explicit domain_error( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT invalid_argument : public exception_with_message { + public: + explicit invalid_argument( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT length_error : public exception_with_message { + public: + explicit length_error( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT out_of_range : public exception_with_message { + public: + explicit out_of_range( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT runtime_error : public exception_with_message { + public: + explicit runtime_error( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT range_error : public exception_with_message { + public: + explicit range_error( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT overflow_error : public exception_with_message { + public: + explicit overflow_error( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT underflow_error : public exception_with_message { + public: + explicit underflow_error( + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : exception_with_message(std::move(message_arg), std::move(trace)) {} + }; + + class CPPTRACE_EXPORT nested_exception : public lazy_exception { + std::exception_ptr ptr; + mutable std::string message_value; + public: + explicit nested_exception( + const std::exception_ptr& exception_ptr, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept + : lazy_exception(std::move(trace)), ptr(exception_ptr) {} + + const char* message() const noexcept override; + std::exception_ptr nested_ptr() const noexcept; + }; + + class CPPTRACE_EXPORT system_error : public runtime_error { + std::error_code ec; + public: + explicit system_error( + int error_code, + std::string&& message_arg, + raw_trace&& trace = detail::get_raw_trace_and_absorb() + ) noexcept; + const std::error_code& code() const noexcept; + }; + + // [[noreturn]] must come first due to old clang + [[noreturn]] CPPTRACE_EXPORT void rethrow_and_wrap_if_needed(std::size_t skip = 0); +} + +#if defined(CPPTRACE_STD_FORMAT) && defined(__cpp_lib_format) + template <> + struct std::formatter : std::formatter { + auto format(cpptrace::stacktrace_frame frame, format_context& ctx) const { + return formatter::format(frame.to_string(), ctx); + } + }; + + template <> + struct std::formatter : std::formatter { + auto format(cpptrace::stacktrace trace, format_context& ctx) const { + return formatter::format(trace.to_string(), ctx); + } + }; +#endif + +// Exception wrapper utilities +#define CPPTRACE_WRAP_BLOCK(statements) do { \ + try { \ + statements \ + } catch(...) { \ + ::cpptrace::rethrow_and_wrap_if_needed(); \ + } \ + } while(0) + +#define CPPTRACE_WRAP(expression) [&] () -> decltype((expression)) { \ + try { \ + return expression; \ + } catch(...) { \ + ::cpptrace::rethrow_and_wrap_if_needed(1); \ + } \ + } () + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif diff --git a/dep/cpptrace/include/cpptrace/from_current.hpp b/dep/cpptrace/include/cpptrace/from_current.hpp new file mode 100644 index 00000000000..91579992bb1 --- /dev/null +++ b/dep/cpptrace/include/cpptrace/from_current.hpp @@ -0,0 +1,131 @@ +#ifndef CPPTRACE_FROM_CURRENT_HPP +#define CPPTRACE_FROM_CURRENT_HPP + +#include + +namespace cpptrace { + CPPTRACE_EXPORT const raw_trace& raw_trace_from_current_exception(); + CPPTRACE_EXPORT const stacktrace& from_current_exception(); + + namespace detail { + // Trace switch is to prevent multiple tracing of stacks on call stacks with multiple catches that don't + // immediately match + inline bool& get_trace_switch() { + static thread_local bool trace_switch = true; + return trace_switch; + } + + class CPPTRACE_EXPORT try_canary { + public: + ~try_canary() { + // Fires when we exit a try block, either via normal means or during unwinding. + // Either way: Flip the switch. + get_trace_switch() = true; + } + }; + + CPPTRACE_EXPORT CPPTRACE_FORCE_NO_INLINE void collect_current_trace(std::size_t skip); + + // this function can be void, however, a char return is used to prevent TCO of the collect_current_trace + CPPTRACE_FORCE_NO_INLINE inline char exception_unwind_interceptor(std::size_t skip) { + if(get_trace_switch()) { + // Done during a search phase. Flip the switch off, no more traces until an unwind happens + get_trace_switch() = false; + collect_current_trace(skip + 1); + } + return 42; + } + + #ifdef _MSC_VER + CPPTRACE_FORCE_NO_INLINE inline int exception_filter() { + exception_unwind_interceptor(1); + return 0; // EXCEPTION_CONTINUE_SEARCH + } + CPPTRACE_FORCE_NO_INLINE inline int unconditional_exception_filter() { + collect_current_trace(1); + return 0; // EXCEPTION_CONTINUE_SEARCH + } + #else + class CPPTRACE_EXPORT unwind_interceptor { + public: + virtual ~unwind_interceptor(); + }; + class CPPTRACE_EXPORT unconditional_unwind_interceptor { + public: + virtual ~unconditional_unwind_interceptor(); + }; + + CPPTRACE_EXPORT void do_prepare_unwind_interceptor(char(*)(std::size_t)); + + #ifndef CPPTRACE_DONT_PREPARE_UNWIND_INTERCEPTOR_ON + __attribute__((constructor)) inline void prepare_unwind_interceptor() { + // __attribute__((constructor)) inline functions can be called for every source file they're #included in + // there is still only one copy of the inline function in the final executable, though + // LTO can make the redundant constructs fire only once + // do_prepare_unwind_interceptor prevents against multiple preparations however it makes sense to guard + // against it here too as a fast path, not that this should matter for performance + static bool did_prepare = false; + if(!did_prepare) { + do_prepare_unwind_interceptor(exception_unwind_interceptor); + did_prepare = true; + } + } + #endif + #endif + } +} + +#ifdef _MSC_VER + // this awful double-IILE is due to C2713 "You can't use structured exception handling (__try/__except) and C++ + // exception handling (try/catch) in the same function." + #define CPPTRACE_TRY \ + try { \ + ::cpptrace::detail::try_canary cpptrace_try_canary; \ + [&]() { \ + __try { \ + [&]() { + #define CPPTRACE_CATCH(param) \ + }(); \ + } __except(::cpptrace::detail::exception_filter()) {} \ + }(); \ + } catch(param) + #define CPPTRACE_TRYZ \ + try { \ + [&]() { \ + __try { \ + [&]() { + #define CPPTRACE_CATCHZ(param) \ + }(); \ + } __except(::cpptrace::detail::unconditional_exception_filter()) {} \ + }(); \ + } catch(param) +#else + #define CPPTRACE_TRY \ + try { \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wshadow\"") \ + ::cpptrace::detail::try_canary cpptrace_try_canary; \ + _Pragma("GCC diagnostic pop") \ + try { + #define CPPTRACE_CATCH(param) \ + } catch(::cpptrace::detail::unwind_interceptor&) {} \ + } catch(param) + #define CPPTRACE_TRYZ \ + try { \ + try { + #define CPPTRACE_CATCHZ(param) \ + } catch(::cpptrace::detail::unconditional_unwind_interceptor&) {} \ + } catch(param) +#endif + +#define CPPTRACE_CATCH_ALT(param) catch(param) + +#ifdef CPPTRACE_UNPREFIXED_TRY_CATCH + #define TRY CPPTRACE_TRY + #define CATCH(param) CPPTRACE_CATCH(param) + #define TRYZ CPPTRACE_TRYZ + #define CATCHZ(param) CPPTRACE_CATCHZ(param) + #define CATCH_ALT(param) CPPTRACE_CATCH_ALT(param) +#endif + +#endif diff --git a/dep/cpptrace/include/ctrace/ctrace.h b/dep/cpptrace/include/ctrace/ctrace.h new file mode 100644 index 00000000000..e4924beee13 --- /dev/null +++ b/dep/cpptrace/include/ctrace/ctrace.h @@ -0,0 +1,163 @@ +#ifndef CTRACE_H +#define CTRACE_H + +#include +#include +#include + +#ifdef _WIN32 +#define CPPTRACE_EXPORT_ATTR __declspec(dllexport) +#define CPPTRACE_IMPORT_ATTR __declspec(dllimport) +#else +#define CPPTRACE_EXPORT_ATTR __attribute__((visibility("default"))) +#define CPPTRACE_IMPORT_ATTR __attribute__((visibility("default"))) +#endif + +#ifdef CPPTRACE_STATIC_DEFINE +# define CPPTRACE_EXPORT +# define CPPTRACE_NO_EXPORT +#else +# ifndef CPPTRACE_EXPORT +# ifdef cpptrace_lib_EXPORTS + /* We are building this library */ +# define CPPTRACE_EXPORT CPPTRACE_EXPORT_ATTR +# else + /* We are using this library */ +# define CPPTRACE_EXPORT CPPTRACE_IMPORT_ATTR +# endif +# endif +#endif + +#if defined(__cplusplus) + #define CTRACE_BEGIN_DEFINITIONS extern "C" { + #define CTRACE_END_DEFINITIONS } +#else + #define CTRACE_BEGIN_DEFINITIONS + #define CTRACE_END_DEFINITIONS +#endif + +#ifdef _MSC_VER + #define CTRACE_FORCE_NO_INLINE __declspec(noinline) +#else + #define CTRACE_FORCE_NO_INLINE __attribute__((noinline)) +#endif + +#ifdef _MSC_VER + #define CTRACE_FORCE_INLINE __forceinline +#elif defined(__clang__) || defined(__GNUC__) + #define CTRACE_FORCE_INLINE __attribute__((always_inline)) inline +#else + #define CTRACE_FORCE_INLINE inline +#endif + +/* See `CPPTRACE_PATH_MAX` for more info. */ +#define CTRACE_PATH_MAX 4096 + +CTRACE_BEGIN_DEFINITIONS + + typedef struct ctrace_raw_trace ctrace_raw_trace; + typedef struct ctrace_object_trace ctrace_object_trace; + typedef struct ctrace_stacktrace ctrace_stacktrace; + + /* Represents a boolean value, ensures a consistent ABI. */ + typedef int8_t ctrace_bool; + /* A type that can represent a pointer, alias for `uintptr_t`. */ + typedef uintptr_t ctrace_frame_ptr; + typedef struct ctrace_object_frame ctrace_object_frame; + typedef struct ctrace_stacktrace_frame ctrace_stacktrace_frame; + typedef struct ctrace_safe_object_frame ctrace_safe_object_frame; + + /* Type-safe null-terminated string wrapper */ + typedef struct { + const char* data; + } ctrace_owning_string; + + struct ctrace_object_frame { + ctrace_frame_ptr raw_address; + ctrace_frame_ptr obj_address; + const char* obj_path; + }; + + struct ctrace_stacktrace_frame { + ctrace_frame_ptr raw_address; + ctrace_frame_ptr object_address; + uint32_t line; + uint32_t column; + const char* filename; + const char* symbol; + ctrace_bool is_inline; + }; + + struct ctrace_safe_object_frame { + ctrace_frame_ptr raw_address; + ctrace_frame_ptr relative_obj_address; + char object_path[CTRACE_PATH_MAX + 1]; + }; + + struct ctrace_raw_trace { + ctrace_frame_ptr* frames; + size_t count; + }; + + struct ctrace_object_trace { + ctrace_object_frame* frames; + size_t count; + }; + + struct ctrace_stacktrace { + ctrace_stacktrace_frame* frames; + size_t count; + }; + + /* ctrace::string: */ + CPPTRACE_EXPORT ctrace_owning_string ctrace_generate_owning_string(const char* raw_string); + CPPTRACE_EXPORT void ctrace_free_owning_string(ctrace_owning_string* string); + + /* ctrace::generation: */ + CPPTRACE_EXPORT ctrace_raw_trace ctrace_generate_raw_trace(size_t skip, size_t max_depth); + CPPTRACE_EXPORT ctrace_object_trace ctrace_generate_object_trace(size_t skip, size_t max_depth); + CPPTRACE_EXPORT ctrace_stacktrace ctrace_generate_trace(size_t skip, size_t max_depth); + + /* ctrace::freeing: */ + CPPTRACE_EXPORT void ctrace_free_raw_trace(ctrace_raw_trace* trace); + CPPTRACE_EXPORT void ctrace_free_object_trace(ctrace_object_trace* trace); + CPPTRACE_EXPORT void ctrace_free_stacktrace(ctrace_stacktrace* trace); + + /* ctrace::resolve: */ + CPPTRACE_EXPORT ctrace_stacktrace ctrace_resolve_raw_trace(const ctrace_raw_trace* trace); + CPPTRACE_EXPORT ctrace_object_trace ctrace_resolve_raw_trace_to_object_trace(const ctrace_raw_trace* trace); + CPPTRACE_EXPORT ctrace_stacktrace ctrace_resolve_object_trace(const ctrace_object_trace* trace); + + /* ctrace::safe: */ + CPPTRACE_EXPORT size_t ctrace_safe_generate_raw_trace(ctrace_frame_ptr* buffer, size_t size, size_t skip, size_t max_depth); + CPPTRACE_EXPORT void ctrace_get_safe_object_frame(ctrace_frame_ptr address, ctrace_safe_object_frame* out); + CPPTRACE_EXPORT ctrace_bool can_signal_safe_unwind(void); + + /* ctrace::io: */ + CPPTRACE_EXPORT ctrace_owning_string ctrace_stacktrace_to_string(const ctrace_stacktrace* trace, ctrace_bool use_color); + CPPTRACE_EXPORT void ctrace_print_stacktrace(const ctrace_stacktrace* trace, FILE* to, ctrace_bool use_color); + + /* ctrace::utility: */ + CPPTRACE_EXPORT ctrace_owning_string ctrace_demangle(const char* mangled); + CPPTRACE_EXPORT int ctrace_stdin_fileno(void); + CPPTRACE_EXPORT int ctrace_stderr_fileno(void); + CPPTRACE_EXPORT int ctrace_stdout_fileno(void); + CPPTRACE_EXPORT ctrace_bool ctrace_isatty(int fd); + + CPPTRACE_EXPORT ctrace_object_frame ctrace_get_object_info(const ctrace_stacktrace_frame* frame); + + /* ctrace::config: */ + typedef enum { + /* Only minimal lookup tables */ + ctrace_prioritize_memory = 0, + /* Build lookup tables but don't keep them around between trace calls */ + ctrace_hybrid = 1, + /* Build lookup tables as needed */ + ctrace_prioritize_speed = 2 + } ctrace_cache_mode; + CPPTRACE_EXPORT void ctrace_set_cache_mode(ctrace_cache_mode mode); + CPPTRACE_EXPORT void ctrace_enable_inlined_call_resolution(ctrace_bool enable); + +CTRACE_END_DEFINITIONS + +#endif diff --git a/dep/cpptrace/src/binary/elf.cpp b/dep/cpptrace/src/binary/elf.cpp new file mode 100644 index 00000000000..27c2a2096e5 --- /dev/null +++ b/dep/cpptrace/src/binary/elf.cpp @@ -0,0 +1,100 @@ +#include "binary/elf.hpp" + +#if IS_LINUX + +#include +#include +#include +#include +#include + +#include + +namespace cpptrace { +namespace detail { + template::value, int>::type = 0> + T elf_byteswap_if_needed(T value, bool elf_is_little) { + if(is_little_endian() == elf_is_little) { + return value; + } else { + return byteswap(value); + } + } + + template + static Result elf_get_module_image_base_from_program_table( + const std::string& object_path, + std::FILE* file, + bool is_little_endian + ) { + static_assert(Bits == 32 || Bits == 64, "Unexpected Bits argument"); + using Header = typename std::conditional::type; + using PHeader = typename std::conditional::type; + auto loaded_header = load_bytes
(file, 0); + if(loaded_header.is_error()) { + return std::move(loaded_header).unwrap_error(); + } + const Header& file_header = loaded_header.unwrap_value(); + if(file_header.e_ehsize != sizeof(Header)) { + return internal_error("ELF file header size mismatch" + object_path); + } + // PT_PHDR will occur at most once + // Should be somewhat reliable https://stackoverflow.com/q/61568612/15675011 + // It should occur at the beginning but may as well loop just in case + for(int i = 0; i < file_header.e_phnum; i++) { + auto loaded_ph = load_bytes(file, file_header.e_phoff + file_header.e_phentsize * i); + if(loaded_ph.is_error()) { + return std::move(loaded_ph).unwrap_error(); + } + const PHeader& program_header = loaded_ph.unwrap_value(); + if(elf_byteswap_if_needed(program_header.p_type, is_little_endian) == PT_PHDR) { + return elf_byteswap_if_needed(program_header.p_vaddr, is_little_endian) - + elf_byteswap_if_needed(program_header.p_offset, is_little_endian); + } + } + // Apparently some objects like shared objects can end up missing this header. 0 as a base seems correct. + return 0; + } + + Result elf_get_module_image_base(const std::string& object_path) { + auto file = raii_wrap(std::fopen(object_path.c_str(), "rb"), file_deleter); + if(file == nullptr) { + return internal_error("Unable to read object file " + object_path); + } + // Initial checks/metadata + auto magic = load_bytes>(file, 0); + if(magic.is_error()) { + return std::move(magic).unwrap_error(); + } + if(magic.unwrap_value() != (std::array{0x7F, 'E', 'L', 'F'})) { + return internal_error("File is not ELF " + object_path); + } + auto ei_class = load_bytes(file, 4); + if(ei_class.is_error()) { + return std::move(ei_class).unwrap_error(); + } + bool is_64 = ei_class.unwrap_value() == 2; + auto ei_data = load_bytes(file, 5); + if(ei_data.is_error()) { + return std::move(ei_data).unwrap_error(); + } + bool is_little_endian = ei_data.unwrap_value() == 1; + auto ei_version = load_bytes(file, 6); + if(ei_version.is_error()) { + return std::move(ei_version).unwrap_error(); + } + if(ei_version.unwrap_value() != 1) { + return internal_error("Unexpected ELF version " + object_path); + } + // get image base + if(is_64) { + return elf_get_module_image_base_from_program_table<64>(object_path, file, is_little_endian); + } else { + return elf_get_module_image_base_from_program_table<32>(object_path, file, is_little_endian); + } + } + +} +} + +#endif diff --git a/dep/cpptrace/src/binary/elf.hpp b/dep/cpptrace/src/binary/elf.hpp new file mode 100644 index 00000000000..f6387569683 --- /dev/null +++ b/dep/cpptrace/src/binary/elf.hpp @@ -0,0 +1,20 @@ +#ifndef ELF_HPP +#define ELF_HPP + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#if IS_LINUX + +#include +#include + +namespace cpptrace { +namespace detail { + Result elf_get_module_image_base(const std::string& object_path); +} +} + +#endif + +#endif diff --git a/dep/cpptrace/src/binary/mach-o.cpp b/dep/cpptrace/src/binary/mach-o.cpp new file mode 100644 index 00000000000..7a8cda0e304 --- /dev/null +++ b/dep/cpptrace/src/binary/mach-o.cpp @@ -0,0 +1,641 @@ +#include "binary/mach-o.hpp" + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#if IS_APPLE + +// A number of mach-o functions are deprecated as of macos 13 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace cpptrace { +namespace detail { + bool is_mach_o(std::uint32_t magic) { + switch(magic) { + case FAT_MAGIC: + case FAT_CIGAM: + case MH_MAGIC: + case MH_CIGAM: + case MH_MAGIC_64: + case MH_CIGAM_64: + return true; + default: + return false; + } + } + + bool file_is_mach_o(const std::string& object_path) noexcept { + auto file = raii_wrap(std::fopen(object_path.c_str(), "rb"), file_deleter); + if(file == nullptr) { + return false; + } + auto magic = load_bytes(file, 0); + if(magic) { + return is_mach_o(magic.unwrap_value()); + } else { + return false; + } + } + + bool is_fat_magic(std::uint32_t magic) { + return magic == FAT_MAGIC || magic == FAT_CIGAM; + } + + // Based on https://github.com/AlexDenisov/segment_dumper/blob/master/main.c + // and https://lowlevelbits.org/parsing-mach-o-files/ + bool is_magic_64(std::uint32_t magic) { + return magic == MH_MAGIC_64 || magic == MH_CIGAM_64; + } + + bool should_swap_bytes(std::uint32_t magic) { + return magic == MH_CIGAM || magic == MH_CIGAM_64 || magic == FAT_CIGAM; + } + + void swap_mach_header(mach_header_64& header) { + swap_mach_header_64(&header, NX_UnknownByteOrder); + } + + void swap_mach_header(mach_header& header) { + swap_mach_header(&header, NX_UnknownByteOrder); + } + + void swap_segment_command(segment_command_64& segment) { + swap_segment_command_64(&segment, NX_UnknownByteOrder); + } + + void swap_segment_command(segment_command& segment) { + swap_segment_command(&segment, NX_UnknownByteOrder); + } + + void swap_nlist(struct nlist& entry) { + swap_nlist(&entry, 1, NX_UnknownByteOrder); + } + + void swap_nlist(struct nlist_64& entry) { + swap_nlist_64(&entry, 1, NX_UnknownByteOrder); + } + + #ifdef __LP64__ + #define LP(x) x##_64 + #else + #define LP(x) x + #endif + + Result mach_o::symtab_info_data::get_string(std::size_t index) const { + if(stringtab && index < symtab.strsize) { + return stringtab.get() + index; + } else { + return internal_error("can't retrieve symbol from symtab"); + } + } + + Result mach_o::load() { + if(magic == FAT_MAGIC || magic == FAT_CIGAM) { + return load_fat_mach(); + } else { + fat_index = 0; + if(is_magic_64(magic)) { + return load_mach<64>(); + } else { + return load_mach<32>(); + } + } + } + + Result mach_o::open_mach_o(const std::string& object_path) { + auto file = raii_wrap(std::fopen(object_path.c_str(), "rb"), file_deleter); + if(file == nullptr) { + return internal_error("Unable to read object file {}", object_path); + } + auto magic = load_bytes(file, 0); + if(!magic) { + return magic.unwrap_error(); + } + if(!is_mach_o(magic.unwrap_value())) { + return internal_error("File is not mach-o {}", object_path); + } + mach_o obj(std::move(file), object_path, magic.unwrap_value()); + auto result = obj.load(); + if(result.is_error()) { + return result.unwrap_error(); + } else { + return obj; + } + } + + Result mach_o::get_text_vmaddr() { + for(const auto& command : load_commands) { + if(command.cmd == LC_SEGMENT_64 || command.cmd == LC_SEGMENT) { + auto segment = command.cmd == LC_SEGMENT_64 + ? load_segment_command<64>(command.file_offset) + : load_segment_command<32>(command.file_offset); + if(segment.is_error()) { + return std::move(segment).unwrap_error(); + } + if(std::strcmp(segment.unwrap_value().segname, "__TEXT") == 0) { + return segment.unwrap_value().vmaddr; + } + } + } + // somehow no __TEXT section was found... + return internal_error("Couldn't find __TEXT section while parsing Mach-O object"); + } + + std::size_t mach_o::get_fat_index() const { + VERIFY(fat_index != std::numeric_limits::max()); + return fat_index; + } + + void mach_o::print_segments() const { + int i = 0; + for(const auto& command : load_commands) { + if(command.cmd == LC_SEGMENT_64 || command.cmd == LC_SEGMENT) { + auto segment_load = command.cmd == LC_SEGMENT_64 + ? load_segment_command<64>(command.file_offset) + : load_segment_command<32>(command.file_offset); + fprintf(stderr, "Load command %d\n", i); + if(segment_load.is_error()) { + fprintf(stderr, " error\n"); + segment_load.drop_error(); + continue; + } + auto& segment = segment_load.unwrap_value(); + fprintf(stderr, " cmd %u\n", segment.cmd); + fprintf(stderr, " cmdsize %u\n", segment.cmdsize); + fprintf(stderr, " segname %s\n", segment.segname); + fprintf(stderr, " vmaddr 0x%llx\n", segment.vmaddr); + fprintf(stderr, " vmsize 0x%llx\n", segment.vmsize); + fprintf(stderr, " off 0x%llx\n", segment.fileoff); + fprintf(stderr, " filesize %llu\n", segment.filesize); + fprintf(stderr, " nsects %u\n", segment.nsects); + } + i++; + } + } + + Result>, internal_error> mach_o::get_symtab_info() { + if(!symtab_info.has_value() && !tried_to_load_symtab) { + // don't try to load the symtab again if for some reason loading here fails + tried_to_load_symtab = true; + for(const auto& command : load_commands) { + if(command.cmd == LC_SYMTAB) { + symtab_info_data info; + auto symtab = load_symbol_table_command(command.file_offset); + if(!symtab) { + return std::move(symtab).unwrap_error(); + } + info.symtab = symtab.unwrap_value(); + auto string = load_string_table(info.symtab.stroff, info.symtab.strsize); + if(!string) { + return std::move(string).unwrap_error(); + } + info.stringtab = std::move(string).unwrap_value(); + symtab_info = std::move(info); + break; + } + } + } + return std::reference_wrapper>{symtab_info}; + } + + void mach_o::print_symbol_table_entry( + const nlist_64& entry, + const std::unique_ptr& stringtab, + std::size_t stringsize, + std::size_t j + ) const { + const char* type = ""; + if(entry.n_type & N_STAB) { + switch(entry.n_type) { + case N_SO: type = "N_SO"; break; + case N_OSO: type = "N_OSO"; break; + case N_BNSYM: type = "N_BNSYM"; break; + case N_ENSYM: type = "N_ENSYM"; break; + case N_FUN: type = "N_FUN"; break; + } + } else if((entry.n_type & N_TYPE) == N_SECT) { + type = "N_SECT"; + } + fprintf( + stderr, + "%5llu %8llx %2llx %7s %2llu %4llx %16llx %s\n", + to_ull(j), + to_ull(entry.n_un.n_strx), + to_ull(entry.n_type), + type, + to_ull(entry.n_sect), + to_ull(entry.n_desc), + to_ull(entry.n_value), + stringtab == nullptr + ? "Stringtab error" + : entry.n_un.n_strx < stringsize + ? stringtab.get() + entry.n_un.n_strx + : "String index out of bounds" + ); + } + + void mach_o::print_symbol_table() { + int i = 0; + for(const auto& command : load_commands) { + if(command.cmd == LC_SYMTAB) { + auto symtab_load = load_symbol_table_command(command.file_offset); + fprintf(stderr, "Load command %d\n", i); + if(symtab_load.is_error()) { + fprintf(stderr, " error\n"); + symtab_load.drop_error(); + continue; + } + auto& symtab = symtab_load.unwrap_value(); + fprintf(stderr, " cmd %llu\n", to_ull(symtab.cmd)); + fprintf(stderr, " cmdsize %llu\n", to_ull(symtab.cmdsize)); + fprintf(stderr, " symoff 0x%llu\n", to_ull(symtab.symoff)); + fprintf(stderr, " nsyms %llu\n", to_ull(symtab.nsyms)); + fprintf(stderr, " stroff 0x%llu\n", to_ull(symtab.stroff)); + fprintf(stderr, " strsize %llu\n", to_ull(symtab.strsize)); + auto stringtab = load_string_table(symtab.stroff, symtab.strsize); + if(!stringtab) { + stringtab.drop_error(); + } + for(std::size_t j = 0; j < symtab.nsyms; j++) { + auto entry = bits == 32 + ? load_symtab_entry<32>(symtab.symoff, j) + : load_symtab_entry<64>(symtab.symoff, j); + if(!entry) { + fprintf(stderr, "error loading symtab entry\n"); + entry.drop_error(); + continue; + } + print_symbol_table_entry( + entry.unwrap_value(), + std::move(stringtab).value_or(std::unique_ptr(nullptr)), + symtab.strsize, + j + ); + } + } + i++; + } + } + + // produce information similar to dsymutil -dump-debug-map + Result mach_o::get_debug_map() { + // we have a bunch of symbols in our binary we need to pair up with symbols from various .o files + // first collect symbols and the objects they come from + debug_map debug_map; + auto symtab_info_res = get_symtab_info(); + if(!symtab_info_res) { + return std::move(symtab_info_res).unwrap_error(); + } + if(!symtab_info_res.unwrap_value().get()) { + return internal_error("No symtab info"); + } + const auto& symtab_info = symtab_info_res.unwrap_value().get().unwrap(); + const auto& symtab = symtab_info.symtab; + // TODO: Take timestamp into account? + std::string current_module; + optional current_function; + for(std::size_t j = 0; j < symtab.nsyms; j++) { + auto load_entry = bits == 32 + ? load_symtab_entry<32>(symtab.symoff, j) + : load_symtab_entry<64>(symtab.symoff, j); + if(!load_entry) { + return std::move(load_entry).unwrap_error(); + } + auto& entry = load_entry.unwrap_value(); + // entry.n_type & N_STAB indicates symbolic debug info + if(!(entry.n_type & N_STAB)) { + continue; + } + switch(entry.n_type) { + case N_SO: + // pass - these encode path and filename for the module, if applicable + break; + case N_OSO: + { + // sets the module + auto str = symtab_info.get_string(entry.n_un.n_strx); + if(!str) { + return std::move(str).unwrap_error(); + } + current_module = str.unwrap_value(); + } + break; + case N_BNSYM: break; // pass + case N_ENSYM: break; // pass + case N_FUN: + { + auto str = symtab_info.get_string(entry.n_un.n_strx); + if(!str) { + return std::move(str).unwrap_error(); + } + if(str.unwrap_value()[0] == 0) { + // end of function scope + if(!current_function) { /**/ } + current_function.unwrap().size = entry.n_value; + debug_map[current_module].push_back(std::move(current_function).unwrap()); + } else { + current_function = debug_map_entry{}; + current_function.unwrap().source_address = entry.n_value; + current_function.unwrap().name = str.unwrap_value(); + } + } + break; + } + } + return debug_map; + } + + Result, internal_error> mach_o::symbol_table() { + // we have a bunch of symbols in our binary we need to pair up with symbols from various .o files + // first collect symbols and the objects they come from + std::vector symbols; + auto symtab_info_res = get_symtab_info(); + if(!symtab_info_res) { + return std::move(symtab_info_res).unwrap_error(); + } + if(!symtab_info_res.unwrap_value().get()) { + return internal_error("No symtab info"); + } + const auto& symtab_info = symtab_info_res.unwrap_value().get().unwrap(); + const auto& symtab = symtab_info.symtab; + // TODO: Take timestamp into account? + for(std::size_t j = 0; j < symtab.nsyms; j++) { + auto load_entry = bits == 32 + ? load_symtab_entry<32>(symtab.symoff, j) + : load_symtab_entry<64>(symtab.symoff, j); + if(!load_entry) { + return std::move(load_entry).unwrap_error(); + } + auto& entry = load_entry.unwrap_value(); + if(entry.n_type & N_STAB) { + continue; + } + if((entry.n_type & N_TYPE) == N_SECT) { + auto str = symtab_info.get_string(entry.n_un.n_strx); + if(!str) { + return std::move(str).unwrap_error(); + } + symbols.push_back({ + entry.n_value, + str.unwrap_value() + }); + } + } + return symbols; + } + + // produce information similar to dsymutil -dump-debug-map + void mach_o::print_debug_map(const debug_map& debug_map) { + for(const auto& entry : debug_map) { + std::cout< + Result mach_o::load_mach() { + static_assert(Bits == 32 || Bits == 64, "Unexpected Bits argument"); + bits = Bits; + using Mach_Header = typename std::conditional::type; + std::size_t header_size = sizeof(Mach_Header); + auto load_header = load_bytes(file, load_base); + if(!load_header) { + return load_header.unwrap_error(); + } + Mach_Header& header = load_header.unwrap_value(); + magic = header.magic; + if(should_swap()) { + swap_mach_header(header); + } + cputype = header.cputype; + cpusubtype = header.cpusubtype; + filetype = header.filetype; + n_load_commands = header.ncmds; + sizeof_load_commands = header.sizeofcmds; + flags = header.flags; + // handle load commands + std::uint32_t ncmds = header.ncmds; + std::uint32_t load_commands_offset = load_base + header_size; + // iterate load commands + std::uint32_t actual_offset = load_commands_offset; + for(std::uint32_t i = 0; i < ncmds; i++) { + auto load_cmd = load_bytes(file, actual_offset); + if(!load_cmd) { + return load_cmd.unwrap_error(); + } + load_command& cmd = load_cmd.unwrap_value(); + if(should_swap()) { + swap_load_command(&cmd, NX_UnknownByteOrder); + } + load_commands.push_back({ actual_offset, cmd.cmd, cmd.cmdsize }); + actual_offset += cmd.cmdsize; + } + return monostate{}; + } + + Result mach_o::load_fat_mach() { + std::size_t header_size = sizeof(fat_header); + std::size_t arch_size = sizeof(fat_arch); + auto load_header = load_bytes(file, 0); + if(!load_header) { + return load_header.unwrap_error(); + } + fat_header& header = load_header.unwrap_value(); + if(should_swap()) { + swap_fat_header(&header, NX_UnknownByteOrder); + } + // thread_local static struct LP(mach_header)* mhp = _NSGetMachExecuteHeader(); + // off_t arch_offset = (off_t)header_size; + // for(std::size_t i = 0; i < header.nfat_arch; i++) { + // fat_arch arch = load_bytes(file, arch_offset); + // if(should_swap()) { + // swap_fat_arch(&arch, 1, NX_UnknownByteOrder); + // } + // off_t mach_header_offset = (off_t)arch.offset; + // arch_offset += arch_size; + // std::uint32_t magic = load_bytes(file, mach_header_offset); + // std::cerr<<"xxx: "<cputype<(mhp->cpusubtype & ~CPU_SUBTYPE_MASK)<cputype && + // static_cast(mhp->cpusubtype & ~CPU_SUBTYPE_MASK) == arch.cpusubtype + // ) { + // load_base = mach_header_offset; + // fat_index = i; + // if(is_magic_64(magic)) { + // load_mach<64>(true); + // } else { + // load_mach<32>(true); + // } + // return; + // } + // } + std::vector fat_arches; + fat_arches.reserve(header.nfat_arch); + off_t arch_offset = (off_t)header_size; + for(std::size_t i = 0; i < header.nfat_arch; i++) { + auto load_arch = load_bytes(file, arch_offset); + if(!load_arch) { + return load_arch.unwrap_error(); + } + fat_arch& arch = load_arch.unwrap_value(); + if(should_swap()) { + swap_fat_arch(&arch, 1, NX_UnknownByteOrder); + } + fat_arches.push_back(arch); + arch_offset += arch_size; + } + thread_local static struct LP(mach_header)* mhp = _NSGetMachExecuteHeader(); + fat_arch* best = NXFindBestFatArch( + mhp->cputype, + mhp->cpusubtype, + fat_arches.data(), + header.nfat_arch + ); + if(best) { + off_t mach_header_offset = (off_t)best->offset; + auto magic = load_bytes(file, mach_header_offset); + if(!magic) { + return magic.unwrap_error(); + } + load_base = mach_header_offset; + fat_index = best - fat_arches.data(); + if(is_magic_64(magic.unwrap_value())) { + load_mach<64>(); + } else { + load_mach<32>(); + } + return monostate{}; + } + // If this is reached... something went wrong. The cpu we're on wasn't found. + return internal_error("Couldn't find appropriate architecture in fat Mach-O"); + } + + template + Result mach_o::load_segment_command(std::uint32_t offset) const { + using Segment_Command = typename std::conditional::type; + auto load_segment = load_bytes(file, offset); + if(!load_segment) { + return load_segment.unwrap_error(); + } + Segment_Command& segment = load_segment.unwrap_value(); + ASSERT(segment.cmd == LC_SEGMENT_64 || segment.cmd == LC_SEGMENT); + if(should_swap()) { + swap_segment_command(segment); + } + // fields match just u64 instead of u32 + segment_command_64 common; + common.cmd = segment.cmd; + common.cmdsize = segment.cmdsize; + static_assert(sizeof common.segname == 16 && sizeof segment.segname == 16, "xx"); + memcpy(common.segname, segment.segname, 16); + common.vmaddr = segment.vmaddr; + common.vmsize = segment.vmsize; + common.fileoff = segment.fileoff; + common.filesize = segment.filesize; + common.maxprot = segment.maxprot; + common.initprot = segment.initprot; + common.nsects = segment.nsects; + common.flags = segment.flags; + return common; + } + + Result mach_o::load_symbol_table_command(std::uint32_t offset) const { + auto load_symtab = load_bytes(file, offset); + if(!load_symtab) { + return load_symtab.unwrap_error(); + } + symtab_command& symtab = load_symtab.unwrap_value(); + ASSERT(symtab.cmd == LC_SYMTAB); + if(should_swap()) { + swap_symtab_command(&symtab, NX_UnknownByteOrder); + } + return symtab; + } + + template + Result mach_o::load_symtab_entry(std::uint32_t symbol_base, std::size_t index) const { + using Nlist = typename std::conditional::type; + uint32_t offset = load_base + symbol_base + index * sizeof(Nlist); + auto load_entry = load_bytes(file, offset); + if(!load_entry) { + return load_entry.unwrap_error(); + } + Nlist& entry = load_entry.unwrap_value(); + if(should_swap()) { + swap_nlist(entry); + } + // fields match just u64 instead of u32 + nlist_64 common; + common.n_un.n_strx = entry.n_un.n_strx; + common.n_type = entry.n_type; + common.n_sect = entry.n_sect; + common.n_desc = entry.n_desc; + common.n_value = entry.n_value; + return common; + } + + Result, internal_error> mach_o::load_string_table(std::uint32_t offset, std::uint32_t byte_count) const { + std::unique_ptr buffer(new char[byte_count + 1]); + if(std::fseek(file, load_base + offset, SEEK_SET) != 0) { + return internal_error("fseek error while loading mach-o symbol table"); + } + if(std::fread(buffer.get(), sizeof(char), byte_count, file) != byte_count) { + return internal_error("fread error while loading mach-o symbol table"); + } + buffer[byte_count] = 0; // just out of an abundance of caution + return buffer; + } + + bool mach_o::should_swap() const { + return should_swap_bytes(magic); + } + + Result macho_is_fat(const std::string& object_path) { + auto file = raii_wrap(std::fopen(object_path.c_str(), "rb"), file_deleter); + if(file == nullptr) { + return internal_error("Unable to read object file {}", object_path); + } + auto magic = load_bytes(file, 0); + if(!magic) { + return magic.unwrap_error(); + } else { + return is_fat_magic(magic.unwrap_value()); + } + } +} +} + +#pragma GCC diagnostic pop + +#endif diff --git a/dep/cpptrace/src/binary/mach-o.hpp b/dep/cpptrace/src/binary/mach-o.hpp new file mode 100644 index 00000000000..62a88620dc2 --- /dev/null +++ b/dep/cpptrace/src/binary/mach-o.hpp @@ -0,0 +1,137 @@ +#ifndef MACHO_HPP +#define MACHO_HPP + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#if IS_APPLE + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cpptrace { +namespace detail { + bool file_is_mach_o(const std::string& object_path) noexcept; + + struct load_command_entry { + std::uint32_t file_offset; + std::uint32_t cmd; + std::uint32_t cmdsize; + }; + + class mach_o { + file_wrapper file; + std::string object_path; + std::uint32_t magic; + cpu_type_t cputype; + cpu_subtype_t cpusubtype; + std::uint32_t filetype; + std::uint32_t n_load_commands; + std::uint32_t sizeof_load_commands; + std::uint32_t flags; + std::size_t bits = 0; // 32 or 64 once load_mach is called + + std::size_t load_base = 0; + std::size_t fat_index = std::numeric_limits::max(); + + std::vector load_commands; + + struct symtab_info_data { + symtab_command symtab; + std::unique_ptr stringtab; + Result get_string(std::size_t index) const; + }; + + bool tried_to_load_symtab = false; + optional symtab_info; + + mach_o( + file_wrapper file, + const std::string& object_path, + std::uint32_t magic + ) : + file(std::move(file)), + object_path(object_path), + magic(magic) {} + + Result load(); + + public: + static NODISCARD Result open_mach_o(const std::string& object_path); + + mach_o(mach_o&&) = default; + ~mach_o() = default; + + Result get_text_vmaddr(); + + std::size_t get_fat_index() const; + + void print_segments() const; + + Result>, internal_error> get_symtab_info(); + + void print_symbol_table_entry( + const nlist_64& entry, + const std::unique_ptr& stringtab, + std::size_t stringsize, + std::size_t j + ) const; + + void print_symbol_table(); + + struct debug_map_entry { + uint64_t source_address; + uint64_t size; + std::string name; + }; + + struct symbol_entry { + uint64_t address; + std::string name; + }; + + // map from object file to a vector of symbols to resolve + using debug_map = std::unordered_map>; + + // produce information similar to dsymutil -dump-debug-map + Result get_debug_map(); + + Result, internal_error> symbol_table(); + + // produce information similar to dsymutil -dump-debug-map + static void print_debug_map(const debug_map& debug_map); + + private: + template + Result load_mach(); + + Result load_fat_mach(); + + template + Result load_segment_command(std::uint32_t offset) const; + + Result load_symbol_table_command(std::uint32_t offset) const; + + template + Result load_symtab_entry(std::uint32_t symbol_base, std::size_t index) const; + + Result, internal_error> load_string_table(std::uint32_t offset, std::uint32_t byte_count) const; + + bool should_swap() const; + }; + + Result macho_is_fat(const std::string& object_path); +} +} + +#endif + +#endif diff --git a/dep/cpptrace/src/binary/module_base.cpp b/dep/cpptrace/src/binary/module_base.cpp new file mode 100644 index 00000000000..12e2f7c2aaf --- /dev/null +++ b/dep/cpptrace/src/binary/module_base.cpp @@ -0,0 +1,95 @@ +#include "binary/module_base.hpp" + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#include +#include +#include +#include + +#if IS_LINUX || IS_APPLE + #include + #include + #if IS_APPLE + #include "binary/mach-o.hpp" + #else + #include "binary/elf.hpp" + #endif +#elif IS_WINDOWS + #include + #include "binary/pe.hpp" +#endif + +namespace cpptrace { +namespace detail { + #if IS_LINUX + Result get_module_image_base(const std::string& object_path) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map cache; + auto it = cache.find(object_path); + if(it == cache.end()) { + // arguably it'd be better to release the lock while computing this, but also arguably it's good to not + // have two threads try to do the same computation + auto base = elf_get_module_image_base(object_path); + // TODO: Cache the error + if(base.is_error()) { + return std::move(base).unwrap_error(); + } + cache.insert(it, {object_path, base.unwrap_value()}); + return base; + } else { + return it->second; + } + } + #elif IS_APPLE + Result get_module_image_base(const std::string& object_path) { + // We have to parse the Mach-O to find the offset of the text section..... + // I don't know how addresses are handled if there is more than one __TEXT load command. I'm assuming for + // now that there is only one, and I'm using only the first section entry within that load command. + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map cache; + auto it = cache.find(object_path); + if(it == cache.end()) { + // arguably it'd be better to release the lock while computing this, but also arguably it's good to not + // have two threads try to do the same computation + auto obj = mach_o::open_mach_o(object_path); + // TODO: Cache the error + if(!obj) { + return obj.unwrap_error(); + } + auto base = obj.unwrap_value().get_text_vmaddr(); + if(!base) { + return std::move(base).unwrap_error(); + } + cache.insert(it, {object_path, base.unwrap_value()}); + return base; + } else { + return it->second; + } + } + #else // Windows + Result get_module_image_base(const std::string& object_path) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map cache; + auto it = cache.find(object_path); + if(it == cache.end()) { + // arguably it'd be better to release the lock while computing this, but also arguably it's good to not + // have two threads try to do the same computation + auto base = pe_get_module_image_base(object_path); + // TODO: Cache the error + if(!base) { + return std::move(base).unwrap_error(); + } + cache.insert(it, {object_path, base.unwrap_value()}); + return base; + } else { + return it->second; + } + } + #endif +} +} diff --git a/dep/cpptrace/src/binary/module_base.hpp b/dep/cpptrace/src/binary/module_base.hpp new file mode 100644 index 00000000000..5053be0632f --- /dev/null +++ b/dep/cpptrace/src/binary/module_base.hpp @@ -0,0 +1,16 @@ +#ifndef IMAGE_MODULE_BASE_HPP +#define IMAGE_MODULE_BASE_HPP + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#include +#include + +namespace cpptrace { +namespace detail { + Result get_module_image_base(const std::string& object_path); +} +} + +#endif diff --git a/dep/cpptrace/src/binary/object.cpp b/dep/cpptrace/src/binary/object.cpp new file mode 100644 index 00000000000..d0849576589 --- /dev/null +++ b/dep/cpptrace/src/binary/object.cpp @@ -0,0 +1,179 @@ +#include "binary/object.hpp" + +#include "utils/common.hpp" +#include "utils/utils.hpp" +#include "binary/module_base.hpp" + +#include +#include +#include +#include + +#if IS_LINUX || IS_APPLE + #include + #include + #if IS_LINUX + #include // needed for dladdr1's link_map info + #endif +#elif IS_WINDOWS + #include +#endif + +namespace cpptrace { +namespace detail { + #if IS_LINUX || IS_APPLE + #if defined(CPPTRACE_HAS_DL_FIND_OBJECT) || defined(CPPTRACE_HAS_DLADDR1) + std::string resolve_l_name(const char* l_name) { + if(l_name != nullptr && l_name[0] != 0) { + return l_name; + } else { + // empty l_name, this means it's the currently running executable + // TODO: Caching and proper handling + char buffer[CPPTRACE_PATH_MAX + 1]{}; + auto res = readlink("/proc/self/exe", buffer, CPPTRACE_PATH_MAX); + if(res == -1) { + return ""; // TODO + } else { + return buffer; + } + } + } + #endif + // dladdr queries are needed to get pre-ASLR addresses and targets to run symbol resolution on + // _dl_find_object is preferred if at all possible as it is much faster (added in glibc 2.35) + // dladdr1 is preferred if possible because it allows for a more accurate object path to be resolved (glibc 2.3.3) + #ifdef CPPTRACE_HAS_DL_FIND_OBJECT // we don't even check for this on apple + object_frame get_frame_object_info(frame_ptr address) { + // Use _dl_find_object when we can, it's orders of magnitude faster + object_frame frame; + frame.raw_address = address; + frame.object_address = 0; + dl_find_object result; + if(_dl_find_object(reinterpret_cast(address), &result) == 0) { // thread safe + frame.object_path = resolve_l_name(result.dlfo_link_map->l_name); + frame.object_address = address - to_frame_ptr(result.dlfo_link_map->l_addr); + } + return frame; + } + #elif defined(CPPTRACE_HAS_DLADDR1) + object_frame get_frame_object_info(frame_ptr address) { + // https://github.com/bminor/glibc/blob/91695ee4598b39d181ab8df579b888a8863c4cab/elf/dl-addr.c#L26 + Dl_info info; + link_map* link_map_info; + object_frame frame; + frame.raw_address = address; + frame.object_address = 0; + if( + // thread safe + dladdr1(reinterpret_cast(address), &info, reinterpret_cast(&link_map_info), RTLD_DL_LINKMAP) + ) { + frame.object_path = resolve_l_name(link_map_info->l_name); + auto base = get_module_image_base(frame.object_path); + if(base.has_value()) { + frame.object_address = address + - reinterpret_cast(info.dli_fbase) + + base.unwrap_value(); + } else { + base.drop_error(); + } + } + return frame; + } + #else + // glibc dladdr may not return an accurate dli_fname as it uses argv[0] for addresses in the main executable + // https://github.com/bminor/glibc/blob/caed1f5c0b2e31b5f4e0f21fea4b2c9ecd3b5b30/elf/dl-addr.c#L33-L36 + // macos doesn't have dladdr1 but its dli_fname behaves more sensibly, same with some other libc's like musl + object_frame get_frame_object_info(frame_ptr address) { + // reference: https://github.com/bminor/glibc/blob/master/debug/backtracesyms.c + Dl_info info; + object_frame frame; + frame.raw_address = address; + frame.object_address = 0; + if(dladdr(reinterpret_cast(address), &info)) { // thread safe + frame.object_path = info.dli_fname; + auto base = get_module_image_base(info.dli_fname); + if(base.has_value()) { + frame.object_address = address + - reinterpret_cast(info.dli_fbase) + + base.unwrap_value(); + } else { + base.drop_error(); + } + } + return frame; + } + #endif + #else + std::string get_module_name(HMODULE handle) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map cache; + auto it = cache.find(handle); + if(it == cache.end()) { + char path[MAX_PATH]; + if(GetModuleFileNameA(handle, path, sizeof(path))) { + cache.insert(it, {handle, path}); + return path; + } else { + std::fprintf(stderr, "%s\n", std::system_error(GetLastError(), std::system_category()).what()); + cache.insert(it, {handle, ""}); + return ""; + } + } else { + return it->second; + } + } + + object_frame get_frame_object_info(frame_ptr address) { + object_frame frame; + frame.raw_address = address; + frame.object_address = 0; + HMODULE handle; + // Multithread safe as long as another thread doesn't come along and free the module + if(GetModuleHandleExA( + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT | GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, + reinterpret_cast(address), + &handle + )) { + frame.object_path = get_module_name(handle); + auto base = get_module_image_base(frame.object_path); + if(base.has_value()) { + frame.object_address = address + - reinterpret_cast(handle) + + base.unwrap_value(); + } else { + base.drop_error(); + } + } else { + std::fprintf(stderr, "%s\n", std::system_error(GetLastError(), std::system_category()).what()); + } + return frame; + } + #endif + + std::vector get_frames_object_info(const std::vector& addresses) { + std::vector frames; + frames.reserve(addresses.size()); + for(const frame_ptr address : addresses) { + frames.push_back(get_frame_object_info(address)); + } + return frames; + } + + object_frame resolve_safe_object_frame(const safe_object_frame& frame) { + std::string object_path = frame.object_path; + if(object_path.empty()) { + return { + frame.raw_address, + 0, + "" + }; + } + return { + frame.raw_address, + frame.address_relative_to_object_start, + std::move(object_path) + }; + } +} +} diff --git a/dep/cpptrace/src/binary/object.hpp b/dep/cpptrace/src/binary/object.hpp new file mode 100644 index 00000000000..9200642e1b1 --- /dev/null +++ b/dep/cpptrace/src/binary/object.hpp @@ -0,0 +1,21 @@ +#ifndef OBJECT_HPP +#define OBJECT_HPP + +#include "utils/common.hpp" +#include "utils/utils.hpp" +#include "binary/module_base.hpp" + +#include +#include + +namespace cpptrace { +namespace detail { + object_frame get_frame_object_info(frame_ptr address); + + std::vector get_frames_object_info(const std::vector& addresses); + + object_frame resolve_safe_object_frame(const safe_object_frame& frame); +} +} + +#endif diff --git a/dep/cpptrace/src/binary/pe.cpp b/dep/cpptrace/src/binary/pe.cpp new file mode 100644 index 00000000000..04d9961c29d --- /dev/null +++ b/dep/cpptrace/src/binary/pe.cpp @@ -0,0 +1,95 @@ +#include "binary/pe.hpp" + +#include "utils/common.hpp" +#include "utils/error.hpp" +#include "utils/utils.hpp" + +#if IS_WINDOWS +#include +#include +#include +#include +#include + +#include + +namespace cpptrace { +namespace detail { + template::value, int>::type = 0> + T pe_byteswap_if_needed(T value) { + // PE header values are little endian, I think dos e_lfanew should be too + if(!is_little_endian()) { + return byteswap(value); + } else { + return value; + } + } + + Result pe_get_module_image_base(const std::string& object_path) { + // https://drive.google.com/file/d/0B3_wGJkuWLytbnIxY1J5WUs4MEk/view?pli=1&resourcekey=0-n5zZ2UW39xVTH8ZSu6C2aQ + // https://0xrick.github.io/win-internals/pe3/ + // Endianness should always be little for dos and pe headers + std::FILE* file_ptr; + errno_t ret = fopen_s(&file_ptr, object_path.c_str(), "rb"); + auto file = raii_wrap(std::move(file_ptr), file_deleter); + if(ret != 0 || file == nullptr) { + return internal_error("Unable to read object file {}", object_path); + } + auto magic = load_bytes>(file, 0); + if(!magic) { + return std::move(magic).unwrap_error(); + } + if(std::memcmp(magic.unwrap_value().data(), "MZ", 2) != 0) { + return internal_error("File is not a PE file {}", object_path); + } + auto e_lfanew = load_bytes(file, 0x3c); // dos header + 0x3c + if(!e_lfanew) { + return std::move(e_lfanew).unwrap_error(); + } + DWORD nt_header_offset = pe_byteswap_if_needed(e_lfanew.unwrap_value()); + auto signature = load_bytes>(file, nt_header_offset); // nt header + 0 + if(!signature) { + return std::move(signature).unwrap_error(); + } + if(std::memcmp(signature.unwrap_value().data(), "PE\0\0", 4) != 0) { + return internal_error("File is not a PE file {}", object_path); + } + auto size_of_optional_header_raw = load_bytes(file, nt_header_offset + 4 + 0x10); // file header + 0x10 + if(!size_of_optional_header_raw) { + return std::move(size_of_optional_header_raw).unwrap_error(); + } + WORD size_of_optional_header = pe_byteswap_if_needed(size_of_optional_header_raw.unwrap_value()); + if(size_of_optional_header == 0) { + return internal_error("Unexpected optional header size for PE file"); + } + auto optional_header_magic_raw = load_bytes(file, nt_header_offset + 0x18); // optional header + 0x0 + if(!optional_header_magic_raw) { + return std::move(optional_header_magic_raw).unwrap_error(); + } + WORD optional_header_magic = pe_byteswap_if_needed(optional_header_magic_raw.unwrap_value()); + VERIFY( + optional_header_magic == IMAGE_NT_OPTIONAL_HDR_MAGIC, + ("PE file does not match expected bit-mode " + object_path).c_str() + ); + // finally get image base + if(optional_header_magic == IMAGE_NT_OPTIONAL_HDR32_MAGIC) { + // 32 bit + auto bytes = load_bytes(file, nt_header_offset + 0x18 + 0x1c); // optional header + 0x1c + if(!bytes) { + return std::move(bytes).unwrap_error(); + } + return to(pe_byteswap_if_needed(bytes.unwrap_value())); + } else { + // 64 bit + // I get an "error: 'QWORD' was not declared in this scope" for some reason when using QWORD + auto bytes = load_bytes(file, nt_header_offset + 0x18 + 0x18); // optional header + 0x18 + if(!bytes) { + return std::move(bytes).unwrap_error(); + } + return to(pe_byteswap_if_needed(bytes.unwrap_value())); + } + } +} +} + +#endif diff --git a/dep/cpptrace/src/binary/pe.hpp b/dep/cpptrace/src/binary/pe.hpp new file mode 100644 index 00000000000..c81d335c581 --- /dev/null +++ b/dep/cpptrace/src/binary/pe.hpp @@ -0,0 +1,19 @@ +#ifndef PE_HPP +#define PE_HPP + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#if IS_WINDOWS +#include +#include + +namespace cpptrace { +namespace detail { + Result pe_get_module_image_base(const std::string& object_path); +} +} + +#endif + +#endif diff --git a/dep/cpptrace/src/binary/safe_dl.cpp b/dep/cpptrace/src/binary/safe_dl.cpp new file mode 100644 index 00000000000..4ed6db7b3b4 --- /dev/null +++ b/dep/cpptrace/src/binary/safe_dl.cpp @@ -0,0 +1,68 @@ +#include "binary/safe_dl.hpp" + +#include "utils/common.hpp" +#include "utils/utils.hpp" +#include "platform/program_name.hpp" + +#include +#include +#include +#include +#include +#include + +#ifdef CPPTRACE_HAS_DL_FIND_OBJECT +#if IS_LINUX || IS_APPLE + #include + #include + #include +#endif + +namespace cpptrace { +namespace detail { + void get_safe_object_frame(frame_ptr address, safe_object_frame* out) { + out->raw_address = address; + dl_find_object result; + if(_dl_find_object(reinterpret_cast(address), &result) == 0) { // thread-safe, signal-safe + out->address_relative_to_object_start = address - to_frame_ptr(result.dlfo_link_map->l_addr); + if(result.dlfo_link_map->l_name != nullptr && result.dlfo_link_map->l_name[0] != 0) { + std::size_t path_length = std::strlen(result.dlfo_link_map->l_name); + std::memcpy( + out->object_path, + result.dlfo_link_map->l_name, + std::min(path_length + 1, std::size_t(CPPTRACE_PATH_MAX + 1)) + ); + } else { + // empty l_name, this means it's the currently running executable + memset(out->object_path, 0, CPPTRACE_PATH_MAX + 1); + // signal-safe + auto res = readlink("/proc/self/exe", out->object_path, CPPTRACE_PATH_MAX); + if(res == -1) { + // error handling? + } + // TODO: Special handling for /proc/pid/exe unlink edge case + } + } else { + // std::cout<<"error"<address_relative_to_object_start = 0; + out->object_path[0] = 0; + } + // TODO: Handle this part of the documentation? + // The address can be a code address or data address. On architectures using function descriptors, no attempt is + // made to decode the function descriptor. Depending on how these descriptors are implemented, _dl_find_object + // may return the object that defines the function descriptor (and not the object that contains the code + // implementing the function), or fail to find any object at all. + } +} +} +#else +namespace cpptrace { +namespace detail { + void get_safe_object_frame(frame_ptr address, safe_object_frame* out) { + out->raw_address = address; + out->address_relative_to_object_start = 0; + out->object_path[0] = 0; + } +} +} +#endif diff --git a/dep/cpptrace/src/binary/safe_dl.hpp b/dep/cpptrace/src/binary/safe_dl.hpp new file mode 100644 index 00000000000..714bfff25ae --- /dev/null +++ b/dep/cpptrace/src/binary/safe_dl.hpp @@ -0,0 +1,12 @@ +#ifndef SAFE_DL_HPP +#define SAFE_DL_HPP + +#include "utils/common.hpp" + +namespace cpptrace { +namespace detail { + void get_safe_object_frame(frame_ptr address, safe_object_frame* out); +} +} + +#endif diff --git a/dep/cpptrace/src/cpptrace.cpp b/dep/cpptrace/src/cpptrace.cpp new file mode 100644 index 00000000000..cbb81b0172b --- /dev/null +++ b/dep/cpptrace/src/cpptrace.cpp @@ -0,0 +1,698 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "symbols/symbols.hpp" +#include "unwind/unwind.hpp" +#include "demangle/demangle.hpp" +#include "platform/exception_type.hpp" +#include "utils/common.hpp" +#include "utils/utils.hpp" +#include "binary/object.hpp" +#include "binary/safe_dl.hpp" +#include "snippets/snippet.hpp" + +namespace cpptrace { + CPPTRACE_FORCE_NO_INLINE + raw_trace raw_trace::current(std::size_t skip) { + try { // try/catch can never be hit but it's needed to prevent TCO + return generate_raw_trace(skip + 1); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return raw_trace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + raw_trace raw_trace::current(std::size_t skip, std::size_t max_depth) { + try { // try/catch can never be hit but it's needed to prevent TCO + return generate_raw_trace(skip + 1, max_depth); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return raw_trace{}; + } + } + + object_trace raw_trace::resolve_object_trace() const { + try { + return object_trace{detail::get_frames_object_info(frames)}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return object_trace{}; + } + } + + stacktrace raw_trace::resolve() const { + try { + std::vector trace = detail::resolve_frames(frames); + for(auto& frame : trace) { + frame.symbol = detail::demangle(frame.symbol); + } + return {std::move(trace)}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return stacktrace{}; + } + } + + void raw_trace::clear() { + frames.clear(); + } + + bool raw_trace::empty() const noexcept { + return frames.empty(); + } + + CPPTRACE_FORCE_NO_INLINE + object_trace object_trace::current(std::size_t skip) { + try { // try/catch can never be hit but it's needed to prevent TCO + return generate_object_trace(skip + 1); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return object_trace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + object_trace object_trace::current(std::size_t skip, std::size_t max_depth) { + try { // try/catch can never be hit but it's needed to prevent TCO + return generate_object_trace(skip + 1, max_depth); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return object_trace{}; + } + } + + stacktrace object_trace::resolve() const { + try { + std::vector trace = detail::resolve_frames(frames); + for(auto& frame : trace) { + frame.symbol = detail::demangle(frame.symbol); + } + return {std::move(trace)}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return stacktrace(); + } + } + + void object_trace::clear() { + frames.clear(); + } + + bool object_trace::empty() const noexcept { + return frames.empty(); + } + + object_frame stacktrace_frame::get_object_info() const { + return detail::get_frame_object_info(raw_address); + } + + std::string stacktrace_frame::to_string() const { + std::string str; + if(is_inline) { + str += microfmt::format("{<{}}", 2 * sizeof(frame_ptr) + 2, "(inlined)"); + } else { + str += microfmt::format("0x{>{}:0h}", 2 * sizeof(frame_ptr), raw_address); + } + if(!symbol.empty()) { + str += microfmt::format(" in {}", symbol); + } + if(!filename.empty()) { + str += microfmt::format(" at {}", filename); + if(line.has_value()) { + str += microfmt::format(":{}", line.value()); + if(column.has_value()) { + str += microfmt::format(":{}", column.value()); + } + } + } + return str; + } + + std::ostream& operator<<(std::ostream& stream, const stacktrace_frame& frame) { + return stream << frame.to_string(); + } + + CPPTRACE_FORCE_NO_INLINE + stacktrace stacktrace::current(std::size_t skip) { + try { // try/catch can never be hit but it's needed to prevent TCO + return generate_trace(skip + 1); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return stacktrace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + stacktrace stacktrace::current(std::size_t skip, std::size_t max_depth) { + try { // try/catch can never be hit but it's needed to prevent TCO + return generate_trace(skip + 1, max_depth); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return stacktrace{}; + } + } + + void stacktrace::print() const { + print(std::cerr, true); + } + + void stacktrace::print(std::ostream& stream) const { + print(stream, true); + } + + void stacktrace::print(std::ostream& stream, bool color) const { + print(stream, color, true, nullptr); + } + + void print_frame( + std::ostream& stream, + bool color, + unsigned frame_number_width, + std::size_t counter, + const stacktrace_frame& frame + ) { + const auto reset = color ? RESET : ""; + const auto green = color ? GREEN : ""; + const auto yellow = color ? YELLOW : ""; + const auto blue = color ? BLUE : ""; + std::string line = microfmt::format("#{<{}} ", frame_number_width, counter); + if(frame.is_inline) { + line += microfmt::format("{<{}}", 2 * sizeof(frame_ptr) + 2, "(inlined)"); + } else { + line += microfmt::format("{}0x{>{}:0h}{}", blue, 2 * sizeof(frame_ptr), frame.raw_address, reset); + } + if(!frame.symbol.empty()) { + line += microfmt::format(" in {}{}{}", yellow, frame.symbol, reset); + } + if(!frame.filename.empty()) { + line += microfmt::format(" at {}{}{}", green, frame.filename, reset); + if(frame.line.has_value()) { + line += microfmt::format(":{}{}{}", blue, frame.line.value(), reset); + if(frame.column.has_value()) { + line += microfmt::format(":{}{}{}", blue, frame.column.value(), reset); + } + } + } + stream << line; + } + + void stacktrace::print(std::ostream& stream, bool color, bool newline_at_end, const char* header) const { + if( + color && ( + (&stream == &std::cout && isatty(stdout_fileno)) || (&stream == &std::cerr && isatty(stderr_fileno)) + ) + ) { + detail::enable_virtual_terminal_processing_if_needed(); + } + stream << (header ? header : "Stack trace (most recent call first):") << '\n'; + std::size_t counter = 0; + if(frames.empty()) { + stream << "\n"; + return; + } + const auto frame_number_width = detail::n_digits(static_cast(frames.size()) - 1); + for(const auto& frame : frames) { + print_frame(stream, color, frame_number_width, counter, frame); + if(newline_at_end || &frame != &frames.back()) { + stream << '\n'; + } + counter++; + } + } + + void stacktrace::print_with_snippets() const { + print_with_snippets(std::cerr, true); + } + + void stacktrace::print_with_snippets(std::ostream& stream) const { + print_with_snippets(stream, true); + } + + void stacktrace::print_with_snippets(std::ostream& stream, bool color) const { + print_with_snippets(stream, color, true, nullptr); + } + + void stacktrace::print_with_snippets(std::ostream& stream, bool color, bool newline_at_end, const char* header) const { + if( + color && ( + (&stream == &std::cout && isatty(stdout_fileno)) || (&stream == &std::cerr && isatty(stderr_fileno)) + ) + ) { + detail::enable_virtual_terminal_processing_if_needed(); + } + stream << (header ? header : "Stack trace (most recent call first):") << '\n'; + std::size_t counter = 0; + if(frames.empty()) { + stream << "" << '\n'; + return; + } + const auto frame_number_width = detail::n_digits(static_cast(frames.size()) - 1); + for(const auto& frame : frames) { + print_frame(stream, color, frame_number_width, counter, frame); + if(newline_at_end || &frame != &frames.back()) { + stream << '\n'; + } + if(frame.line.has_value() && !frame.filename.empty()) { + stream << detail::get_snippet(frame.filename, frame.line.value(), 2, color); + } + counter++; + } + } + + void stacktrace::clear() { + frames.clear(); + } + + bool stacktrace::empty() const noexcept { + return frames.empty(); + } + + std::string stacktrace::to_string(bool color) const { + std::ostringstream oss; + print(oss, color, false, nullptr); + return std::move(oss).str(); + } + + std::ostream& operator<<(std::ostream& stream, const stacktrace& trace) { + return stream << trace.to_string(); + } + + CPPTRACE_FORCE_NO_INLINE + raw_trace generate_raw_trace(std::size_t skip) { + try { + return raw_trace{detail::capture_frames(skip + 1, SIZE_MAX)}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return raw_trace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + raw_trace generate_raw_trace(std::size_t skip, std::size_t max_depth) { + try { + return raw_trace{detail::capture_frames(skip + 1, max_depth)}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return raw_trace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_generate_raw_trace(frame_ptr* buffer, std::size_t size, std::size_t skip) { + try { // try/catch can never be hit but it's needed to prevent TCO + return detail::safe_capture_frames(buffer, size, skip + 1, SIZE_MAX); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return 0; + } + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_generate_raw_trace( + frame_ptr* buffer, + std::size_t size, + std::size_t skip, + std::size_t max_depth + ) { + try { // try/catch can never be hit but it's needed to prevent TCO + return detail::safe_capture_frames(buffer, size, skip + 1, max_depth); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return 0; + } + } + + CPPTRACE_FORCE_NO_INLINE + object_trace generate_object_trace(std::size_t skip) { + try { + return object_trace{detail::get_frames_object_info(detail::capture_frames(skip + 1, SIZE_MAX))}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return object_trace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + object_trace generate_object_trace(std::size_t skip, std::size_t max_depth) { + try { + return object_trace{detail::get_frames_object_info(detail::capture_frames(skip + 1, max_depth))}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return object_trace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + stacktrace generate_trace(std::size_t skip) { + try { // try/catch can never be hit but it's needed to prevent TCO + return generate_trace(skip + 1, SIZE_MAX); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return stacktrace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + stacktrace generate_trace(std::size_t skip, std::size_t max_depth) { + try { + std::vector frames = detail::capture_frames(skip + 1, max_depth); + std::vector trace = detail::resolve_frames(frames); + for(auto& frame : trace) { + frame.symbol = detail::demangle(frame.symbol); + } + return {std::move(trace)}; + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return stacktrace(); + } + } + + object_frame safe_object_frame::resolve() const { + return detail::resolve_safe_object_frame(*this); + } + + void get_safe_object_frame(frame_ptr address, safe_object_frame* out) { + detail::get_safe_object_frame(address, out); + } + + bool can_signal_safe_unwind() { + return detail::has_safe_unwind(); + } + + std::string demangle(const std::string& name) { + return detail::demangle(name); + } + + std::string get_snippet(const std::string& path, std::size_t line, std::size_t context_size, bool color) { + return detail::get_snippet(path, line, context_size, color); + } + + bool isatty(int fd) { + return detail::isatty(fd); + } + + extern const int stdin_fileno = detail::fileno(stdin); + extern const int stdout_fileno = detail::fileno(stdout); + extern const int stderr_fileno = detail::fileno(stderr); + + CPPTRACE_FORCE_NO_INLINE void print_terminate_trace() { + try { // try/catch can never be hit but it's needed to prevent TCO + generate_trace(1).print( + std::cerr, + isatty(stderr_fileno), + true, + "Stack trace to reach terminate handler (most recent call first):" + ); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + } + } + + [[noreturn]] void terminate_handler() { + // TODO: Support std::nested_exception? + try { + auto ptr = std::current_exception(); + if(ptr == nullptr) { + fputs("terminate called without an active exception", stderr); + print_terminate_trace(); + } else { + std::rethrow_exception(ptr); + } + } catch(cpptrace::exception& e) { + microfmt::print( + stderr, + "Terminate called after throwing an instance of {}: {}\n", + demangle(typeid(e).name()), + e.message() + ); + e.trace().print(std::cerr, isatty(stderr_fileno)); + } catch(std::exception& e) { + microfmt::print( + stderr, "Terminate called after throwing an instance of {}: {}\n", demangle(typeid(e).name()), e.what() + ); + print_terminate_trace(); + } catch(...) { + microfmt::print( + stderr, "Terminate called after throwing an instance of {}\n", detail::exception_type_name() + ); + print_terminate_trace(); + } + std::flush(std::cerr); + abort(); + } + + void register_terminate_handler() { + std::set_terminate(terminate_handler); + } + + namespace detail { + std::atomic_bool absorb_trace_exceptions(true); // NOSONAR + std::atomic_bool resolve_inlined_calls(true); // NOSONAR + std::atomic current_cache_mode(cache_mode::prioritize_speed); // NOSONAR + } + + void absorb_trace_exceptions(bool absorb) { + detail::absorb_trace_exceptions = absorb; + } + + void enable_inlined_call_resolution(bool enable) { + detail::resolve_inlined_calls = enable; + } + + namespace experimental { + void set_cache_mode(cache_mode mode) { + detail::current_cache_mode = mode; + } + } + + namespace detail { + bool should_absorb_trace_exceptions() { + return absorb_trace_exceptions; + } + + bool should_resolve_inlined_calls() { + return resolve_inlined_calls; + } + + cache_mode get_cache_mode() { + return current_cache_mode; + } + + CPPTRACE_FORCE_NO_INLINE + raw_trace get_raw_trace_and_absorb(std::size_t skip, std::size_t max_depth) { + try { + return generate_raw_trace(skip + 1, max_depth); + } catch(const std::exception& e) { + if(!detail::should_absorb_trace_exceptions()) { + // TODO: Append to message somehow + std::fprintf( + stderr, + "Cpptrace: Exception occurred while resolving trace in cpptrace::exception object:\n%s\n", + e.what() + ); + } + return raw_trace{}; + } + } + + CPPTRACE_FORCE_NO_INLINE + raw_trace get_raw_trace_and_absorb(std::size_t skip) { + try { // try/catch can never be hit but it's needed to prevent TCO + return get_raw_trace_and_absorb(skip + 1, SIZE_MAX); + } catch(...) { + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + return raw_trace{}; + } + } + + lazy_trace_holder::lazy_trace_holder(const lazy_trace_holder& other) : resolved(other.resolved) { + if(other.resolved) { + new (&resolved_trace) stacktrace(other.resolved_trace); + } else { + new (&trace) raw_trace(other.trace); + } + } + lazy_trace_holder::lazy_trace_holder(lazy_trace_holder&& other) noexcept : resolved(other.resolved) { + if(other.resolved) { + new (&resolved_trace) stacktrace(std::move(other.resolved_trace)); + } else { + new (&trace) raw_trace(std::move(other.trace)); + } + } + lazy_trace_holder& lazy_trace_holder::operator=(const lazy_trace_holder& other) { + clear(); + resolved = other.resolved; + if(other.resolved) { + new (&resolved_trace) stacktrace(other.resolved_trace); + } else { + new (&trace) raw_trace(other.trace); + } + return *this; + } + lazy_trace_holder& lazy_trace_holder::operator=(lazy_trace_holder&& other) noexcept { + clear(); + resolved = other.resolved; + if(other.resolved) { + new (&resolved_trace) stacktrace(std::move(other.resolved_trace)); + } else { + new (&trace) raw_trace(std::move(other.trace)); + } + return *this; + } + lazy_trace_holder::~lazy_trace_holder() { + clear(); + } + // access + const raw_trace& lazy_trace_holder::get_raw_trace() const { + if(resolved) { + throw std::logic_error( + "cpptrace::detail::lazy_trace_holder::get_resolved_trace called on resolved holder" + ); + } + return trace; + } + stacktrace& lazy_trace_holder::get_resolved_trace() { + if(!resolved) { + raw_trace old_trace = std::move(trace); + *this = lazy_trace_holder(stacktrace{}); + try { + if(!old_trace.empty()) { + resolved_trace = old_trace.resolve(); + } + } catch(const std::exception& e) { + if(!detail::should_absorb_trace_exceptions()) { + // TODO: Append to message somehow? + std::fprintf( + stderr, + "Exception occurred while resolving trace in cpptrace::detail::lazy_trace_holder:\n%s\n", + e.what() + ); + } + } + } + return resolved_trace; + } + const stacktrace& lazy_trace_holder::get_resolved_trace() const { + if(!resolved) { + throw std::logic_error( + "cpptrace::detail::lazy_trace_holder::get_resolved_trace called on unresolved const holder" + ); + } + return resolved_trace; + } + void lazy_trace_holder::clear() { + if(resolved) { + resolved_trace.~stacktrace(); + } else { + trace.~raw_trace(); + } + } + } + + const char* lazy_exception::what() const noexcept { + if(what_string.empty()) { + what_string = message() + std::string(":\n") + trace_holder.get_resolved_trace().to_string(); + } + return what_string.c_str(); + } + + const char* lazy_exception::message() const noexcept { + return "cpptrace::lazy_exception"; + } + + const stacktrace& lazy_exception::trace() const noexcept { + return trace_holder.get_resolved_trace(); + } + + const char* exception_with_message::message() const noexcept { + return user_message.c_str(); + } + + system_error::system_error(int error_code, std::string&& message_arg, raw_trace&& trace) noexcept + : runtime_error( + message_arg + ": " + std::error_code(error_code, std::generic_category()).message(), + std::move(trace) + ), + ec(std::error_code(error_code, std::generic_category())) {} + + const std::error_code& system_error::code() const noexcept { + return ec; + } + + const char* nested_exception::message() const noexcept { + if(message_value.empty()) { + try { + std::rethrow_exception(ptr); + } catch(std::exception& e) { + message_value = std::string("Nested exception: ") + e.what(); + } catch(...) { + message_value = "Nested exception holding instance of " + detail::exception_type_name(); + } + } + return message_value.c_str(); + } + + std::exception_ptr nested_exception::nested_ptr() const noexcept { + return ptr; + } + + CPPTRACE_FORCE_NO_INLINE + void rethrow_and_wrap_if_needed(std::size_t skip) { + try { + std::rethrow_exception(std::current_exception()); + } catch(cpptrace::exception&) { + throw; // already a cpptrace::exception + } catch(...) { + throw nested_exception(std::current_exception(), detail::get_raw_trace_and_absorb(skip + 1)); + } + } +} diff --git a/dep/cpptrace/src/ctrace.cpp b/dep/cpptrace/src/ctrace.cpp new file mode 100644 index 00000000000..eacf15e2f5e --- /dev/null +++ b/dep/cpptrace/src/ctrace.cpp @@ -0,0 +1,442 @@ +#include +#include +#include + +#include "symbols/symbols.hpp" +#include "unwind/unwind.hpp" +#include "demangle/demangle.hpp" +#include "platform/exception_type.hpp" +#include "utils/common.hpp" +#include "utils/utils.hpp" +#include "binary/object.hpp" +#include "binary/safe_dl.hpp" + +#define ESC "\033[" +#define RESET ESC "0m" +#define RED ESC "31m" +#define GREEN ESC "32m" +#define YELLOW ESC "33m" +#define BLUE ESC "34m" +#define MAGENTA ESC "35m" +#define CYAN ESC "36m" + +#if defined(__GNUC__) && ((__GNUC__ > 2) || (__GNUC__ == 2 && __GNUC_MINOR__ >= 6)) +# define CTRACE_GNU_FORMAT(...) __attribute__((format(__VA_ARGS__))) +#elif defined(__clang__) +// Probably requires llvm >3.5? Not exactly sure. +# define CTRACE_GNU_FORMAT(...) __attribute__((format(__VA_ARGS__))) +#else +# define CTRACE_GNU_FORMAT(...) +#endif + +#if defined(__clang__) +# define CTRACE_FORMAT_PROLOGUE \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wformat-security\"") +# define CTRACE_FORMAT_EPILOGUE \ + _Pragma("clang diagnostic pop") +#elif defined(__GNUC_MINOR__) +# define CTRACE_FORMAT_PROLOGUE \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wformat-security\"") +# define CTRACE_FORMAT_EPILOGUE \ + _Pragma("GCC diagnostic pop") +#else +# define CTRACE_FORMAT_PROLOGUE +# define CTRACE_FORMAT_EPILOGUE +#endif + +namespace ctrace { + static constexpr std::uint32_t invalid_pos = ~0U; + +CTRACE_FORMAT_PROLOGUE + template + CTRACE_GNU_FORMAT(printf, 2, 0) + static void ffprintf(std::FILE* f, const char fmt[], Args&&...args) { + (void)std::fprintf(f, fmt, args...); + (void)fflush(f); + } +CTRACE_FORMAT_EPILOGUE + + static bool is_empty(std::uint32_t pos) noexcept { + return pos == invalid_pos; + } + + static bool is_empty(const char* str) noexcept { + return !str || std::char_traits::length(str) == 0; + } + + static ctrace_owning_string generate_owning_string(const char* raw_string) noexcept { + // Returns length to the null terminator. + std::size_t count = std::char_traits::length(raw_string); + char* new_string = new char[count + 1]; + std::char_traits::copy(new_string, raw_string, count); + new_string[count] = '\0'; + return { new_string }; + } + + static ctrace_owning_string generate_owning_string(const std::string& std_string) { + return generate_owning_string(std_string.c_str()); + } + + static void free_owning_string(const char* owned_string) noexcept { + if(!owned_string) return; // Not necessary but eh + delete[] owned_string; + } + + static void free_owning_string(ctrace_owning_string& owned_string) noexcept { + free_owning_string(owned_string.data); + } + + static ctrace_object_frame convert_object_frame(const cpptrace::object_frame& frame) { + const char* new_path = generate_owning_string(frame.object_path).data; + return { frame.raw_address, frame.object_address, new_path }; + } + + static ctrace_object_trace c_convert(const std::vector& trace) { + std::size_t count = trace.size(); + auto* frames = new ctrace_object_frame[count]; + std::transform(trace.begin(), trace.end(), frames, convert_object_frame); + return { frames, count }; + } + + static ctrace_stacktrace_frame convert_stacktrace_frame(const cpptrace::stacktrace_frame& frame) { + ctrace_stacktrace_frame new_frame; + new_frame.raw_address = frame.raw_address; + new_frame.object_address = frame.object_address; + new_frame.line = frame.line.value_or(invalid_pos); + new_frame.column = frame.column.value_or(invalid_pos); + new_frame.filename = generate_owning_string(frame.filename).data; + new_frame.symbol = generate_owning_string(cpptrace::detail::demangle(frame.symbol)).data; + new_frame.is_inline = ctrace_bool(frame.is_inline); + return new_frame; + } + + static cpptrace::stacktrace_frame convert_stacktrace_frame(const ctrace_stacktrace_frame& frame) { + using nullable_type = cpptrace::nullable; + static constexpr auto null_v = nullable_type::null().raw_value; + cpptrace::stacktrace_frame new_frame; + new_frame.raw_address = frame.raw_address; + new_frame.object_address = frame.object_address; + new_frame.line = nullable_type{is_empty(frame.line) ? null_v : frame.line}; + new_frame.column = nullable_type{is_empty(frame.column) ? null_v : frame.column}; + new_frame.filename = frame.filename; + new_frame.symbol = frame.symbol; + new_frame.is_inline = bool(frame.is_inline); + return new_frame; + } + + static ctrace_stacktrace c_convert(const std::vector& trace) { + std::size_t count = trace.size(); + auto* frames = new ctrace_stacktrace_frame[count]; + std::transform( + trace.begin(), + trace.end(), frames, + static_cast(convert_stacktrace_frame) + ); + return { frames, count }; + } + + static cpptrace::stacktrace cpp_convert(const ctrace_stacktrace* ptrace) { + if(!ptrace || !ptrace->frames) { + return { }; + } + std::vector new_frames; + new_frames.reserve(ptrace->count); + for(std::size_t i = 0; i < ptrace->count; ++i) { + new_frames.push_back(convert_stacktrace_frame(ptrace->frames[i])); + } + return cpptrace::stacktrace{std::move(new_frames)}; + } +} + +extern "C" { + // ctrace::string + ctrace_owning_string ctrace_generate_owning_string(const char* raw_string) { + return ctrace::generate_owning_string(raw_string); + } + + void ctrace_free_owning_string(ctrace_owning_string* string) { + if(!string) { + return; + } + ctrace::free_owning_string(*string); + string->data = nullptr; + } + + // ctrace::generation: + CTRACE_FORCE_NO_INLINE + ctrace_raw_trace ctrace_generate_raw_trace(size_t skip, size_t max_depth) { + try { + std::vector trace = cpptrace::detail::capture_frames(skip + 1, max_depth); + std::size_t count = trace.size(); + auto* frames = new ctrace_frame_ptr[count]; + std::copy(trace.data(), trace.data() + count, frames); + return { frames, count }; + } catch(...) { + // Don't check rethrow condition, it's risky. + return { nullptr, 0 }; + } + } + + CTRACE_FORCE_NO_INLINE + ctrace_object_trace ctrace_generate_object_trace(size_t skip, size_t max_depth) { + try { + std::vector trace = cpptrace::detail::get_frames_object_info( + cpptrace::detail::capture_frames(skip + 1, max_depth) + ); + return ctrace::c_convert(trace); + } catch(...) { // NOSONAR + // Don't check rethrow condition, it's risky. + return { nullptr, 0 }; + } + } + + CTRACE_FORCE_NO_INLINE + ctrace_stacktrace ctrace_generate_trace(size_t skip, size_t max_depth) { + try { + std::vector frames = cpptrace::detail::capture_frames(skip + 1, max_depth); + std::vector trace = cpptrace::detail::resolve_frames(frames); + return ctrace::c_convert(trace); + } catch(...) { // NOSONAR + // Don't check rethrow condition, it's risky. + return { nullptr, 0 }; + } + } + + + // ctrace::freeing: + void ctrace_free_raw_trace(ctrace_raw_trace* trace) { + if(!trace) { + return; + } + ctrace_frame_ptr* frames = trace->frames; + delete[] frames; + trace->frames = nullptr; + trace->count = 0; + } + + void ctrace_free_object_trace(ctrace_object_trace* trace) { + if(!trace || !trace->frames) { + return; + } + ctrace_object_frame* frames = trace->frames; + for(std::size_t i = 0; i < trace->count; ++i) { + const char* path = frames[i].obj_path; + ctrace::free_owning_string(path); + } + + delete[] frames; + trace->frames = nullptr; + trace->count = 0; + } + + void ctrace_free_stacktrace(ctrace_stacktrace* trace) { + if(!trace || !trace->frames) { + return; + } + ctrace_stacktrace_frame* frames = trace->frames; + for(std::size_t i = 0; i < trace->count; ++i) { + ctrace::free_owning_string(frames[i].filename); + ctrace::free_owning_string(frames[i].symbol); + } + + delete[] frames; + trace->frames = nullptr; + trace->count = 0; + } + + // ctrace::resolve: + ctrace_stacktrace ctrace_resolve_raw_trace(const ctrace_raw_trace* trace) { + if(!trace || !trace->frames) { + return { nullptr, 0 }; + } + try { + std::vector frames(trace->count, 0); + std::copy(trace->frames, trace->frames + trace->count, frames.begin()); + std::vector resolved = cpptrace::detail::resolve_frames(frames); + return ctrace::c_convert(resolved); + } catch(...) { // NOSONAR + // Don't check rethrow condition, it's risky. + return { nullptr, 0 }; + } + } + + ctrace_object_trace ctrace_resolve_raw_trace_to_object_trace(const ctrace_raw_trace* trace) { + if(!trace || !trace->frames) { + return { nullptr, 0 }; + } + try { + std::vector frames(trace->count, 0); + std::copy(trace->frames, trace->frames + trace->count, frames.begin()); + std::vector obj = cpptrace::detail::get_frames_object_info(frames); + return ctrace::c_convert(obj); + } catch(...) { // NOSONAR + // Don't check rethrow condition, it's risky. + return { nullptr, 0 }; + } + } + + ctrace_stacktrace ctrace_resolve_object_trace(const ctrace_object_trace* trace) { + if(!trace || !trace->frames) { + return { nullptr, 0 }; + } + try { + std::vector frames(trace->count, 0); + std::transform( + trace->frames, + trace->frames + trace->count, + frames.begin(), + [] (const ctrace_object_frame& frame) -> cpptrace::frame_ptr { + return frame.raw_address; + } + ); + std::vector resolved = cpptrace::detail::resolve_frames(frames); + return ctrace::c_convert(resolved); + } catch(...) { // NOSONAR + // Don't check rethrow condition, it's risky. + return { nullptr, 0 }; + } + } + + // ctrace::safe: + size_t ctrace_safe_generate_raw_trace(ctrace_frame_ptr* buffer, size_t size, size_t skip, size_t max_depth) { + return cpptrace::safe_generate_raw_trace(buffer, size, skip, max_depth); + } + + void ctrace_get_safe_object_frame(ctrace_frame_ptr address, ctrace_safe_object_frame* out) { + // TODO: change this? + static_assert(sizeof(cpptrace::safe_object_frame) == sizeof(ctrace_safe_object_frame), ""); + cpptrace::get_safe_object_frame(address, reinterpret_cast(out)); + } + + ctrace_bool can_signal_safe_unwind() { + return cpptrace::can_signal_safe_unwind(); + } + + // ctrace::io: + ctrace_owning_string ctrace_stacktrace_to_string(const ctrace_stacktrace* trace, ctrace_bool use_color) { + if(!trace || !trace->frames) { + return ctrace::generate_owning_string(""); + } + auto cpp_trace = ctrace::cpp_convert(trace); + std::string trace_string = cpp_trace.to_string(bool(use_color)); + return ctrace::generate_owning_string(trace_string); + } + + void ctrace_print_stacktrace(const ctrace_stacktrace* trace, FILE* to, ctrace_bool use_color) { + if( + use_color && ( + (to == stdout && cpptrace::isatty(cpptrace::stdout_fileno)) || + (to == stderr && cpptrace::isatty(cpptrace::stderr_fileno)) + ) + ) { + cpptrace::detail::enable_virtual_terminal_processing_if_needed(); + } + ctrace::ffprintf(to, "Stack trace (most recent call first):\n"); + if(trace->count == 0 || !trace->frames) { + ctrace::ffprintf(to, "\n"); + return; + } + const auto reset = use_color ? ESC "0m" : ""; + const auto green = use_color ? ESC "32m" : ""; + const auto yellow = use_color ? ESC "33m" : ""; + const auto blue = use_color ? ESC "34m" : ""; + const auto frame_number_width = cpptrace::detail::n_digits(unsigned(trace->count - 1)); + ctrace_stacktrace_frame* frames = trace->frames; + for(std::size_t i = 0; i < trace->count; ++i) { + static constexpr auto ptr_len = 2 * sizeof(cpptrace::frame_ptr); + ctrace::ffprintf(to, "#%-*llu ", int(frame_number_width), i); + if(frames[i].is_inline) { + (void)std::fprintf(to, "%*s", + int(ptr_len + 2), + "(inlined)"); + } else { + (void)std::fprintf(to, "%s0x%0*llx%s", + blue, + int(ptr_len), + cpptrace::detail::to_ull(frames[i].raw_address), + reset); + } + if(!ctrace::is_empty(frames[i].symbol)) { + (void)std::fprintf(to, " in %s%s%s", + yellow, + frames[i].symbol, + reset); + } + if(!ctrace::is_empty(frames[i].filename)) { + (void)std::fprintf(to, " at %s%s%s", + green, + frames[i].filename, + reset); + if(ctrace::is_empty(frames[i].line)) { + ctrace::ffprintf(to, "\n"); + continue; + } + (void)std::fprintf(to, ":%s%llu%s", + blue, + cpptrace::detail::to_ull(frames[i].line), + reset); + if(ctrace::is_empty(frames[i].column)) { + ctrace::ffprintf(to, "\n"); + continue; + } + (void)std::fprintf(to, ":%s%llu%s", + blue, + cpptrace::detail::to_ull(frames[i].column), + reset); + } + // always print newline at end :M + ctrace::ffprintf(to, "\n"); + } + } + + // utility::demangle: + ctrace_owning_string ctrace_demangle(const char* mangled) { + if(!mangled) { + return ctrace::generate_owning_string(""); + } + std::string demangled = cpptrace::demangle(mangled); + return ctrace::generate_owning_string(demangled); + } + + // utility::io + int ctrace_stdin_fileno(void) { + return cpptrace::stdin_fileno; + } + + int ctrace_stderr_fileno(void) { + return cpptrace::stderr_fileno; + } + + int ctrace_stdout_fileno(void) { + return cpptrace::stdout_fileno; + } + + ctrace_bool ctrace_isatty(int fd) { + return cpptrace::isatty(fd); + } + + // utility::cache: + void ctrace_set_cache_mode(ctrace_cache_mode mode) { + static constexpr auto cache_max = cpptrace::cache_mode::prioritize_speed; + if(mode > unsigned(cache_max)) { + return; + } + auto cache_mode = static_cast(mode); + cpptrace::experimental::set_cache_mode(cache_mode); + } + + void ctrace_enable_inlined_call_resolution(ctrace_bool enable) { + cpptrace::enable_inlined_call_resolution(enable); + } + + ctrace_object_frame ctrace_get_object_info(const ctrace_stacktrace_frame* frame) { + try { + cpptrace::object_frame new_frame = cpptrace::detail::get_frame_object_info(frame->raw_address); + return ctrace::convert_object_frame(new_frame); + } catch(...) { + return {0, 0, nullptr}; + } + } +} diff --git a/dep/cpptrace/src/demangle/demangle.hpp b/dep/cpptrace/src/demangle/demangle.hpp new file mode 100644 index 00000000000..9aba59df24c --- /dev/null +++ b/dep/cpptrace/src/demangle/demangle.hpp @@ -0,0 +1,12 @@ +#ifndef DEMANGLE_HPP +#define DEMANGLE_HPP + +#include + +namespace cpptrace { +namespace detail { + std::string demangle(const std::string&); +} +} + +#endif diff --git a/dep/cpptrace/src/demangle/demangle_with_cxxabi.cpp b/dep/cpptrace/src/demangle/demangle_with_cxxabi.cpp new file mode 100644 index 00000000000..15dfb635b5d --- /dev/null +++ b/dep/cpptrace/src/demangle/demangle_with_cxxabi.cpp @@ -0,0 +1,31 @@ +#ifdef CPPTRACE_DEMANGLE_WITH_CXXABI + +#include "demangle/demangle.hpp" + +#include + +#include +#include + +namespace cpptrace { +namespace detail { + std::string demangle(const std::string& name) { + int status; + // presumably thread-safe + // it appears safe to pass nullptr for status however the docs don't explicitly say it's safe so I don't + // want to rely on it + char* const demangled = abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status); + // demangled will always be nullptr on non-zero status, and if __cxa_demangle ever fails for any reason + // we'll just quietly return the mangled name + if(demangled) { + std::string str = demangled; + std::free(demangled); + return str; + } else { + return name; + } + } +} +} + +#endif diff --git a/dep/cpptrace/src/demangle/demangle_with_nothing.cpp b/dep/cpptrace/src/demangle/demangle_with_nothing.cpp new file mode 100644 index 00000000000..a3fcd983594 --- /dev/null +++ b/dep/cpptrace/src/demangle/demangle_with_nothing.cpp @@ -0,0 +1,15 @@ +#ifdef CPPTRACE_DEMANGLE_WITH_NOTHING + +#include "demangle/demangle.hpp" + +#include + +namespace cpptrace { +namespace detail { + std::string demangle(const std::string& name) { + return name; + } +} +} + +#endif diff --git a/dep/cpptrace/src/demangle/demangle_with_winapi.cpp b/dep/cpptrace/src/demangle/demangle_with_winapi.cpp new file mode 100644 index 00000000000..a41f93bb13a --- /dev/null +++ b/dep/cpptrace/src/demangle/demangle_with_winapi.cpp @@ -0,0 +1,25 @@ +#ifdef CPPTRACE_DEMANGLE_WITH_WINAPI + +#include "demangle/demangle.hpp" + +#include + +#include +#include + +namespace cpptrace { +namespace detail { + std::string demangle(const std::string& name) { + char buffer[500]; + auto ret = UnDecorateSymbolName(name.c_str(), buffer, sizeof(buffer) - 1, 0); + if(ret == 0) { + return name; + } else { + buffer[ret] = 0; // just in case, ms' docs unclear if null terminator inserted + return buffer; + } + } +} +} + +#endif diff --git a/dep/cpptrace/src/from_current.cpp b/dep/cpptrace/src/from_current.cpp new file mode 100644 index 00000000000..308735f7b42 --- /dev/null +++ b/dep/cpptrace/src/from_current.cpp @@ -0,0 +1,323 @@ +#include +#define CPPTRACE_DONT_PREPARE_UNWIND_INTERCEPTOR_ON +#include + +#include +#include + +#include "platform/platform.hpp" +#include "utils/common.hpp" +#include "utils/microfmt.hpp" +#include "utils/utils.hpp" + +#ifndef _MSC_VER + #include + #include + #if IS_WINDOWS + #include + #else + #include + #include + #if IS_APPLE + #include + #ifdef HAS_MACH_VM + #include + #endif + #else + #include + #include + #endif + #endif +#endif + +namespace cpptrace { + namespace detail { + thread_local lazy_trace_holder current_exception_trace; + + CPPTRACE_FORCE_NO_INLINE void collect_current_trace(std::size_t skip) { + current_exception_trace = lazy_trace_holder(cpptrace::generate_raw_trace(skip + 1)); + } + + #ifndef _MSC_VER + // set only once by do_prepare_unwind_interceptor + char (*intercept_unwind_handler)(std::size_t) = nullptr; + + CPPTRACE_FORCE_NO_INLINE + bool intercept_unwind(const std::type_info*, const std::type_info*, void**, unsigned) { + if(intercept_unwind_handler) { + intercept_unwind_handler(1); + } + return false; + } + + CPPTRACE_FORCE_NO_INLINE + bool unconditional_exception_unwind_interceptor(const std::type_info*, const std::type_info*, void**, unsigned) { + collect_current_trace(1); + return false; + } + + using do_catch_fn = decltype(intercept_unwind); + + unwind_interceptor::~unwind_interceptor() = default; + unconditional_unwind_interceptor::~unconditional_unwind_interceptor() = default; + + #if IS_LIBSTDCXX + constexpr size_t vtable_size = 11; + #elif IS_LIBCXX + constexpr size_t vtable_size = 10; + #else + #warning "Cpptrace from_current: Unrecognized C++ standard library, from_current() won't be supported" + constexpr size_t vtable_size = 0; + #endif + + #if IS_WINDOWS + int get_page_size() { + SYSTEM_INFO info; + GetSystemInfo(&info); + return info.dwPageSize; + } + constexpr auto memory_readonly = PAGE_READONLY; + constexpr auto memory_readwrite = PAGE_READWRITE; + int mprotect_page_and_return_old_protections(void* page, int page_size, int protections) { + DWORD old_protections; + if(!VirtualProtect(page, page_size, protections, &old_protections)) { + throw std::runtime_error( + microfmt::format( + "VirtualProtect call failed: {}", + std::system_error(GetLastError(), std::system_category()).what() + ) + ); + } + return old_protections; + } + void mprotect_page(void* page, int page_size, int protections) { + mprotect_page_and_return_old_protections(page, page_size, protections); + } + void* allocate_page(int page_size) { + auto page = VirtualAlloc(nullptr, page_size, MEM_COMMIT | MEM_RESERVE, memory_readwrite); + if(!page) { + throw std::runtime_error( + microfmt::format( + "VirtualAlloc call failed: {}", + std::system_error(GetLastError(), std::system_category()).what() + ) + ); + } + return page; + } + #else + int get_page_size() { + return getpagesize(); + } + constexpr auto memory_readonly = PROT_READ; + constexpr auto memory_readwrite = PROT_READ | PROT_WRITE; + #if IS_APPLE + int get_page_protections(void* page) { + // https://stackoverflow.com/a/12627784/15675011 + #ifdef HAS_MACH_VM + mach_vm_size_t vmsize; + mach_vm_address_t address = (mach_vm_address_t)page; + #else + vm_size_t vmsize; + vm_address_t address = (vm_address_t)page; + #endif + vm_region_basic_info_data_t info; + mach_msg_type_number_t info_count = + sizeof(size_t) == 8 ? VM_REGION_BASIC_INFO_COUNT_64 : VM_REGION_BASIC_INFO_COUNT; + memory_object_name_t object; + kern_return_t status = + #ifdef HAS_MACH_VM + mach_vm_region + #else + vm_region_64 + #endif + ( + mach_task_self(), + &address, + &vmsize, + VM_REGION_BASIC_INFO, + (vm_region_info_t)&info, + &info_count, + &object + ); + if(status == KERN_INVALID_ADDRESS) { + throw std::runtime_error("vm_region failed with KERN_INVALID_ADDRESS"); + } + int perms = 0; + if(info.protection & VM_PROT_READ) { + perms |= PROT_READ; + } + if(info.protection & VM_PROT_WRITE) { + perms |= PROT_WRITE; + } + if(info.protection & VM_PROT_EXECUTE) { + perms |= PROT_EXEC; + } + return perms; + } + #else + int get_page_protections(void* page) { + auto page_addr = reinterpret_cast(page); + std::ifstream stream("/proc/self/maps"); + stream>>std::hex; + while(!stream.eof()) { + uintptr_t start; + uintptr_t stop; + stream>>start; + stream.ignore(1); // dash + stream>>stop; + if(stream.eof()) { + break; + } + if(stream.fail()) { + throw std::runtime_error("Failure reading /proc/self/maps"); + } + if(page_addr >= start && page_addr < stop) { + stream.ignore(1); // space + char r, w, x; // there's a private/shared flag after these but we don't need it + stream>>r>>w>>x; + if(stream.fail() || stream.eof()) { + throw std::runtime_error("Failure reading /proc/self/maps"); + } + int perms = 0; + if(r == 'r') { + perms |= PROT_READ; + } + if(w == 'w') { + perms |= PROT_WRITE; + } + if(x == 'x') { + perms |= PROT_EXEC; + } + // std::cerr<<"--parsed: "<::max(), '\n'); + } + throw std::runtime_error("Failed to find mapping with page in /proc/self/maps"); + } + #endif + void mprotect_page(void* page, int page_size, int protections) { + if(mprotect(page, page_size, protections) != 0) { + throw std::runtime_error(microfmt::format("mprotect call failed: {}", strerror(errno))); + } + } + int mprotect_page_and_return_old_protections(void* page, int page_size, int protections) { + auto old_protections = get_page_protections(page); + mprotect_page(page, page_size, protections); + return old_protections; + } + void* allocate_page(int page_size) { + auto page = mmap(nullptr, page_size, memory_readwrite, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); + if(page == MAP_FAILED) { + throw std::runtime_error(microfmt::format("mmap call failed: {}", strerror(errno))); + } + return page; + } + #endif + + void perform_typeinfo_surgery(const std::type_info& info, do_catch_fn* do_catch_function) { + if(vtable_size == 0) { // set to zero if we don't know what standard library we're working with + return; + } + void* type_info_pointer = const_cast(static_cast(&info)); + void* type_info_vtable_pointer = *static_cast(type_info_pointer); + // the type info vtable pointer points to two pointers inside the vtable, adjust it back + type_info_vtable_pointer = static_cast(static_cast(type_info_vtable_pointer) - 2); + + // for libstdc++ the class type info vtable looks like + // 0x7ffff7f89d18 <_ZTVN10__cxxabiv117__class_type_infoE>: 0x0000000000000000 0x00007ffff7f89d00 + // [offset ][typeinfo pointer ] + // 0x7ffff7f89d28 <_ZTVN10__cxxabiv117__class_type_infoE+16>: 0x00007ffff7dd65a0 0x00007ffff7dd65c0 + // [base destructor ][deleting dtor ] + // 0x7ffff7f89d38 <_ZTVN10__cxxabiv117__class_type_infoE+32>: 0x00007ffff7dd8f10 0x00007ffff7dd8f10 + // [__is_pointer_p ][__is_function_p ] + // 0x7ffff7f89d48 <_ZTVN10__cxxabiv117__class_type_infoE+48>: 0x00007ffff7dd6640 0x00007ffff7dd6500 + // [__do_catch ][__do_upcast ] + // 0x7ffff7f89d58 <_ZTVN10__cxxabiv117__class_type_infoE+64>: 0x00007ffff7dd65e0 0x00007ffff7dd66d0 + // [__do_upcast ][__do_dyncast ] + // 0x7ffff7f89d68 <_ZTVN10__cxxabiv117__class_type_infoE+80>: 0x00007ffff7dd6580 0x00007ffff7f8abe8 + // [__do_find_public_src][other ] + // In libc++ the layout is + // [offset ][typeinfo pointer ] + // [base destructor ][deleting dtor ] + // [noop1 ][noop2 ] + // [can_catch ][search_above_dst ] + // [search_below_dst ][has_unambiguous_public_base] + // Relevant documentation/implementation: + // https://itanium-cxx-abi.github.io/cxx-abi/abi.html + // libstdc++ + // https://github.com/gcc-mirror/gcc/blob/b13e34699c7d27e561fcfe1b66ced1e50e69976f/libstdc%252B%252B-v3/libsupc%252B%252B/typeinfo + // https://github.com/gcc-mirror/gcc/blob/b13e34699c7d27e561fcfe1b66ced1e50e69976f/libstdc%252B%252B-v3/libsupc%252B%252B/class_type_info.cc + // libc++ + // https://github.com/llvm/llvm-project/blob/648f4d0658ab00cf1e95330c8811aaea9481a274/libcxx/include/typeinfo + // https://github.com/llvm/llvm-project/blob/648f4d0658ab00cf1e95330c8811aaea9481a274/libcxxabi/src/private_typeinfo.h + + // shouldn't be anything other than 4096 but out of an abundance of caution + auto page_size = get_page_size(); + if(page_size <= 0 && (page_size & (page_size - 1)) != 0) { + throw std::runtime_error( + microfmt::format("getpagesize() is not a power of 2 greater than zero (was {})", page_size) + ); + } + + // allocate a page for the new vtable so it can be made read-only later + // the OS cleans this up, no cleanup done here for it + void* new_vtable_page = allocate_page(page_size); + // make our own copy of the vtable + memcpy(new_vtable_page, type_info_vtable_pointer, vtable_size * sizeof(void*)); + // ninja in the custom __do_catch interceptor + auto new_vtable = static_cast(new_vtable_page); + new_vtable[6] = reinterpret_cast(do_catch_function); + // make the page read-only + mprotect_page(new_vtable_page, page_size, memory_readonly); + + // make the vtable pointer for unwind_interceptor's type_info point to the new vtable + auto type_info_addr = reinterpret_cast(type_info_pointer); + auto page_addr = type_info_addr & ~(page_size - 1); + // make sure the memory we're going to set is within the page + if(type_info_addr - page_addr + sizeof(void*) > static_cast(page_size)) { + throw std::runtime_error("pointer crosses page boundaries"); + } + auto old_protections = mprotect_page_and_return_old_protections( + reinterpret_cast(page_addr), + page_size, + memory_readwrite + ); + *static_cast(type_info_pointer) = static_cast(new_vtable + 2); + mprotect_page(reinterpret_cast(page_addr), page_size, old_protections); + } + + void do_prepare_unwind_interceptor(char(*intercept_unwind_handler)(std::size_t)) { + static bool did_prepare = false; + if(!did_prepare) { + cpptrace::detail::intercept_unwind_handler = intercept_unwind_handler; + try { + perform_typeinfo_surgery(typeid(cpptrace::detail::unwind_interceptor), intercept_unwind); + perform_typeinfo_surgery( + typeid(cpptrace::detail::unconditional_unwind_interceptor), + unconditional_exception_unwind_interceptor + ); + } catch(std::exception& e) { + std::fprintf( + stderr, + "Cpptrace: Exception occurred while preparing from_current support: %s\n", + e.what() + ); + } catch(...) { + std::fprintf(stderr, "Cpptrace: Unknown exception occurred while preparing from_current support\n"); + } + did_prepare = true; + } + } + #endif + } + + const raw_trace& raw_trace_from_current_exception() { + return detail::current_exception_trace.get_raw_trace(); + } + + const stacktrace& from_current_exception() { + return detail::current_exception_trace.get_resolved_trace(); + } +} diff --git a/dep/cpptrace/src/platform/dbghelp_syminit_manager.hpp b/dep/cpptrace/src/platform/dbghelp_syminit_manager.hpp new file mode 100644 index 00000000000..7a1ca935a37 --- /dev/null +++ b/dep/cpptrace/src/platform/dbghelp_syminit_manager.hpp @@ -0,0 +1,43 @@ +#ifndef DBGHELP_SYMINIT_MANAGER_HPP +#define DBGHELP_SYMINIT_MANAGER_HPP + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#include + +#include +#include + +namespace cpptrace { +namespace detail { + struct dbghelp_syminit_manager { + std::unordered_set set; + + ~dbghelp_syminit_manager() { + for(auto handle : set) { + if(!SymCleanup(handle)) { + ASSERT(false, microfmt::format("Cpptrace SymCleanup failed with code {}\n", GetLastError()).c_str()); + } + } + } + + void init(HANDLE proc) { + if(set.count(proc) == 0) { + if(!SymInitialize(proc, NULL, TRUE)) { + throw internal_error("SymInitialize failed {}", GetLastError()); + } + set.insert(proc); + } + } + }; + + // Thread-safety: Must only be called from symbols_with_dbghelp while the dbghelp_lock lock is held + inline dbghelp_syminit_manager& get_syminit_manager() { + static dbghelp_syminit_manager syminit_manager; + return syminit_manager; + } +} +} + +#endif diff --git a/dep/cpptrace/src/platform/exception_type.hpp b/dep/cpptrace/src/platform/exception_type.hpp new file mode 100644 index 00000000000..5a989a3929e --- /dev/null +++ b/dep/cpptrace/src/platform/exception_type.hpp @@ -0,0 +1,27 @@ +#ifndef EXCEPTION_TYPE_HPP +#define EXCEPTION_TYPE_HPP + +#include + +#include "platform/platform.hpp" + +// libstdc++ and libc++ +#if defined(CPPTRACE_HAS_CXX_EXCEPTION_TYPE) && (IS_LIBSTDCXX || IS_LIBCXX) + #include + #include "demangle/demangle.hpp" +#endif + +namespace cpptrace { +namespace detail { + inline std::string exception_type_name() { + #if defined(CPPTRACE_HAS_CXX_EXCEPTION_TYPE) && (IS_LIBSTDCXX || IS_LIBCXX) + const std::type_info* t = abi::__cxa_current_exception_type(); + return t ? detail::demangle(t->name()) : ""; + #else + return ""; + #endif + } +} +} + +#endif diff --git a/dep/cpptrace/src/platform/path.hpp b/dep/cpptrace/src/platform/path.hpp new file mode 100644 index 00000000000..f19320ad7b7 --- /dev/null +++ b/dep/cpptrace/src/platform/path.hpp @@ -0,0 +1,42 @@ +#ifndef PATH_HPP +#define PATH_HPP + +#include "utils/common.hpp" +#include "platform/platform.hpp" + +#if IS_WINDOWS +#include +#endif + +namespace cpptrace { +namespace detail { + #if IS_WINDOWS + constexpr char PATH_SEP = '\\'; + inline bool is_absolute(const std::string& path) { + // I don't want to bring in shlwapi as a dependency just for PathIsRelativeA so I'm following the guidance of + // https://stackoverflow.com/a/71941552/15675011 and + // https://github.com/wine-mirror/wine/blob/b210a204137dec8d2126ca909d762454fd47e963/dlls/kernelbase/path.c#L982 + if(path.empty() || IsDBCSLeadByte(path[0])) { + return false; + } + if(path[0] == '\\') { + return true; + } + if(path.size() >= 2 && std::isalpha(path[0]) && path[1] == ':') { + return true; + } + return false; + } + #else + constexpr char PATH_SEP = '/'; + inline bool is_absolute(const std::string& path) { + if(path.empty()) { + return false; + } + return path[0] == '/'; + } + #endif +} +} + +#endif diff --git a/dep/cpptrace/src/platform/platform.hpp b/dep/cpptrace/src/platform/platform.hpp new file mode 100644 index 00000000000..89206e826c3 --- /dev/null +++ b/dep/cpptrace/src/platform/platform.hpp @@ -0,0 +1,48 @@ +#ifndef PLATFORM_HPP +#define PLATFORM_HPP + +#define IS_WINDOWS 0 +#define IS_LINUX 0 +#define IS_APPLE 0 + +#if defined(_WIN32) + #undef IS_WINDOWS + #define IS_WINDOWS 1 +#elif defined(__linux) + #undef IS_LINUX + #define IS_LINUX 1 +#elif defined(__APPLE__) + #undef IS_APPLE + #define IS_APPLE 1 +#else + #error "Unexpected platform" +#endif + +#define IS_CLANG 0 +#define IS_GCC 0 +#define IS_MSVC 0 + +#if defined(__clang__) + #undef IS_CLANG + #define IS_CLANG 1 +#elif defined(__GNUC__) || defined(__GNUG__) + #undef IS_GCC + #define IS_GCC 1 +#elif defined(_MSC_VER) + #undef IS_MSVC + #define IS_MSVC 1 +#else + #error "Unsupported compiler" +#endif + +#define IS_LIBSTDCXX 0 +#define IS_LIBCXX 0 +#if defined(__GLIBCXX__) || defined(__GLIBCPP__) +#undef IS_LIBSTDCXX +#define IS_LIBSTDCXX 1 +#elif defined(_LIBCPP_VERSION) +#undef IS_LIBCXX +#define IS_LIBCXX 1 +#endif + +#endif diff --git a/dep/cpptrace/src/platform/program_name.hpp b/dep/cpptrace/src/platform/program_name.hpp new file mode 100644 index 00000000000..a51f677a243 --- /dev/null +++ b/dep/cpptrace/src/platform/program_name.hpp @@ -0,0 +1,100 @@ +#ifndef PROGRAM_NAME_HPP +#define PROGRAM_NAME_HPP + +#include +#include + +#include "platform/platform.hpp" + +#if IS_WINDOWS +#include + +#define CPPTRACE_MAX_PATH MAX_PATH + +namespace cpptrace { +namespace detail { + inline const char* program_name() { + static std::mutex mutex; + const std::lock_guard lock(mutex); + static std::string name; + static bool did_init = false; + static bool valid = false; + if(!did_init) { + did_init = true; + char buffer[MAX_PATH + 1]; + int res = GetModuleFileNameA(nullptr, buffer, MAX_PATH); + if(res) { + name = buffer; + valid = true; + } + } + return valid && !name.empty() ? name.c_str() : nullptr; + } +} +} + +#elif IS_APPLE + +#include +#include +#include + +#define CPPTRACE_MAX_PATH CPPTRACE_PATH_MAX + +namespace cpptrace { +namespace detail { + inline const char* program_name() { + static std::mutex mutex; + const std::lock_guard lock(mutex); + static std::string name; + static bool did_init = false; + static bool valid = false; + if(!did_init) { + did_init = true; + char buffer[CPPTRACE_PATH_MAX + 1]; + std::uint32_t bufferSize = sizeof buffer; + if(_NSGetExecutablePath(buffer, &bufferSize) == 0) { + name.assign(buffer, bufferSize); + valid = true; + } + } + return valid && !name.empty() ? name.c_str() : nullptr; + } +} +} + +#elif IS_LINUX + +#include +#include +#include + +#define CPPTRACE_MAX_PATH CPPTRACE_PATH_MAX + +namespace cpptrace { +namespace detail { + inline const char* program_name() { + static std::mutex mutex; + const std::lock_guard lock(mutex); + static std::string name; + static bool did_init = false; + static bool valid = false; + if(!did_init) { + did_init = true; + char buffer[CPPTRACE_PATH_MAX + 1]; + const ssize_t size = readlink("/proc/self/exe", buffer, CPPTRACE_PATH_MAX); + if(size == -1) { + return nullptr; + } + buffer[size] = 0; + name = buffer; + valid = true; + } + return valid && !name.empty() ? name.c_str() : nullptr; + } +} +} + +#endif + +#endif diff --git a/dep/cpptrace/src/snippets/snippet.cpp b/dep/cpptrace/src/snippets/snippet.cpp new file mode 100644 index 00000000000..ffa19d3d5f5 --- /dev/null +++ b/dep/cpptrace/src/snippets/snippet.cpp @@ -0,0 +1,142 @@ +#include "snippets/snippet.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +namespace cpptrace { +namespace detail { + constexpr std::int64_t max_size = 1024 * 1024 * 10; // 10 MiB + + struct line_range { + std::size_t begin; + std::size_t end; // one past the end + }; + + class snippet_manager { + bool loaded_contents; + std::string contents; + // 1-based indexing + std::vector line_table; + public: + snippet_manager(const std::string& path) : loaded_contents(false) { + std::ifstream file; + try { + file.open(path, std::ios::ate); + if(file.is_open()) { + std::ifstream::pos_type size = file.tellg(); + if(size == std::ifstream::pos_type(-1) || size > max_size) { + return; + } + // else load file + file.seekg(0, std::ios::beg); + contents.resize(to(size)); + if(!file.read(&contents[0], size)) { + // error ... + } + build_line_table(); + loaded_contents = true; + } + } catch(const std::ifstream::failure&) { + // ... + } + } + + // takes a 1-index line + std::string get_line(std::size_t line) const { + if(!loaded_contents || line > line_table.size()) { + return ""; + } else { + return contents.substr(line_table[line].begin, line_table[line].end - line_table[line].begin); + } + } + + std::size_t num_lines() const { + return line_table.size(); + } + + bool ok() const { + return loaded_contents; + } + private: + void build_line_table() { + line_table.push_back({0, 0}); + std::size_t pos = 0; // stores the start of the current line + while(true) { + // find the end of the current line + std::size_t terminator_pos = contents.find('\n', pos); + if(terminator_pos == std::string::npos) { + line_table.push_back({pos, contents.size()}); + break; + } else { + std::size_t end_pos = terminator_pos; // one past the end of the current line + if(end_pos > 0 && contents[end_pos - 1] == '\r') { + end_pos--; + } + line_table.push_back({pos, end_pos}); + pos = terminator_pos + 1; + } + } + } + }; + + std::mutex snippet_manager_mutex; + std::unordered_map snippet_managers; + + const snippet_manager& get_manager(const std::string& path) { + std::unique_lock lock(snippet_manager_mutex); + auto it = snippet_managers.find(path); + if(it == snippet_managers.end()) { + return snippet_managers.insert({path, snippet_manager(path)}).first->second; + } else { + return it->second; + } + } + + // how wide the margin for the line number should be + constexpr std::size_t margin_width = 8; + + // 1-indexed line + std::string get_snippet(const std::string& path, std::size_t target_line, std::size_t context_size, bool color) { + const auto& manager = get_manager(path); + if(!manager.ok()) { + return ""; + } + auto begin = target_line <= context_size + 1 ? 1 : target_line - context_size; + auto original_begin = begin; + auto end = std::min(target_line + context_size, manager.num_lines() - 1); + std::vector lines; + for(auto line = begin; line <= end; line++) { + lines.push_back(manager.get_line(line)); + } + // trim blank lines + while(begin < target_line && lines[begin - original_begin].empty()) { + begin++; + } + while(end > target_line && lines[end - original_begin].empty()) { + end--; + } + // make the snippet + std::string snippet; + for(auto line = begin; line <= end; line++) { + if(color && line == target_line) { + snippet += YELLOW; + } + auto line_str = std::to_string(line); + snippet += microfmt::format("{>{}}: ", margin_width, line_str); + if(color && line == target_line) { + snippet += RESET; + } + snippet += lines[line - original_begin] + "\n"; + } + return snippet; + } +} +} diff --git a/dep/cpptrace/src/snippets/snippet.hpp b/dep/cpptrace/src/snippets/snippet.hpp new file mode 100644 index 00000000000..cad7cc582c6 --- /dev/null +++ b/dep/cpptrace/src/snippets/snippet.hpp @@ -0,0 +1,14 @@ +#ifndef SNIPPET_HPP +#define SNIPPET_HPP + +#include +#include + +namespace cpptrace { +namespace detail { + // 1-indexed line + std::string get_snippet(const std::string& path, std::size_t line, std::size_t context_size, bool color); +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/dwarf/debug_map_resolver.cpp b/dep/cpptrace/src/symbols/dwarf/debug_map_resolver.cpp new file mode 100644 index 00000000000..1bc6a6ea41d --- /dev/null +++ b/dep/cpptrace/src/symbols/dwarf/debug_map_resolver.cpp @@ -0,0 +1,207 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF + +#include "symbols/dwarf/resolver.hpp" + +#include +#include "symbols/symbols.hpp" +#include "utils/common.hpp" +#include "utils/error.hpp" +#include "binary/object.hpp" +#include "binary/mach-o.hpp" +#include "utils/utils.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace cpptrace { +namespace detail { +namespace libdwarf { + #if IS_APPLE + struct target_object { + std::string object_path; + bool path_ok = true; + optional> symbols; + std::unique_ptr resolver; + + target_object(std::string object_path) : object_path(std::move(object_path)) {} + + std::unique_ptr& get_resolver() { + if(!resolver) { + // this seems silly but it's an attempt to not repeatedly try to initialize new dwarf_resolvers if + // exceptions are thrown, e.g. if the path doesn't exist + resolver = std::unique_ptr(new null_resolver); + resolver = make_dwarf_resolver(object_path); + } + return resolver; + } + + std::unordered_map& get_symbols() { + if(!symbols) { + // this is an attempt to not repeatedly try to reprocess mach-o files if exceptions are thrown, e.g. if + // the path doesn't exist + std::unordered_map symbols; + this->symbols = symbols; + auto obj = mach_o::open_mach_o(object_path); + if(!obj) { + return this->symbols.unwrap(); + } + auto symbol_table = obj.unwrap_value().symbol_table(); + if(!symbol_table) { + return this->symbols.unwrap(); + } + for(const auto& symbol : symbol_table.unwrap_value()) { + symbols[symbol.name] = symbol.address; + } + this->symbols = std::move(symbols); + } + return symbols.unwrap(); + } + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + frame_with_inlines resolve_frame( + const object_frame& frame_info, + const std::string& symbol_name, + std::size_t offset + ) { + const auto& symbol_table = get_symbols(); + auto it = symbol_table.find(symbol_name); + if(it != symbol_table.end()) { + auto frame = frame_info; + // substitute a translated address object for the target file in + frame.object_address = it->second + offset; + auto res = get_resolver()->resolve_frame(frame); + // replace the translated address with the object address in the binary + res.frame.object_address = frame_info.object_address; + return res; + } else { + return { + { + frame_info.raw_address, + frame_info.object_address, + nullable::null(), + nullable::null(), + frame_info.object_path, + symbol_name, + false + }, + {} + }; + } + } + }; + + struct debug_map_symbol_info { + uint64_t source_address; + uint64_t size; + std::string name; + nullable target_address; // T(-1) is used as a sentinel + std::size_t object_index; + }; + + class debug_map_resolver : public symbol_resolver { + std::vector target_objects; + std::vector symbols; + public: + debug_map_resolver(const std::string& source_object_path) { + // load mach-o + // TODO: Cache somehow? + auto obj = mach_o::open_mach_o(source_object_path); + if(!obj) { + return; + } + mach_o& source_mach = obj.unwrap_value(); + auto source_debug_map = source_mach.get_debug_map(); + if(!source_debug_map) { + return; + } + // get symbol entries from debug map, as well as the various object files used to make this binary + for(auto& entry : source_debug_map.unwrap_value()) { + // object it came from + target_objects.push_back({entry.first}); + // push the symbols + auto& map_entry_symbols = entry.second; + symbols.reserve(symbols.size() + map_entry_symbols.size()); + for(auto& symbol : map_entry_symbols) { + symbols.push_back({ + symbol.source_address, + symbol.size, + std::move(symbol.name), + nullable::null(), + target_objects.size() - 1 + }); + } + } + // sort for binary lookup later + std::sort( + symbols.begin(), + symbols.end(), + [] ( + const debug_map_symbol_info& a, + const debug_map_symbol_info& b + ) { + return a.source_address < b.source_address; + } + ); + } + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + frame_with_inlines resolve_frame(const object_frame& frame_info) override { + // resolve object frame: + // find the symbol in this executable corresponding to the object address + // resolve the symbol in the object it came from, based on the symbol name + auto closest_symbol_it = first_less_than_or_equal( + symbols.begin(), + symbols.end(), + frame_info.object_address, + [] ( + uint64_t pc, + const debug_map_symbol_info& symbol + ) { + return pc < symbol.source_address; + } + ); + if(closest_symbol_it != symbols.end()) { + if(frame_info.object_address <= closest_symbol_it->source_address + closest_symbol_it->size) { + return target_objects[closest_symbol_it->object_index].resolve_frame( + { + frame_info.raw_address, + // the resolver doesn't care about the object address here, only the offset from the start + // of the symbol and it'll lookup the symbol's base-address + frame_info.object_address, + frame_info.object_path + }, + closest_symbol_it->name, + frame_info.object_address - closest_symbol_it->source_address + ); + } + } + // There was either no closest symbol or the closest symbol didn't end up containing the address we're + // looking for, so just return a blank frame + return { + { + frame_info.raw_address, + frame_info.object_address, + nullable::null(), + nullable::null(), + frame_info.object_path, + "", + false + }, + {} + }; + }; + }; + + std::unique_ptr make_debug_map_resolver(const std::string& object_path) { + return std::unique_ptr(new debug_map_resolver(object_path)); + } + #endif +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/dwarf/dwarf.hpp b/dep/cpptrace/src/symbols/dwarf/dwarf.hpp new file mode 100644 index 00000000000..073bae4cd88 --- /dev/null +++ b/dep/cpptrace/src/symbols/dwarf/dwarf.hpp @@ -0,0 +1,539 @@ +#ifndef DWARF_HPP +#define DWARF_HPP + +#include +#include "utils/error.hpp" +#include "utils/utils.hpp" + +#include +#include +#include + +#ifdef CPPTRACE_USE_NESTED_LIBDWARF_HEADER_PATH + #include + #include +#else + #include + #include +#endif + +namespace cpptrace { +namespace detail { +namespace libdwarf { + static_assert(std::is_pointer::value, "Dwarf_Die not a pointer"); + static_assert(std::is_pointer::value, "Dwarf_Debug not a pointer"); + + using rangelist_entries = std::vector>; + + [[noreturn]] inline void handle_dwarf_error(Dwarf_Debug dbg, Dwarf_Error error) { + Dwarf_Unsigned ev = dwarf_errno(error); + char* msg = dwarf_errmsg(error); + (void)dbg; + // dwarf_dealloc_error(dbg, error); + throw internal_error("dwarf error {} {}", ev, msg); + } + + struct die_object { + Dwarf_Debug dbg = nullptr; + Dwarf_Die die = nullptr; + + // Error handling helper + // For some reason R (*f)(Args..., void*)-style deduction isn't possible, seems like a bug in all compilers + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=56190 + template< + typename... Args, + typename... Args2, + typename std::enable_if< + std::is_same< + decltype( + (void)std::declval()(std::forward(std::declval())..., nullptr) + ), + void + >::value, + int + >::type = 0 + > + int wrap(int (*f)(Args...), Args2&&... args) const { + Dwarf_Error error = nullptr; + int ret = f(std::forward(args)..., &error); + if(ret == DW_DLV_ERROR) { + handle_dwarf_error(dbg, error); + } + return ret; + } + + die_object(Dwarf_Debug dbg, Dwarf_Die die) : dbg(dbg), die(die) { + ASSERT(dbg != nullptr); + } + + ~die_object() { + if(die) { + dwarf_dealloc_die(die); + } + } + + die_object(const die_object&) = delete; + + die_object& operator=(const die_object&) = delete; + + die_object(die_object&& other) noexcept : dbg(other.dbg), die(other.die) { + // done for finding mistakes, attempts to use the die_object after this should segfault + // a valid use otherwise would be moved_from.get_sibling() which would get the next CU + other.dbg = nullptr; + other.die = nullptr; + } + + die_object& operator=(die_object&& other) noexcept { + std::swap(dbg, other.dbg); + std::swap(die, other.die); + return *this; + } + + die_object clone() const { + Dwarf_Off global_offset = get_global_offset(); + Dwarf_Bool is_info = dwarf_get_die_infotypes_flag(die); + Dwarf_Die die_copy = nullptr; + VERIFY(wrap(dwarf_offdie_b, dbg, global_offset, is_info, &die_copy) == DW_DLV_OK); + return {dbg, die_copy}; + } + + die_object get_child() const { + Dwarf_Die child = nullptr; + int ret = wrap(dwarf_child, die, &child); + if(ret == DW_DLV_OK) { + return die_object(dbg, child); + } else if(ret == DW_DLV_NO_ENTRY) { + return die_object(dbg, nullptr); + } else { + PANIC(); + } + } + + die_object get_sibling() const { + Dwarf_Die sibling = nullptr; + int ret = wrap(dwarf_siblingof_b, dbg, die, true, &sibling); + if(ret == DW_DLV_OK) { + return die_object(dbg, sibling); + } else if(ret == DW_DLV_NO_ENTRY) { + return die_object(dbg, nullptr); + } else { + PANIC(); + } + } + + operator bool() const { + return die != nullptr; + } + + Dwarf_Die get() const { + return die; + } + + std::string get_name() const { + char empty[] = ""; + char* name = empty; + int ret = wrap(dwarf_diename, die, &name); + auto wrapper = raii_wrap(name, [this] (char* str) { dwarf_dealloc(dbg, str, DW_DLA_STRING); }); + std::string str; + if(ret != DW_DLV_NO_ENTRY) { + str = name; + } + return name; + } + + optional get_string_attribute(Dwarf_Half attr_num) const { + Dwarf_Attribute attr; + if(wrap(dwarf_attr, die, attr_num, &attr) == DW_DLV_OK) { + auto attwrapper = raii_wrap(attr, [] (Dwarf_Attribute attr) { dwarf_dealloc_attribute(attr); }); + char* raw_str; + VERIFY(wrap(dwarf_formstring, attr, &raw_str) == DW_DLV_OK); + auto strwrapper = raii_wrap(raw_str, [this] (char* str) { dwarf_dealloc(dbg, str, DW_DLA_STRING); }); + std::string str = raw_str; + return str; + } else { + return nullopt; + } + } + + optional get_unsigned_attribute(Dwarf_Half attr_num) const { + Dwarf_Attribute attr; + if(wrap(dwarf_attr, die, attr_num, &attr) == DW_DLV_OK) { + auto attwrapper = raii_wrap(attr, [] (Dwarf_Attribute attr) { dwarf_dealloc_attribute(attr); }); + // Dwarf_Half form = 0; + // VERIFY(wrap(dwarf_whatform, attr, &form) == DW_DLV_OK); + Dwarf_Unsigned val; + VERIFY(wrap(dwarf_formudata, attr, &val) == DW_DLV_OK); + return val; + } else { + return nullopt; + } + } + + bool has_attr(Dwarf_Half attr_num) const { + Dwarf_Bool present = false; + VERIFY(wrap(dwarf_hasattr, die, attr_num, &present) == DW_DLV_OK); + return present; + } + + Dwarf_Half get_tag() const { + Dwarf_Half tag = 0; + VERIFY(wrap(dwarf_tag, die, &tag) == DW_DLV_OK); + return tag; + } + + const char* get_tag_name() const { + const char* tag_name; + if(dwarf_get_TAG_name(get_tag(), &tag_name) == DW_DLV_OK) { + return tag_name; + } else { + return ""; + } + } + + Dwarf_Off get_global_offset() const { + Dwarf_Off off; + VERIFY(wrap(dwarf_dieoffset, die, &off) == DW_DLV_OK); + return off; + } + + die_object resolve_reference_attribute(Dwarf_Half attr_num) const { + Dwarf_Attribute attr; + VERIFY(dwarf_attr(die, attr_num, &attr, nullptr) == DW_DLV_OK); + auto wrapper = raii_wrap(attr, [] (Dwarf_Attribute attr) { dwarf_dealloc_attribute(attr); }); + Dwarf_Half form = 0; + VERIFY(wrap(dwarf_whatform, attr, &form) == DW_DLV_OK); + switch(form) { + case DW_FORM_ref1: + case DW_FORM_ref2: + case DW_FORM_ref4: + case DW_FORM_ref8: + case DW_FORM_ref_udata: + { + Dwarf_Off off = 0; + Dwarf_Bool is_info = dwarf_get_die_infotypes_flag(die); + VERIFY(wrap(dwarf_formref, attr, &off, &is_info) == DW_DLV_OK); + Dwarf_Off global_offset = 0; + VERIFY(wrap(dwarf_convert_to_global_offset, attr, off, &global_offset) == DW_DLV_OK); + Dwarf_Die target = nullptr; + VERIFY(wrap(dwarf_offdie_b, dbg, global_offset, is_info, &target) == DW_DLV_OK); + return die_object(dbg, target); + } + case DW_FORM_ref_addr: + { + Dwarf_Off off; + VERIFY(wrap(dwarf_global_formref, attr, &off) == DW_DLV_OK); + int is_info = dwarf_get_die_infotypes_flag(die); + Dwarf_Die target = nullptr; + VERIFY(wrap(dwarf_offdie_b, dbg, off, is_info, &target) == DW_DLV_OK); + return die_object(dbg, target); + } + case DW_FORM_ref_sig8: + { + Dwarf_Sig8 signature; + VERIFY(wrap(dwarf_formsig8, attr, &signature) == DW_DLV_OK); + Dwarf_Die target = nullptr; + Dwarf_Bool targ_is_info = false; + VERIFY(wrap(dwarf_find_die_given_sig8, dbg, &signature, &target, &targ_is_info) == DW_DLV_OK); + return die_object(dbg, target); + } + default: + PANIC(microfmt::format("unknown form for attribute {} {}\n", attr_num, form)); + } + } + + Dwarf_Unsigned get_ranges_base_address(const die_object& cu_die) const { + // After libdwarf v0.11.0 this can use dwarf_get_ranges_baseaddress, however, in the interest of not + // requiring v0.11.0 just yet the logic is implemented here too. + // The base address is: + // - If the die has a rangelist, use the low_pc for that die + // - Otherwise use the low_pc from the CU if present + // - Otherwise 0 + if(has_attr(DW_AT_ranges)) { + if(has_attr(DW_AT_low_pc)) { + Dwarf_Addr lowpc; + if(wrap(dwarf_lowpc, die, &lowpc) == DW_DLV_OK) { + return lowpc; + } + } + } + if(cu_die.has_attr(DW_AT_low_pc)) { + Dwarf_Addr lowpc; + if(wrap(dwarf_lowpc, cu_die.get(), &lowpc) == DW_DLV_OK) { + return lowpc; + } + } + return 0; + } + + Dwarf_Unsigned get_ranges_offset(Dwarf_Attribute attr) const { + Dwarf_Unsigned off = 0; + Dwarf_Half form = 0; + VERIFY(wrap(dwarf_whatform, attr, &form) == DW_DLV_OK); + if (form == DW_FORM_rnglistx) { + VERIFY(wrap(dwarf_formudata, attr, &off) == DW_DLV_OK); + } else { + VERIFY(wrap(dwarf_global_formref, attr, &off) == DW_DLV_OK); + } + return off; + } + + template + // callback should return true to keep going + void dwarf5_ranges(F callback) const { + Dwarf_Attribute attr = nullptr; + if(wrap(dwarf_attr, die, DW_AT_ranges, &attr) != DW_DLV_OK) { + return; + } + auto attrwrapper = raii_wrap(attr, [] (Dwarf_Attribute attr) { dwarf_dealloc_attribute(attr); }); + Dwarf_Unsigned offset = get_ranges_offset(attr); + Dwarf_Half form = 0; + VERIFY(wrap(dwarf_whatform, attr, &form) == DW_DLV_OK); + // get .debug_rnglists info + Dwarf_Rnglists_Head head = nullptr; + Dwarf_Unsigned rnglists_entries = 0; + Dwarf_Unsigned dw_global_offset_of_rle_set = 0; + int res = wrap( + dwarf_rnglists_get_rle_head, + attr, + form, + offset, + &head, + &rnglists_entries, + &dw_global_offset_of_rle_set + ); + auto headwrapper = raii_wrap(head, [] (Dwarf_Rnglists_Head head) { dwarf_dealloc_rnglists_head(head); }); + if(res == DW_DLV_NO_ENTRY) { + return; + } + VERIFY(res == DW_DLV_OK); + for(std::size_t i = 0 ; i < rnglists_entries; i++) { + unsigned entrylen = 0; + unsigned rle_value_out = 0; + Dwarf_Unsigned raw1 = 0; + Dwarf_Unsigned raw2 = 0; + Dwarf_Bool unavailable = 0; + Dwarf_Unsigned cooked1 = 0; + Dwarf_Unsigned cooked2 = 0; + res = wrap( + dwarf_get_rnglists_entry_fields_a, + head, + i, + &entrylen, + &rle_value_out, + &raw1, + &raw2, + &unavailable, + &cooked1, + &cooked2 + ); + if(res == DW_DLV_NO_ENTRY) { + continue; + } + VERIFY(res == DW_DLV_OK); + if(unavailable) { + continue; + } + switch(rle_value_out) { + // Following the same scheme from libdwarf-addr2line + case DW_RLE_end_of_list: + case DW_RLE_base_address: + case DW_RLE_base_addressx: + // Already handled + break; + case DW_RLE_offset_pair: + case DW_RLE_startx_endx: + case DW_RLE_start_end: + case DW_RLE_startx_length: + case DW_RLE_start_length: + if(!callback(cooked1, cooked2)) { + return; + } + break; + default: + PANIC("Something is wrong"); + break; + } + } + } + + template + // callback should return true to keep going + void dwarf4_ranges(Dwarf_Addr baseaddr, F callback) const { + Dwarf_Attribute attr = nullptr; + if(wrap(dwarf_attr, die, DW_AT_ranges, &attr) != DW_DLV_OK) { + return; + } + auto attrwrapper = raii_wrap(attr, [] (Dwarf_Attribute attr) { dwarf_dealloc_attribute(attr); }); + Dwarf_Unsigned offset; + if(wrap(dwarf_global_formref, attr, &offset) != DW_DLV_OK) { + return; + } + Dwarf_Addr baseaddr_original = baseaddr; + Dwarf_Ranges* ranges = nullptr; + Dwarf_Signed count = 0; + VERIFY( + wrap( + dwarf_get_ranges_b, + dbg, + offset, + die, + nullptr, + &ranges, + &count, + nullptr + ) == DW_DLV_OK + ); + auto rangeswrapper = raii_wrap( + ranges, + [this, count] (Dwarf_Ranges* ranges) { dwarf_dealloc_ranges(dbg, ranges, count); } + ); + for(int i = 0; i < count; i++) { + if(ranges[i].dwr_type == DW_RANGES_ENTRY) { + if(!callback(baseaddr + ranges[i].dwr_addr1, baseaddr + ranges[i].dwr_addr2)) { + return; + } + } else if(ranges[i].dwr_type == DW_RANGES_ADDRESS_SELECTION) { + baseaddr = ranges[i].dwr_addr2; + } else { + ASSERT(ranges[i].dwr_type == DW_RANGES_END); + baseaddr = baseaddr_original; + } + } + } + + template + // callback should return true to keep going + void dwarf_ranges(const die_object& cu_die, int version, F callback) const { + Dwarf_Addr lowpc; + if(wrap(dwarf_lowpc, die, &lowpc) == DW_DLV_OK) { + Dwarf_Addr highpc = 0; + enum Dwarf_Form_Class return_class; + if(wrap(dwarf_highpc_b, die, &highpc, nullptr, &return_class) == DW_DLV_OK) { + if(return_class == DW_FORM_CLASS_CONSTANT) { + highpc += lowpc; + } + if(!callback(lowpc, highpc)) { + return; + } + } + } + if(version >= 5) { + dwarf5_ranges(callback); + } else { + dwarf4_ranges(get_ranges_base_address(cu_die), callback); + } + } + + rangelist_entries get_rangelist_entries(const die_object& cu_die, int version) const { + rangelist_entries vec; + dwarf_ranges(cu_die, version, [&vec] (Dwarf_Addr low, Dwarf_Addr high) { + // Simple coalescing optimization: + // Sometimes the range list entries are really continuous: [100, 200), [200, 300) + // Other times there's just one byte of separation [300, 399), [400, 500) + // Those are the main two cases I've observed. + // This will not catch all cases, presumably, as the range lists aren't sorted. But compilers/linkers + // seem to like to emit the ranges in sorted order. + if(!vec.empty() && low - vec.back().second <= 1) { + vec.back().second = high; + } else { + vec.push_back({low, high}); + } + return true; + }); + return vec; + } + + Dwarf_Bool pc_in_die(const die_object& cu_die, int version, Dwarf_Addr pc) const { + bool found = false; + dwarf_ranges(cu_die, version, [&found, pc] (Dwarf_Addr low, Dwarf_Addr high) { + if(pc >= low && pc < high) { + found = true; + return false; + } + return true; + }); + return found; + } + + void print() const { + std::fprintf( + stderr, + "%08llx %s %s\n", + to_ull(get_global_offset()), + get_tag_name(), + get_name().c_str() + ); + } + }; + + // walk die list, callback is called on each die and should return true to + // continue traversal + // returns true if traversal should continue + inline bool walk_die_list( + const die_object& die, + const std::function& fn + ) { + // TODO: Refactor so there is only one fn call + bool continue_traversal = true; + if(fn(die)) { + die_object current = die.get_sibling(); + while(current) { + if(fn(current)) { + current = current.get_sibling(); + } else { + continue_traversal = false; + break; + } + } + } + return continue_traversal; + } + + // walk die list, recursing into children, callback is called on each die + // and should return true to continue traversal + // returns true if traversal should continue + inline bool walk_die_list_recursive( + const die_object& die, + const std::function& fn + ) { + return walk_die_list( + die, + [&fn](const die_object& die) { + auto child = die.get_child(); + if(child) { + if(!walk_die_list_recursive(child, fn)) { + return false; + } + } + return fn(die); + } + ); + } + + class maybe_owned_die_object { + // Hacky... I wish std::variant existed. + optional owned_die; + optional ref_die; + maybe_owned_die_object(die_object&& die) : owned_die(std::move(die)) {} + maybe_owned_die_object(const die_object& die) : ref_die(&die) {} + public: + static maybe_owned_die_object owned(die_object&& die) { + return maybe_owned_die_object{std::move(die)}; + } + static maybe_owned_die_object ref(const die_object& die) { + return maybe_owned_die_object{die}; + } + const die_object& get() { + ASSERT(owned_die || ref_die, "Mal-formed maybe_owned_die_object"); + if(owned_die) { + return owned_die.unwrap(); + } else { + return *ref_die.unwrap(); + } + } + }; +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/dwarf/dwarf_resolver.cpp b/dep/cpptrace/src/symbols/dwarf/dwarf_resolver.cpp new file mode 100644 index 00000000000..76edc0efbbd --- /dev/null +++ b/dep/cpptrace/src/symbols/dwarf/dwarf_resolver.cpp @@ -0,0 +1,1086 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF + +#include "symbols/dwarf/resolver.hpp" + +#include +#include "symbols/dwarf/dwarf.hpp" // has dwarf #includes +#include "symbols/symbols.hpp" +#include "utils/common.hpp" +#include "utils/error.hpp" +#include "utils/utils.hpp" +#include "platform/path.hpp" +#include "platform/program_name.hpp" // For CPPTRACE_MAX_PATH +#include "binary/mach-o.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// It's been tricky to piece together how to handle all this dwarf stuff. Some resources I've used are +// https://www.prevanders.net/libdwarf.pdf +// https://github.com/davea42/libdwarf-addr2line +// https://github.com/ruby/ruby/blob/master/addr2line.c + +namespace cpptrace { +namespace detail { +namespace libdwarf { + // printbugging as we go + constexpr bool dump_dwarf = false; + constexpr bool trace_dwarf = false; + + struct subprogram_entry { + die_object die; + Dwarf_Addr low; + Dwarf_Addr high; + }; + + struct cu_entry { + die_object die; + Dwarf_Half dwversion; + Dwarf_Addr low; + Dwarf_Addr high; + }; + + struct line_entry { + Dwarf_Addr low; + // Dwarf_Addr high; + // int i; + Dwarf_Line line; + optional path; + optional line_number; + optional column_number; + line_entry(Dwarf_Addr low, Dwarf_Line line) : low(low), line(line) {} + }; + + struct line_table_info { + Dwarf_Unsigned version; + Dwarf_Line_Context line_context; + // sorted by low_addr + // TODO: Make this optional at some point, it may not be generated if cache mode switches during program exec... + std::vector line_entries; + }; + + class dwarf_resolver; + + // used to describe data from an upstream binary to a resolver for the .dwo + struct skeleton_info { + die_object cu_die; + Dwarf_Half dwversion; + dwarf_resolver& resolver; + }; + + class dwarf_resolver : public symbol_resolver { + std::string object_path; + Dwarf_Debug dbg = nullptr; + bool ok = false; + // .debug_aranges cache + Dwarf_Arange* aranges = nullptr; + Dwarf_Signed arange_count = 0; + // Map from CU -> Line context + std::unordered_map line_tables; + // Map from CU -> Sorted subprograms vector + std::unordered_map> subprograms_cache; + // Vector of ranges and their corresponding CU offsets + std::vector cu_cache; + bool generated_cu_cache = false; + // Map from CU -> {srcfiles, count} + std::unordered_map> srcfiles_cache; + // Map from CU -> split full cu resolver + std::unordered_map> split_full_cu_resolvers; + // info for resolving a dwo object + optional skeleton; + + private: + // Error handling helper + // For some reason R (*f)(Args..., void*)-style deduction isn't possible, seems like a bug in all compilers + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=56190 + template< + typename... Args, + typename... Args2, + typename std::enable_if< + std::is_same< + decltype( + (void)std::declval()(std::forward(std::declval())..., nullptr) + ), + void + >::value, + int + >::type = 0 + > + int wrap(int (*f)(Args...), Args2&&... args) const { + Dwarf_Error error = nullptr; + int ret = f(std::forward(args)..., &error); + if(ret == DW_DLV_ERROR) { + handle_dwarf_error(dbg, error); + } + return ret; + } + + public: + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + dwarf_resolver(const std::string& object_path_, optional split_ = nullopt) + : object_path(object_path_), + skeleton(std::move(split_)) + { + // use a buffer when invoking dwarf_init_path, which allows it to automatically find debuglink or dSYM + // sources + bool use_buffer = true; + // for universal / fat mach-o files + unsigned universal_number = 0; + #if IS_APPLE + if(directory_exists(object_path + ".dSYM")) { + // Possibly depends on the build system but a obj.cpp.o.dSYM/Contents/Resources/DWARF/obj.cpp.o can be + // created alongside .o files. These are text files containing directives, as opposed to something we + // can actually use + std::string dsym_resource = object_path + ".dSYM/Contents/Resources/DWARF/" + basename(object_path); + if(file_is_mach_o(dsym_resource)) { + object_path = std::move(dsym_resource); + } + use_buffer = false; // we resolved dSYM above as appropriate + } + auto result = macho_is_fat(object_path); + if(result.is_error()) { + result.drop_error(); + } else if(result.unwrap_value()) { + auto obj = mach_o::open_mach_o(object_path); + if(!obj) { + ok = false; + return; + } + universal_number = obj.unwrap_value().get_fat_index(); + } + #endif + + // Giving libdwarf a buffer for a true output path is needed for its automatic resolution of debuglink and + // dSYM files. We don't utilize the dSYM logic here, we just care about debuglink. + std::unique_ptr buffer; + if(use_buffer) { + buffer = std::unique_ptr(new char[CPPTRACE_MAX_PATH]); + } + auto ret = wrap( + dwarf_init_path_a, + object_path.c_str(), + buffer.get(), + CPPTRACE_MAX_PATH, + DW_GROUPNUMBER_ANY, + universal_number, + nullptr, + nullptr, + &dbg + ); + if(ret == DW_DLV_OK) { + ok = true; + } else if(ret == DW_DLV_NO_ENTRY) { + // fail, no debug info + ok = false; + } else { + ok = false; + PANIC("Unknown return code from dwarf_init_path"); + } + + if(skeleton) { + VERIFY(wrap(dwarf_set_tied_dbg, dbg, skeleton.unwrap().resolver.dbg) == DW_DLV_OK); + } + + if(ok) { + // Check for .debug_aranges for fast lookup + wrap(dwarf_get_aranges, dbg, &aranges, &arange_count); + } + } + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + ~dwarf_resolver() override { + // TODO: Maybe redundant since dwarf_finish(dbg); will clean up the line stuff anyway but may as well just + // for thoroughness + for(auto& entry : line_tables) { + dwarf_srclines_dealloc_b(entry.second.line_context); + } + for(auto& entry : srcfiles_cache) { + dwarf_dealloc(dbg, entry.second.first, DW_DLA_LIST); + } + // subprograms_cache needs to be destroyed before dbg otherwise there will be another use after free + subprograms_cache.clear(); + split_full_cu_resolvers.clear(); + skeleton.reset(); + if(aranges) { + dwarf_dealloc(dbg, aranges, DW_DLA_LIST); + } + cu_cache.clear(); + dwarf_finish(dbg); + } + + dwarf_resolver(const dwarf_resolver&) = delete; + dwarf_resolver& operator=(const dwarf_resolver&) = delete; + dwarf_resolver(dwarf_resolver&&) = delete; + dwarf_resolver& operator=(dwarf_resolver&&) = delete; + + private: + // walk all CU's in a dbg, callback is called on each die and should return true to + // continue traversal + void walk_compilation_units(const std::function& fn) { + // libdwarf keeps track of where it is in the file, dwarf_next_cu_header_d is statefull + Dwarf_Unsigned next_cu_header; + Dwarf_Half header_cu_type; + while(true) { + int ret = wrap( + dwarf_next_cu_header_d, + dbg, + true, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + &next_cu_header, + &header_cu_type + ); + if(ret == DW_DLV_NO_ENTRY) { + if(dump_dwarf) { + std::fprintf(stderr, "End walk_dbg\n"); + } + return; + } + if(ret != DW_DLV_OK) { + PANIC("Unexpected return code from dwarf_next_cu_header_d"); + return; + } + // 0 passed as the die to the first call of dwarf_siblingof_b immediately after dwarf_next_cu_header_d + // to fetch the cu die + die_object cu_die(dbg, nullptr); + cu_die = cu_die.get_sibling(); + if(!cu_die) { + break; + } + if(!walk_die_list(cu_die, fn)) { + break; + } + } + if(dump_dwarf) { + std::fprintf(stderr, "End walk_compilation_units\n"); + } + } + + void lazy_generate_cu_cache() { + if(!generated_cu_cache) { + walk_compilation_units([this] (const die_object& cu_die) { + Dwarf_Half offset_size = 0; + Dwarf_Half dwversion = 0; + dwarf_get_version_of_die(cu_die.get(), &dwversion, &offset_size); + if(skeleton) { + // NOTE: If we have a corresponding skeleton, we assume we have one CU matching the skeleton CU + // Precedence for this assumption is https://dwarfstd.org/doc/DWARF5.pdf#subsection.3.1.3 + // TODO: Also assuming same dwversion + const auto& skeleton_cu = skeleton.unwrap().cu_die; + auto ranges_vec = skeleton_cu.get_rangelist_entries(skeleton_cu, dwversion); + for(auto range : ranges_vec) { + // TODO: Reduce cloning here + cu_cache.push_back({ cu_die.clone(), dwversion, range.first, range.second }); + } + return false; + } else { + auto ranges_vec = cu_die.get_rangelist_entries(cu_die, dwversion); + for(auto range : ranges_vec) { + // TODO: Reduce cloning here + cu_cache.push_back({ cu_die.clone(), dwversion, range.first, range.second }); + } + return true; + } + }); + std::sort(cu_cache.begin(), cu_cache.end(), [] (const cu_entry& a, const cu_entry& b) { + return a.low < b.low; + }); + generated_cu_cache = true; + } + } + + std::string subprogram_symbol( + const die_object& die, + Dwarf_Half dwversion + ) { + ASSERT(die.get_tag() == DW_TAG_subprogram || die.get_tag() == DW_TAG_inlined_subroutine); + optional name; + if(auto linkage_name = die.get_string_attribute(DW_AT_linkage_name)) { + name = std::move(linkage_name); + } else if(auto linkage_name = die.get_string_attribute(DW_AT_MIPS_linkage_name)) { + name = std::move(linkage_name); + } else if(auto linkage_name = die.get_string_attribute(DW_AT_name)) { + name = std::move(linkage_name); + } + if(name.has_value()) { + return std::move(name).unwrap(); + } else { + if(die.has_attr(DW_AT_specification)) { + die_object spec = die.resolve_reference_attribute(DW_AT_specification); + return subprogram_symbol(spec, dwversion); + } else if(die.has_attr(DW_AT_abstract_origin)) { + die_object spec = die.resolve_reference_attribute(DW_AT_abstract_origin); + return subprogram_symbol(spec, dwversion); + } + } + return ""; + } + + // despite (some) dwarf using 1-indexing, file_i should be the 0-based index + std::string resolve_filename(const die_object& cu_die, Dwarf_Unsigned file_i) { + // for split-dwarf line resolution happens in the skeleton + if(skeleton) { + return skeleton.unwrap().resolver.resolve_filename(skeleton.unwrap().cu_die, file_i); + } + std::string filename; + if(get_cache_mode() == cache_mode::prioritize_memory) { + char** dw_srcfiles; + Dwarf_Signed dw_filecount; + VERIFY(wrap(dwarf_srcfiles, cu_die.get(), &dw_srcfiles, &dw_filecount) == DW_DLV_OK); + if(Dwarf_Signed(file_i) < dw_filecount) { + // dwarf is using 1-indexing + filename = dw_srcfiles[file_i]; + } + dwarf_dealloc(cu_die.dbg, dw_srcfiles, DW_DLA_LIST); + } else { + auto off = cu_die.get_global_offset(); + auto it = srcfiles_cache.find(off); + if(it == srcfiles_cache.end()) { + char** dw_srcfiles; + Dwarf_Signed dw_filecount; + VERIFY(wrap(dwarf_srcfiles, cu_die.get(), &dw_srcfiles, &dw_filecount) == DW_DLV_OK); + it = srcfiles_cache.insert(it, {off, {dw_srcfiles, dw_filecount}}); + } + char** dw_srcfiles = it->second.first; + Dwarf_Signed dw_filecount = it->second.second; + if(Dwarf_Signed(file_i) < dw_filecount) { + // dwarf is using 1-indexing + filename = dw_srcfiles[file_i]; + } + } + return filename; + } + + void get_inlines_info( + const die_object& cu_die, + const die_object& die, + Dwarf_Addr pc, + Dwarf_Half dwversion, + std::vector& inlines + ) { + ASSERT(die.get_tag() == DW_TAG_subprogram || die.get_tag() == DW_TAG_inlined_subroutine); + // get_inlines_info is recursive and recurses into dies with pc ranges matching the pc we're looking for, + // however, because I wouldn't want anything stack overflowing I'm breaking the recursion out into a loop + // while looping when we find the target die we need to be able to store a die somewhere that doesn't die + // at the end of the list traversal, we'll use this as a holder for it + die_object current_obj_holder(dbg, nullptr); + optional> current_die = die; + while(current_die.has_value()) { + auto child = current_die.unwrap().get().get_child(); + if(!child) { + break; + } + optional> target_die; + walk_die_list( + child, + [this, &cu_die, pc, dwversion, &inlines, &target_die, ¤t_obj_holder] (const die_object& die) { + if(die.get_tag() == DW_TAG_inlined_subroutine && die.pc_in_die(cu_die, dwversion, pc)) { + const auto name = subprogram_symbol(die, dwversion); + auto file_i = die.get_unsigned_attribute(DW_AT_call_file); + // TODO: Refactor.... Probably put logic in resolve_filename. + if(file_i) { + // for dwarf 2, 3, 4, and experimental line table version 0xfe06 1-indexing is used + // for dwarf 5 0-indexing is used + optional> line_table_opt; + if(skeleton) { + line_table_opt = skeleton.unwrap().resolver.get_line_table( + skeleton.unwrap().cu_die + ); + } else { + line_table_opt = get_line_table(cu_die); + } + if(line_table_opt) { + auto& line_table = line_table_opt.unwrap().get(); + if(line_table.version != 5) { + if(file_i.unwrap() == 0) { + file_i.reset(); // 0 means no name to be found + } else { + // decrement to 0-based index + file_i.unwrap()--; + } + } + } else { + // silently continue + } + } + std::string file = file_i ? resolve_filename(cu_die, file_i.unwrap()) : ""; + const auto line = die.get_unsigned_attribute(DW_AT_call_line); + const auto col = die.get_unsigned_attribute(DW_AT_call_column); + inlines.push_back(stacktrace_frame{ + 0, + 0, // TODO: Could put an object address here... + {static_cast(line.value_or(0))}, + {static_cast(col.value_or(0))}, + file, + name, + true + }); + current_obj_holder = die.clone(); + target_die = current_obj_holder; + return false; + } else { + return true; + } + } + ); + // recursing into the found target as-if by get_inlines_info(cu_die, die, pc, dwversion, inlines); + current_die = target_die; + } + } + + std::string retrieve_symbol_for_subprogram( + const die_object& cu_die, + const die_object& die, + Dwarf_Addr pc, + Dwarf_Half dwversion, + std::vector& inlines + ) { + ASSERT(die.get_tag() == DW_TAG_subprogram); + const auto name = subprogram_symbol(die, dwversion); + if(detail::should_resolve_inlined_calls()) { + get_inlines_info(cu_die, die, pc, dwversion, inlines); + } + return name; + } + + // returns true if this call found the symbol + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + bool retrieve_symbol_walk( + const die_object& cu_die, + const die_object& die, + Dwarf_Addr pc, + Dwarf_Half dwversion, + stacktrace_frame& frame, + std::vector& inlines + ) { + bool found = false; + walk_die_list( + die, + [this, &cu_die, pc, dwversion, &frame, &inlines, &found] (const die_object& die) { + if(dump_dwarf) { + std::fprintf( + stderr, + "-------------> %08llx %s %s\n", + to_ull(die.get_global_offset()), + die.get_tag_name(), + die.get_name().c_str() + ); + } + if(!(die.get_tag() == DW_TAG_namespace || die.pc_in_die(cu_die, dwversion, pc))) { + if(dump_dwarf) { + std::fprintf(stderr, "pc not in die\n"); + } + } else { + if(trace_dwarf) { + std::fprintf( + stderr, + "%s %08llx %s\n", + die.get_tag() == DW_TAG_namespace ? "pc maybe in die (namespace)" : "pc in die", + to_ull(die.get_global_offset()), + die.get_tag_name() + ); + } + if(die.get_tag() == DW_TAG_subprogram) { + frame.symbol = retrieve_symbol_for_subprogram(cu_die, die, pc, dwversion, inlines); + found = true; + return false; + } + auto child = die.get_child(); + if(child) { + if(retrieve_symbol_walk(cu_die, child, pc, dwversion, frame, inlines)) { + found = true; + return false; + } + } else { + if(dump_dwarf) { + std::fprintf(stderr, "(no child)\n"); + } + } + } + return true; + } + ); + if(dump_dwarf) { + std::fprintf(stderr, "End walk_die_list\n"); + } + return found; + } + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + void preprocess_subprograms( + const die_object& cu_die, + const die_object& die, + Dwarf_Half dwversion, + std::vector& vec + ) { + walk_die_list( + die, + [this, &cu_die, dwversion, &vec] (const die_object& die) { + switch(die.get_tag()) { + case DW_TAG_subprogram: + { + auto ranges_vec = die.get_rangelist_entries(cu_die, dwversion); + // TODO: Feels super inefficient and some day should maybe use an interval tree. + for(auto range : ranges_vec) { + // TODO: Reduce cloning here + vec.push_back({ die.clone(), range.first, range.second }); + } + // Walk children to get things like lambdas + // TODO: Somehow find a way to get better names here? For gcc it's just "operator()" + // On clang it's better + auto child = die.get_child(); + if(child) { + preprocess_subprograms(cu_die, child, dwversion, vec); + } + } + break; + case DW_TAG_namespace: + case DW_TAG_structure_type: + case DW_TAG_class_type: + case DW_TAG_module: + case DW_TAG_imported_module: + case DW_TAG_compile_unit: + { + auto child = die.get_child(); + if(child) { + preprocess_subprograms(cu_die, child, dwversion, vec); + } + } + break; + default: + break; + } + return true; + } + ); + if(dump_dwarf) { + std::fprintf(stderr, "End walk_die_list\n"); + } + } + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + void retrieve_symbol( + const die_object& cu_die, + Dwarf_Addr pc, + Dwarf_Half dwversion, + stacktrace_frame& frame, + std::vector& inlines + ) { + if(get_cache_mode() == cache_mode::prioritize_memory) { + retrieve_symbol_walk(cu_die, cu_die, pc, dwversion, frame, inlines); + } else { + auto off = cu_die.get_global_offset(); + auto it = subprograms_cache.find(off); + if(it == subprograms_cache.end()) { + // TODO: Refactor. Do the sort in the preprocess function and return the vec directly. + std::vector vec; + preprocess_subprograms(cu_die, cu_die, dwversion, vec); + std::sort(vec.begin(), vec.end(), [] (const subprogram_entry& a, const subprogram_entry& b) { + return a.low < b.low; + }); + subprograms_cache.emplace(off, std::move(vec)); + it = subprograms_cache.find(off); + } + auto& vec = it->second; + auto vec_it = first_less_than_or_equal( + vec.begin(), + vec.end(), + pc, + [] (Dwarf_Addr pc, const subprogram_entry& entry) { + return pc < entry.low; + } + ); + // If the vector has been empty this can happen + if(vec_it != vec.end()) { + if(vec_it->die.pc_in_die(cu_die, dwversion, pc)) { + frame.symbol = retrieve_symbol_for_subprogram(cu_die, vec_it->die, pc, dwversion, inlines); + } + } else { + ASSERT(vec.size() == 0, "Vec should be empty?"); + } + } + } + + // returns a reference to a CU's line table, may be invalidated if the line_tables map is modified + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + optional> get_line_table(const die_object& cu_die) { + auto off = cu_die.get_global_offset(); + auto it = line_tables.find(off); + if(it != line_tables.end()) { + return it->second; + } else { + Dwarf_Unsigned version; + Dwarf_Small table_count; + Dwarf_Line_Context line_context; + int ret = wrap( + dwarf_srclines_b, + cu_die.get(), + &version, + &table_count, + &line_context + ); + static_assert(std::is_unsigned::value, "Expected unsigned Dwarf_Small"); + VERIFY(/*table_count >= 0 &&*/ table_count <= 2, "Unknown dwarf line table count"); + if(ret == DW_DLV_NO_ENTRY) { + // TODO: Failing silently for now + return nullopt; + } + VERIFY(ret == DW_DLV_OK); + + std::vector line_entries; + + if(get_cache_mode() == cache_mode::prioritize_speed) { + // build lookup table + Dwarf_Line* line_buffer = nullptr; + Dwarf_Signed line_count = 0; + Dwarf_Line* linebuf_actuals = nullptr; + Dwarf_Signed linecount_actuals = 0; + VERIFY( + wrap( + dwarf_srclines_two_level_from_linecontext, + line_context, + &line_buffer, + &line_count, + &linebuf_actuals, + &linecount_actuals + ) == DW_DLV_OK + ); + + // TODO: Make any attempt to note PC ranges? Handle line end sequence? + line_entries.reserve(line_count); + for(int i = 0; i < line_count; i++) { + Dwarf_Line line = line_buffer[i]; + Dwarf_Addr low_addr = 0; + VERIFY(wrap(dwarf_lineaddr, line, &low_addr) == DW_DLV_OK); + // scan ahead for the last line entry matching this pc + int j; + for(j = i + 1; j < line_count; j++) { + Dwarf_Addr addr = 0; + VERIFY(wrap(dwarf_lineaddr, line_buffer[j], &addr) == DW_DLV_OK); + if(addr != low_addr) { + break; + } + } + line = line_buffer[j - 1]; + // { + // Dwarf_Unsigned line_number = 0; + // VERIFY(wrap(dwarf_lineno, line, &line_number) == DW_DLV_OK); + // frame.line = static_cast(line_number); + // char* filename = nullptr; + // VERIFY(wrap(dwarf_linesrc, line, &filename) == DW_DLV_OK); + // auto wrapper = raii_wrap( + // filename, + // [this] (char* str) { if(str) dwarf_dealloc(dbg, str, DW_DLA_STRING); } + // ); + // frame.filename = filename; + // printf("%s : %d\n", filename, line_number); + // Dwarf_Bool is_line_end; + // VERIFY(wrap(dwarf_lineendsequence, line, &is_line_end) == DW_DLV_OK); + // if(is_line_end) { + // puts("Line end"); + // } + // } + line_entries.push_back({ + low_addr, + line + }); + i = j - 1; + } + // sort lines + std::sort(line_entries.begin(), line_entries.end(), [] (const line_entry& a, const line_entry& b) { + return a.low < b.low; + }); + } + + it = line_tables.insert({off, {version, line_context, std::move(line_entries)}}).first; + return it->second; + } + } + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + void retrieve_line_info( + const die_object& cu_die, + Dwarf_Addr pc, + stacktrace_frame& frame + ) { + // For debug fission the skeleton debug info will have the line table + if(skeleton) { + return skeleton.unwrap().resolver.retrieve_line_info(skeleton.unwrap().cu_die, pc, frame); + } + auto table_info_opt = get_line_table(cu_die); + if(!table_info_opt) { + return; // failing silently for now + } + auto& table_info = table_info_opt.unwrap().get(); + if(get_cache_mode() == cache_mode::prioritize_speed) { + // Lookup in the table + auto& line_entries = table_info.line_entries; + auto table_it = first_less_than_or_equal( + line_entries.begin(), + line_entries.end(), + pc, + [] (Dwarf_Addr pc, const line_entry& entry) { + return pc < entry.low; + } + ); + // If the vector has been empty this can happen + if(table_it != line_entries.end()) { + Dwarf_Line line = table_it->line; + // line number + if(!table_it->line_number) { + Dwarf_Unsigned line_number = 0; + VERIFY(wrap(dwarf_lineno, line, &line_number) == DW_DLV_OK); + table_it->line_number = static_cast(line_number); + } + frame.line = table_it->line_number.unwrap(); + // column number + if(!table_it->column_number) { + Dwarf_Unsigned column_number = 0; + VERIFY(wrap(dwarf_lineoff_b, line, &column_number) == DW_DLV_OK); + table_it->column_number = static_cast(column_number); + } + frame.column = table_it->column_number.unwrap(); + // filename + if(!table_it->path) { + char* filename = nullptr; + VERIFY(wrap(dwarf_linesrc, line, &filename) == DW_DLV_OK); + auto wrapper = raii_wrap( + filename, + [this] (char* str) { if(str) dwarf_dealloc(dbg, str, DW_DLA_STRING); } + ); + table_it->path = filename; + } + frame.filename = table_it->path.unwrap(); + } + } else { + Dwarf_Line_Context line_context = table_info.line_context; + // walk for it + Dwarf_Line* line_buffer = nullptr; + Dwarf_Signed line_count = 0; + Dwarf_Line* linebuf_actuals = nullptr; + Dwarf_Signed linecount_actuals = 0; + VERIFY( + wrap( + dwarf_srclines_two_level_from_linecontext, + line_context, + &line_buffer, + &line_count, + &linebuf_actuals, + &linecount_actuals + ) == DW_DLV_OK + ); + Dwarf_Addr last_lineaddr = 0; + Dwarf_Line last_line = nullptr; + for(int i = 0; i < line_count; i++) { + Dwarf_Line line = line_buffer[i]; + Dwarf_Addr lineaddr = 0; + VERIFY(wrap(dwarf_lineaddr, line, &lineaddr) == DW_DLV_OK); + Dwarf_Line found_line = nullptr; + if(pc == lineaddr) { + // Multiple PCs may correspond to a line, find the last one + found_line = line; + for(int j = i + 1; j < line_count; j++) { + Dwarf_Line line = line_buffer[j]; + Dwarf_Addr lineaddr = 0; + VERIFY(wrap(dwarf_lineaddr, line, &lineaddr) == DW_DLV_OK); + if(pc == lineaddr) { + found_line = line; + } + } + } else if(last_line && pc > last_lineaddr && pc < lineaddr) { + // Guess that the last line had it + found_line = last_line; + } + if(found_line) { + Dwarf_Unsigned line_number = 0; + VERIFY(wrap(dwarf_lineno, found_line, &line_number) == DW_DLV_OK); + frame.line = static_cast(line_number); + char* filename = nullptr; + VERIFY(wrap(dwarf_linesrc, found_line, &filename) == DW_DLV_OK); + auto wrapper = raii_wrap( + filename, + [this] (char* str) { if(str) dwarf_dealloc(dbg, str, DW_DLA_STRING); } + ); + frame.filename = filename; + } else { + Dwarf_Bool is_line_end; + VERIFY(wrap(dwarf_lineendsequence, line, &is_line_end) == DW_DLV_OK); + if(is_line_end) { + last_lineaddr = 0; + last_line = nullptr; + } else { + last_lineaddr = lineaddr; + last_line = line; + } + } + } + } + } + + struct cu_info { + maybe_owned_die_object cu_die; + Dwarf_Half dwversion; + }; + + // CU resolution has three paths: + // - If aranges are present, the pc is looked up in aranges (falls through to next cases if not in aranges) + // - If cache mode is prioritize memory, the CUs are walked for a match + // - Otherwise a CU cache is built up and CUs are looked up in the map + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + optional lookup_cu(Dwarf_Addr pc) { + // Check for .debug_aranges for fast lookup + if(aranges && !skeleton) { // don't bother under split dwarf + // Try to find pc in aranges + Dwarf_Arange arange; + if(wrap(dwarf_get_arange, aranges, arange_count, pc, &arange) == DW_DLV_OK) { + // Address in table, load CU die + Dwarf_Off cu_die_offset; + VERIFY(wrap(dwarf_get_cu_die_offset, arange, &cu_die_offset) == DW_DLV_OK); + Dwarf_Die raw_die; + // Setting is_info = true for now, assuming in .debug_info rather than .debug_types + VERIFY(wrap(dwarf_offdie_b, dbg, cu_die_offset, true, &raw_die) == DW_DLV_OK); + die_object cu_die(dbg, raw_die); + Dwarf_Half offset_size = 0; + Dwarf_Half dwversion = 0; + VERIFY(dwarf_get_version_of_die(cu_die.get(), &dwversion, &offset_size) == DW_DLV_OK); + if(trace_dwarf) { + std::fprintf(stderr, "Found CU in aranges\n"); + cu_die.print(); + } + return cu_info{maybe_owned_die_object::owned(std::move(cu_die)), dwversion}; + } + } + // otherwise, or if not in aranges + // one reason to fallback here is if the compilation has dwarf generated from different compilers and only + // some of them generate aranges (e.g. static linking with cpptrace after specifying clang++ as the c++ + // compiler while the C compiler defaults to an older gcc) + if(get_cache_mode() == cache_mode::prioritize_memory) { + // walk for the cu and go from there + optional info; + walk_compilation_units([this, pc, &info] (const die_object& cu_die) { + Dwarf_Half offset_size = 0; + Dwarf_Half dwversion = 0; + dwarf_get_version_of_die(cu_die.get(), &dwversion, &offset_size); + //auto p = cu_die.get_pc_range(dwversion); + //cu_die.print(); + //fprintf(stderr, " %llx, %llx\n", p.first, p.second); + if(trace_dwarf) { + std::fprintf(stderr, "CU: %d %s\n", dwversion, cu_die.get_name().c_str()); + } + // NOTE: If we have a corresponding skeleton, we assume we have one CU matching the skeleton CU + if( + ( + skeleton + && skeleton.unwrap().cu_die.pc_in_die( + skeleton.unwrap().cu_die, + skeleton.unwrap().dwversion, + pc + ) + ) || cu_die.pc_in_die(cu_die, dwversion, pc) + ) { + if(trace_dwarf) { + std::fprintf( + stderr, + "pc in die %08llx %s (now searching for %08llx)\n", + to_ull(cu_die.get_global_offset()), + cu_die.get_tag_name(), + to_ull(pc) + ); + } + info = cu_info{maybe_owned_die_object::owned(cu_die.clone()), dwversion}; + return false; + } + return true; + }); + return info; + } else { + lazy_generate_cu_cache(); + // look up the cu + auto vec_it = first_less_than_or_equal( + cu_cache.begin(), + cu_cache.end(), + pc, + [] (Dwarf_Addr pc, const cu_entry& entry) { + return pc < entry.low; + } + ); + // TODO: Vec-it is already range-based, this range check is redundant + // If the vector has been empty this can happen + if(vec_it != cu_cache.end()) { + // TODO: Cache the range list? + // NOTE: If we have a corresponding skeleton, we assume we have one CU matching the skeleton CU + if( + ( + skeleton + && skeleton.unwrap().cu_die.pc_in_die( + skeleton.unwrap().cu_die, + skeleton.unwrap().dwversion, + pc + ) + ) || vec_it->die.pc_in_die(vec_it->die, vec_it->dwversion, pc) + ) { + return cu_info{maybe_owned_die_object::ref(vec_it->die), vec_it->dwversion}; + } + } else { + // I've had this happen for _start, where there is a cached CU for the object but _start is outside + // of the CU's PC range + // ASSERT(cu_cache.size() == 0, "Vec should be empty?"); + } + return nullopt; + } + } + + optional get_dwo_name(const die_object& cu_die) { + if(auto dwo_name = cu_die.get_string_attribute(DW_AT_GNU_dwo_name)) { + return dwo_name; + } else if(auto dwo_name = cu_die.get_string_attribute(DW_AT_dwo_name)) { + return dwo_name; + } else { + return nullopt; + } + } + + void perform_dwarf_fission_resolution( + const die_object& cu_die, + const optional& dwo_name, + const object_frame& object_frame_info, + stacktrace_frame& frame, + std::vector& inlines + ) { + // Split dwarf / debug fission / dwo is handled here + // Location of the split full CU is a combination of DW_AT_dwo_name/DW_AT_GNU_dwo_name and DW_AT_comp_dir + // https://gcc.gnu.org/wiki/DebugFission + if(dwo_name) { + // TODO: DWO ID? + auto comp_dir = cu_die.get_string_attribute(DW_AT_comp_dir); + Dwarf_Half offset_size = 0; + Dwarf_Half dwversion = 0; + dwarf_get_version_of_die(cu_die.get(), &dwversion, &offset_size); + std::string path; + if(is_absolute(dwo_name.unwrap())) { + path = dwo_name.unwrap(); + } else if(comp_dir) { + path = comp_dir.unwrap() + PATH_SEP + dwo_name.unwrap(); + } else { + // maybe default to dwo_name but for now not doing anything + return; + } + // todo: slight inefficiency in this copy-back strategy due to other frame members + frame_with_inlines res; + if(get_cache_mode() == cache_mode::prioritize_memory) { + dwarf_resolver resolver( + path, + skeleton_info{cu_die.clone(), dwversion, *this} + ); + res = resolver.resolve_frame(object_frame_info); + } else { + auto off = cu_die.get_global_offset(); + auto it = split_full_cu_resolvers.find(off); + if(it == split_full_cu_resolvers.end()) { + it = split_full_cu_resolvers.emplace( + off, + std::unique_ptr( + new dwarf_resolver( + path, + skeleton_info{cu_die.clone(), dwversion, *this} + ) + ) + ).first; + } + res = it->second->resolve_frame(object_frame_info); + } + frame = std::move(res.frame); + inlines = std::move(res.inlines); + } + } + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + void resolve_frame_core( + const object_frame& object_frame_info, + stacktrace_frame& frame, + std::vector& inlines + ) { + auto pc = object_frame_info.object_address; + if(dump_dwarf) { + std::fprintf(stderr, "%s\n", object_path.c_str()); + std::fprintf(stderr, "%llx\n", to_ull(pc)); + } + optional cu = lookup_cu(pc); + if(cu) { + const auto& cu_die = cu.unwrap().cu_die.get(); + // gnu non-standard debug-fission may create non-skeleton CU DIEs and just add dwo attributes + // clang emits dwo names in the split CUs, so guard against going down the dwarf fission path (which + // doesn't infinitely recurse because it's not emitted as an absolute path and there's no comp dir but + // it's good to guard against the infinite recursion anyway) + auto dwo_name = get_dwo_name(cu_die); + if(cu_die.get_tag() == DW_TAG_skeleton_unit || (dwo_name && !skeleton)) { + perform_dwarf_fission_resolution(cu_die, dwo_name, object_frame_info, frame, inlines); + } else { + retrieve_line_info(cu_die, pc, frame); + retrieve_symbol(cu_die, pc, cu.unwrap().dwversion, frame, inlines); + } + } + } + + public: + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + frame_with_inlines resolve_frame(const object_frame& frame_info) override { + if(!ok) { + return { + { + frame_info.raw_address, + frame_info.object_address, + nullable::null(), + nullable::null(), + frame_info.object_path, + "", + false + }, + {} + }; + } + stacktrace_frame frame = null_frame; + frame.filename = frame_info.object_path; + frame.raw_address = frame_info.raw_address; + frame.object_address = frame_info.object_address; + if(trace_dwarf) { + std::fprintf( + stderr, + "Starting resolution for %s %08llx\n", + object_path.c_str(), + to_ull(frame_info.object_address) + ); + } + std::vector inlines; + resolve_frame_core( + frame_info, + frame, + inlines + ); + return {std::move(frame), std::move(inlines)}; + } + }; + + std::unique_ptr make_dwarf_resolver(const std::string& object_path) { + return std::unique_ptr(new dwarf_resolver(object_path)); + } +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/dwarf/resolver.hpp b/dep/cpptrace/src/symbols/dwarf/resolver.hpp new file mode 100644 index 00000000000..20734d3ec36 --- /dev/null +++ b/dep/cpptrace/src/symbols/dwarf/resolver.hpp @@ -0,0 +1,56 @@ +#ifndef SYMBOL_RESOLVER_HPP +#define SYMBOL_RESOLVER_HPP + +#include +#include "symbols/symbols.hpp" +#include "utils/common.hpp" + +#include + +#if false + #define CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING CPPTRACE_FORCE_NO_INLINE +#else + #define CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING +#endif + +namespace cpptrace { +namespace detail { +namespace libdwarf { + class symbol_resolver { + public: + virtual ~symbol_resolver() = default; + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + virtual frame_with_inlines resolve_frame(const object_frame& frame_info) = 0; + }; + + class null_resolver : public symbol_resolver { + public: + null_resolver() = default; + null_resolver(const std::string&) {} + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + frame_with_inlines resolve_frame(const object_frame& frame_info) override { + return { + { + frame_info.raw_address, + frame_info.object_address, + nullable::null(), + nullable::null(), + frame_info.object_path, + "", + false + }, + {} + }; + }; + }; + + std::unique_ptr make_dwarf_resolver(const std::string& object_path); + #if IS_APPLE + std::unique_ptr make_debug_map_resolver(const std::string& object_path); + #endif +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/symbols.hpp b/dep/cpptrace/src/symbols/symbols.hpp new file mode 100644 index 00000000000..dfe4a7b7269 --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols.hpp @@ -0,0 +1,72 @@ +#ifndef SYMBOLS_HPP +#define SYMBOLS_HPP + +#include + +#include +#include +#include + +#include "binary/object.hpp" + +namespace cpptrace { +namespace detail { + using collated_vec = std::vector< + std::pair, std::reference_wrapper> + >; + struct frame_with_inlines { + stacktrace_frame frame; + std::vector inlines; + }; + using collated_vec_with_inlines = std::vector< + std::pair, std::reference_wrapper> + >; + + // These two helpers create a map from a target object to a vector of frames to resolve + std::unordered_map collate_frames( + const std::vector& frames, + std::vector& trace + ); + std::unordered_map collate_frames( + const std::vector& frames, + std::vector& trace + ); + + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE + namespace libbacktrace { + std::vector resolve_frames(const std::vector& frames); + } + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF + namespace libdwarf { + std::vector resolve_frames(const std::vector& frames); + } + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDL + namespace libdl { + std::vector resolve_frames(const std::vector& frames); + } + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE + namespace addr2line { + std::vector resolve_frames(const std::vector& frames); + } + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_DBGHELP + namespace dbghelp { + std::vector resolve_frames(const std::vector& frames); + } + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_NOTHING + namespace nothing { + std::vector resolve_frames(const std::vector& frames); + std::vector resolve_frames(const std::vector& frames); + } + #endif + + std::vector resolve_frames(const std::vector& frames); + std::vector resolve_frames(const std::vector& frames); +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/symbols_core.cpp b/dep/cpptrace/src/symbols/symbols_core.cpp new file mode 100644 index 00000000000..a77896e727e --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols_core.cpp @@ -0,0 +1,151 @@ +#include "symbols/symbols.hpp" + +#include +#include + +#include "utils/common.hpp" +#include "binary/object.hpp" + +namespace cpptrace { +namespace detail { + template + std::unordered_map collate_frames( + const std::vector& frames, + std::vector& trace + ) { + std::unordered_map entries; + for(std::size_t i = 0; i < frames.size(); i++) { + const auto& entry = frames[i]; + // If libdl fails to find the shared object for a frame, the path will be empty. I've observed this + // on macos when looking up the shared object containing `start`. + if(!entry.object_path.empty()) { + entries[entry.object_path].emplace_back( + entry, + trace[i] + ); + } + } + return entries; + } + + std::unordered_map collate_frames( + const std::vector& frames, + std::vector& trace + ) { + return collate_frames(frames, trace); + } + std::unordered_map collate_frames( + const std::vector& frames, + std::vector& trace + ) { + return collate_frames(frames, trace); + } + + /* + * + * + * All the code here is awful and I'm not proud of it. + * + * + * + */ + + // Resolver must not support walking inlines + void fill_blanks( + std::vector& vec, + std::vector (*resolver)(const std::vector&) + ) { + std::vector addresses; + for(const auto& frame : vec) { + if(frame.symbol.empty() || frame.filename.empty()) { + addresses.push_back(frame.raw_address); + } + } + std::vector new_frames = resolver(addresses); + std::size_t i = 0; + for(auto& frame : vec) { + if(frame.symbol.empty() || frame.filename.empty()) { + // three cases to handle, either partially overwrite or fully overwrite + if(frame.symbol.empty() && frame.filename.empty()) { + frame = new_frames[i]; + } else if(frame.symbol.empty() && !frame.filename.empty()) { + frame.symbol = new_frames[i].symbol; + } else { + ASSERT(!frame.symbol.empty() && frame.filename.empty()); + frame.filename = new_frames[i].filename; + frame.line = new_frames[i].line; + frame.column = new_frames[i].column; + } + i++; + } + } + } + + std::vector resolve_frames(const std::vector& frames) { + #if defined(CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF) && defined(CPPTRACE_GET_SYMBOLS_WITH_DBGHELP) + std::vector trace = libdwarf::resolve_frames(frames); + fill_blanks(trace, dbghelp::resolve_frames); + return trace; + #else + #if defined(CPPTRACE_GET_SYMBOLS_WITH_LIBDL) \ + || defined(CPPTRACE_GET_SYMBOLS_WITH_DBGHELP) \ + || defined(CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE) + // actually need to go backwards to a void* + std::vector raw_frames(frames.size()); + for(std::size_t i = 0; i < frames.size(); i++) { + raw_frames[i] = frames[i].raw_address; + } + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDL + return libdl::resolve_frames(raw_frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF + return libdwarf::resolve_frames(frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_DBGHELP + return dbghelp::resolve_frames(raw_frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE + return addr2line::resolve_frames(frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE + return libbacktrace::resolve_frames(raw_frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_NOTHING + return nothing::resolve_frames(frames); + #endif + #endif + } + + std::vector resolve_frames(const std::vector& frames) { + #if defined(CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF) \ + || defined(CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE) + auto dlframes = get_frames_object_info(frames); + #endif + #if defined(CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF) && defined(CPPTRACE_GET_SYMBOLS_WITH_DBGHELP) + std::vector trace = libdwarf::resolve_frames(dlframes); + fill_blanks(trace, dbghelp::resolve_frames); + return trace; + #else + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDL + return libdl::resolve_frames(frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF + return libdwarf::resolve_frames(dlframes); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_DBGHELP + return dbghelp::resolve_frames(frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE + return addr2line::resolve_frames(dlframes); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE + return libbacktrace::resolve_frames(frames); + #endif + #ifdef CPPTRACE_GET_SYMBOLS_WITH_NOTHING + return nothing::resolve_frames(frames); + #endif + #endif + } +} +} diff --git a/dep/cpptrace/src/symbols/symbols_with_addr2line.cpp b/dep/cpptrace/src/symbols/symbols_with_addr2line.cpp new file mode 100644 index 00000000000..8a4723860c4 --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols_with_addr2line.cpp @@ -0,0 +1,322 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE + +#include +#include "symbols/symbols.hpp" +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_LINUX || IS_APPLE + #include + #include + #include +#endif + +#include "binary/object.hpp" + +namespace cpptrace { +namespace detail { +namespace addr2line { + #if IS_LINUX || IS_APPLE + bool has_addr2line() { + static std::mutex mutex; + static bool has_addr2line = false; + static bool checked = false; + std::lock_guard lock(mutex); + if(!checked) { + checked = true; + // Detects if addr2line exists by trying to invoke addr2line --help + constexpr int magic = 42; + const pid_t pid = fork(); + if(pid == -1) { return false; } + if(pid == 0) { // child + close(STDOUT_FILENO); + close(STDERR_FILENO); // atos --help writes to stderr + #ifdef CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH + #if !IS_APPLE + execlp("addr2line", "addr2line", "--help", nullptr); + #else + execlp("atos", "atos", "--help", nullptr); + #endif + #else + #ifndef CPPTRACE_ADDR2LINE_PATH + #error "CPPTRACE_ADDR2LINE_PATH must be defined if CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH is not" + #endif + execl(CPPTRACE_ADDR2LINE_PATH, CPPTRACE_ADDR2LINE_PATH, "--help", nullptr); + #endif + _exit(magic); + } + int status; + waitpid(pid, &status, 0); + has_addr2line = WEXITSTATUS(status) == 0; + } + return has_addr2line; + } + + struct pipe_ends { + int read; + int write; + }; + + struct pipe_t { + union { + pipe_ends end; + int data[2]; + }; + }; + static_assert(sizeof(pipe_t) == 2 * sizeof(int), "Unexpected struct packing"); + + std::string resolve_addresses(const std::string& addresses, const std::string& executable) { + pipe_t output_pipe; + pipe_t input_pipe; + VERIFY(pipe(output_pipe.data) == 0); + VERIFY(pipe(input_pipe.data) == 0); + const pid_t pid = fork(); + if(pid == -1) { return ""; } // error? TODO: Diagnostic + if(pid == 0) { // child + dup2(output_pipe.end.write, STDOUT_FILENO); + dup2(input_pipe.end.read, STDIN_FILENO); + close(output_pipe.end.read); + close(output_pipe.end.write); + close(input_pipe.end.read); + close(input_pipe.end.write); + close(STDERR_FILENO); // TODO: Might be worth conditionally enabling or piping + #ifdef CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH + #if !IS_APPLE + execlp("addr2line", "addr2line", "-e", executable.c_str(), "-f", "-C", "-p", nullptr); + #else + execlp("atos", "atos", "-o", executable.c_str(), "-fullPath", nullptr); + #endif + #else + #ifndef CPPTRACE_ADDR2LINE_PATH + #error "CPPTRACE_ADDR2LINE_PATH must be defined if CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH is not" + #endif + #if !IS_APPLE + execl( + CPPTRACE_ADDR2LINE_PATH, + CPPTRACE_ADDR2LINE_PATH, + "-e", + executable.c_str(), + "-f", + "-C", + "-p", + nullptr + ); + #else + execl( + CPPTRACE_ADDR2LINE_PATH, + CPPTRACE_ADDR2LINE_PATH, + "-o", executable.c_str(), + "-fullPath", + nullptr + ); + #endif + #endif + _exit(1); // TODO: Diagnostic? + } + VERIFY(write(input_pipe.end.write, addresses.data(), addresses.size()) != -1); + close(input_pipe.end.read); + close(input_pipe.end.write); + close(output_pipe.end.write); + std::string output; + constexpr int buffer_size = 4096; + char buffer[buffer_size]; + std::size_t count = 0; + while((count = read(output_pipe.end.read, buffer, buffer_size)) > 0) { + output.insert(output.end(), buffer, buffer + count); + } + // TODO: check status from addr2line? + waitpid(pid, nullptr, 0); + return output; + } + #elif IS_WINDOWS + bool has_addr2line() { + static std::mutex mutex; + static bool has_addr2line = false; + static bool checked = false; + std::lock_guard lock(mutex); + if(!checked) { + // TODO: Popen is a hack. Implement properly with CreateProcess and pipes later. + checked = true; + #ifdef CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH + std::FILE* p = popen("addr2line --version", "r"); + #else + #ifndef CPPTRACE_ADDR2LINE_PATH + #error "CPPTRACE_ADDR2LINE_PATH must be defined if CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH is not" + #endif + std::FILE* p = popen(CPPTRACE_ADDR2LINE_PATH " --version", "r"); + #endif + if(p) { + has_addr2line = pclose(p) == 0; + } + } + return has_addr2line; + } + + std::string resolve_addresses(const std::string& addresses, const std::string& executable) { + // TODO: Popen is a hack. Implement properly with CreateProcess and pipes later. + ///fprintf(stderr, ("addr2line -e " + executable + " -fCp " + addresses + "\n").c_str()); + #ifdef CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH + std::FILE* p = popen(("addr2line -e \"" + executable + "\" -fCp " + addresses).c_str(), "r"); + #else + #ifndef CPPTRACE_ADDR2LINE_PATH + #error "CPPTRACE_ADDR2LINE_PATH must be defined if CPPTRACE_ADDR2LINE_SEARCH_SYSTEM_PATH is not" + #endif + std::FILE* p = popen( + (CPPTRACE_ADDR2LINE_PATH " -e \"" + executable + "\" -fCp " + addresses).c_str(), + "r" + ); + #endif + std::string output; + constexpr int buffer_size = 4096; + char buffer[buffer_size]; + std::size_t count = 0; + while((count = std::fread(buffer, 1, buffer_size, p)) > 0) { + output.insert(output.end(), buffer, buffer + count); + } + pclose(p); + ///fprintf(stderr, "%s\n", output.c_str()); + return output; + } + #endif + + void update_trace(const std::string& line, std::size_t entry_index, const collated_vec& entries_vec) { + #if !IS_APPLE + // Result will be of the form " at path:line" + // The path may be ?? if addr2line cannot resolve, line may be ? + // Edge cases: + // ?? ??:0 + // symbol :? + const std::size_t at_location = line.find(" at "); + std::size_t symbol_end; + std::size_t filename_start; + if(at_location != std::string::npos) { + symbol_end = at_location; + filename_start = at_location + 4; + } else { + VERIFY(line.find("?? ") == 0, "Unexpected edge case while processing addr2line output"); + symbol_end = 2; + filename_start = 3; + } + auto symbol = line.substr(0, symbol_end); + auto colon = line.rfind(':'); + VERIFY(colon != std::string::npos); + VERIFY(colon >= filename_start); // :? to deal with "symbol :?" edge case + auto filename = line.substr(filename_start, colon - filename_start); + auto line_number = line.substr(colon + 1); + if(line_number != "?") { + entries_vec[entry_index].second.get().line = std::stoi(line_number); + } + if(!filename.empty() && filename != "??") { + entries_vec[entry_index].second.get().filename = filename; + } + if(!symbol.empty()) { + entries_vec[entry_index].second.get().symbol = symbol; + } + #else + // Result will be of the form " (in ) (file:line)" + // The symbol may just be the given address if atos can't resolve it + // Examples: + // trace() (in demo) (demo.cpp:8) + // 0x100003b70 (in demo) + // 0xffffffffffffffff + // foo (in bar) + 14 + // I'm making some assumptions here. Support may need to be improved later. This is tricky output to + // parse. + const std::size_t in_location = line.find(" (in "); + if(in_location == std::string::npos) { + // presumably the 0xffffffffffffffff case + return; + } + const std::size_t symbol_end = in_location; + entries_vec[entry_index].second.get().symbol = line.substr(0, symbol_end); + const std::size_t object_end = line.find(")", in_location); + VERIFY( + object_end != std::string::npos, + "Unexpected edge case while processing addr2line/atos output" + ); + const std::size_t filename_start = line.find(") (", object_end); + if(filename_start == std::string::npos) { + // presumably something like 0x100003b70 (in demo) or foo (in bar) + 14 + return; + } + const std::size_t filename_end = line.find(":", filename_start); + VERIFY( + filename_end != std::string::npos, + "Unexpected edge case while processing addr2line/atos output" + ); + entries_vec[entry_index].second.get().filename = line.substr( + filename_start + 3, + filename_end - filename_start - 3 + ); + const std::size_t line_start = filename_end + 1; + const std::size_t line_end = line.find(")", filename_end); + VERIFY( + line_end == line.size() - 1, + "Unexpected edge case while processing addr2line/atos output" + ); + entries_vec[entry_index].second.get().line = std::stoi(line.substr(line_start, line_end - line_start)); + #endif + } + + std::vector resolve_frames(const std::vector& frames) { + // TODO: Refactor better + std::vector trace(frames.size(), null_frame); + for(std::size_t i = 0; i < frames.size(); i++) { + trace[i].raw_address = frames[i].raw_address; + trace[i].object_address = frames[i].object_address; + // Set what is known for now, and resolutions from addr2line should overwrite + trace[i].filename = frames[i].object_path; + } + if(has_addr2line()) { + const auto entries = collate_frames(frames, trace); + for(const auto& entry : entries) { + try { + const auto& object_name = entry.first; + const auto& entries_vec = entry.second; + // You may ask why it'd ever happen that there could be an empty entries_vec array, if there're + // no addresses why would get_addr2line_targets do anything? The reason is because if things in + // get_addr2line_targets fail it will silently skip. This is partly an optimization but also an + // assertion below will fail if addr2line is given an empty input. + if(entries_vec.empty()) { + continue; + } + std::string address_input; + for(const auto& pair : entries_vec) { + address_input += microfmt::format( + "{:h}{}", + pair.first.get().object_address, + #if !IS_WINDOWS + '\n' + #else + ' ' + #endif + ); + } + auto output = split(trim(resolve_addresses(address_input, object_name)), "\n"); + VERIFY(output.size() == entries_vec.size()); + for(std::size_t i = 0; i < output.size(); i++) { + update_trace(output[i], i, entries_vec); + } + } catch(...) { // NOSONAR + if(!should_absorb_trace_exceptions()) { + throw; + } + } + } + } + return trace; + } +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/symbols_with_dbghelp.cpp b/dep/cpptrace/src/symbols/symbols_with_dbghelp.cpp new file mode 100644 index 00000000000..981cb109447 --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols_with_dbghelp.cpp @@ -0,0 +1,456 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_DBGHELP + +#include +#include "symbols/symbols.hpp" +#include "platform/dbghelp_syminit_manager.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cpptrace { +namespace detail { +namespace dbghelp { + // SymFromAddr only returns the function's name. In order to get information about parameters, + // important for C++ stack traces where functions may be overloaded, we have to manually use + // Windows DIA to walk debug info structures. Resources: + // https://web.archive.org/web/20201027025750/http://www.debuginfo.com/articles/dbghelptypeinfo.html + // https://web.archive.org/web/20201203160805/http://www.debuginfo.com/articles/dbghelptypeinfofigures.html + // https://github.com/DynamoRIO/dynamorio/blob/master/ext/drsyms/drsyms_windows.c#L1370-L1439 + // TODO: Currently unable to detect rvalue references + // TODO: Currently unable to detect const + enum class SymTagEnum { + SymTagNull, SymTagExe, SymTagCompiland, SymTagCompilandDetails, SymTagCompilandEnv, + SymTagFunction, SymTagBlock, SymTagData, SymTagAnnotation, SymTagLabel, SymTagPublicSymbol, + SymTagUDT, SymTagEnum, SymTagFunctionType, SymTagPointerType, SymTagArrayType, + SymTagBaseType, SymTagTypedef, SymTagBaseClass, SymTagFriend, SymTagFunctionArgType, + SymTagFuncDebugStart, SymTagFuncDebugEnd, SymTagUsingNamespace, SymTagVTableShape, + SymTagVTable, SymTagCustom, SymTagThunk, SymTagCustomType, SymTagManagedType, + SymTagDimension, SymTagCallSite, SymTagInlineSite, SymTagBaseInterface, SymTagVectorType, + SymTagMatrixType, SymTagHLSLType, SymTagCaller, SymTagCallee, SymTagExport, + SymTagHeapAllocationSite, SymTagCoffGroup, SymTagMax + }; + + enum class IMAGEHLP_SYMBOL_TYPE_INFO { + TI_GET_SYMTAG, TI_GET_SYMNAME, TI_GET_LENGTH, TI_GET_TYPE, TI_GET_TYPEID, TI_GET_BASETYPE, + TI_GET_ARRAYINDEXTYPEID, TI_FINDCHILDREN, TI_GET_DATAKIND, TI_GET_ADDRESSOFFSET, + TI_GET_OFFSET, TI_GET_VALUE, TI_GET_COUNT, TI_GET_CHILDRENCOUNT, TI_GET_BITPOSITION, + TI_GET_VIRTUALBASECLASS, TI_GET_VIRTUALTABLESHAPEID, TI_GET_VIRTUALBASEPOINTEROFFSET, + TI_GET_CLASSPARENTID, TI_GET_NESTED, TI_GET_SYMINDEX, TI_GET_LEXICALPARENT, TI_GET_ADDRESS, + TI_GET_THISADJUST, TI_GET_UDTKIND, TI_IS_EQUIV_TO, TI_GET_CALLING_CONVENTION, + TI_IS_CLOSE_EQUIV_TO, TI_GTIEX_REQS_VALID, TI_GET_VIRTUALBASEOFFSET, + TI_GET_VIRTUALBASEDISPINDEX, TI_GET_IS_REFERENCE, TI_GET_INDIRECTVIRTUALBASECLASS, + TI_GET_VIRTUALBASETABLETYPE, TI_GET_OBJECTPOINTERTYPE, IMAGEHLP_SYMBOL_TYPE_INFO_MAX + }; + + enum class BasicType { + btNoType = 0, btVoid = 1, btChar = 2, btWChar = 3, btInt = 6, btUInt = 7, btFloat = 8, + btBCD = 9, btBool = 10, btLong = 13, btULong = 14, btCurrency = 25, btDate = 26, + btVariant = 27, btComplex = 28, btBit = 29, btBSTR = 30, btHresult = 31 + }; + + // SymGetTypeInfo utility + template + T get_info(ULONG type_index, HANDLE proc, ULONG64 modbase) { + T info; + if( + !SymGetTypeInfo( + proc, + modbase, + type_index, + static_cast<::IMAGEHLP_SYMBOL_TYPE_INFO>(SymType), + &info + ) + ) { + if(FAILABLE) { + return (T)-1; + } else { + throw internal_error( + "SymGetTypeInfo failed: {}", std::system_error(GetLastError(), std::system_category()).what() + ); + } + } + return info; + } + + template + std::string get_info_wchar(ULONG type_index, HANDLE proc, ULONG64 modbase) { + WCHAR* info; + if( + !SymGetTypeInfo(proc, modbase, type_index, static_cast<::IMAGEHLP_SYMBOL_TYPE_INFO>(SymType), &info) + ) { + throw internal_error( + "SymGetTypeInfo failed: {}", std::system_error(GetLastError(), std::system_category()).what() + ); + } + // special case to properly free a buffer and convert string to narrow chars, only used for + // TI_GET_SYMNAME + static_assert( + SymType == IMAGEHLP_SYMBOL_TYPE_INFO::TI_GET_SYMNAME, + "get_info_wchar called with unexpected IMAGEHLP_SYMBOL_TYPE_INFO" + ); + std::wstring wstr(info); + std::string str; + str.reserve(wstr.size()); + for(const auto c : wstr) { + str.push_back(static_cast(c)); + } + LocalFree(info); + return str; + } + + // Translate basic types to string + static std::string get_basic_type(ULONG type_index, HANDLE proc, ULONG64 modbase) { + auto basic_type = get_info( + type_index, + proc, + modbase + ); + //auto length = get_info(type_index, proc, modbase); + switch(basic_type) { + case BasicType::btNoType: + return ""; + case BasicType::btVoid: + return "void"; + case BasicType::btChar: + return "char"; + case BasicType::btWChar: + return "wchar_t"; + case BasicType::btInt: + return "int"; + case BasicType::btUInt: + return "unsigned int"; + case BasicType::btFloat: + return "float"; + case BasicType::btBool: + return "bool"; + case BasicType::btLong: + return "long"; + case BasicType::btULong: + return "unsigned long"; + default: + return ""; + } + } + + static std::string resolve_type(ULONG type_index, HANDLE proc, ULONG64 modbase); + + struct class_name_result { + bool has_class_name; + std::string name; + }; + // Helper for member pointers + static class_name_result lookup_class_name(ULONG type_index, HANDLE proc, ULONG64 modbase) { + DWORD class_parent_id = get_info( + type_index, + proc, + modbase + ); + if(class_parent_id == (DWORD)-1) { + return {false, ""}; + } else { + return {true, resolve_type(class_parent_id, proc, modbase)}; + } + } + + struct type_result { + std::string base; + std::string extent; + }; + // Resolve more complex types + // returns [base, extent] + static type_result lookup_type(ULONG type_index, HANDLE proc, ULONG64 modbase) { + auto tag = get_info(type_index, proc, modbase); + switch(tag) { + case SymTagEnum::SymTagBaseType: + return {get_basic_type(type_index, proc, modbase), ""}; + case SymTagEnum::SymTagPointerType: { + DWORD underlying_type_id = get_info( + type_index, + proc, + modbase + ); + bool is_ref = get_info( + type_index, + proc, + modbase + ); + std::string pp = is_ref ? "&" : "*"; // pointer punctuator + auto class_name_res = lookup_class_name(type_index, proc, modbase); + if(class_name_res.has_class_name) { + pp = class_name_res.name + "::" + pp; + } + const auto type = lookup_type(underlying_type_id, proc, modbase); + if(type.extent.empty()) { + return {type.base + (pp.size() > 1 ? " " : "") + pp, ""}; + } else { + return {type.base + "(" + pp, ")" + type.extent}; + } + } + case SymTagEnum::SymTagArrayType: { + DWORD underlying_type_id = get_info( + type_index, + proc, + modbase + ); + DWORD length = get_info( + type_index, + proc, + modbase + ); + const auto type = lookup_type(underlying_type_id, proc, modbase); + return {type.base, "[" + std::to_string(length) + "]" + type.extent}; + } + case SymTagEnum::SymTagFunctionType: { + DWORD return_type_id = get_info( + type_index, + proc, + modbase + ); + DWORD n_children = get_info( + type_index, + proc, + modbase + ); + DWORD class_parent_id = get_info( + type_index, + proc, + modbase + ); + int n_ignore = class_parent_id != (DWORD)-1; // ignore this param + // this must be ignored before TI_FINDCHILDREN_PARAMS::Count is set, else error + n_children -= n_ignore; + // return type + const auto return_type = lookup_type(return_type_id, proc, modbase); + if(n_children == 0) { + return {return_type.base, "()" + return_type.extent}; + } else { + // alignment should be fine + std::size_t sz = sizeof(TI_FINDCHILDREN_PARAMS) + + (n_children) * sizeof(TI_FINDCHILDREN_PARAMS::ChildId[0]); + TI_FINDCHILDREN_PARAMS* children = (TI_FINDCHILDREN_PARAMS*) new char[sz]; + children->Start = 0; + children->Count = n_children; + if( + !SymGetTypeInfo( + proc, modbase, type_index, + static_cast<::IMAGEHLP_SYMBOL_TYPE_INFO>( + IMAGEHLP_SYMBOL_TYPE_INFO::TI_FINDCHILDREN + ), + children + ) + ) { + throw internal_error( + "SymGetTypeInfo failed: {}", + std::system_error(GetLastError(), std::system_category()).what() + ); + } + // get children type + std::string extent = "("; + if(children->Start != 0) { + throw internal_error("Error: children->Start == 0"); + } + for(std::size_t i = 0; i < n_children; i++) { + extent += (i == 0 ? "" : ", ") + resolve_type(children->ChildId[i], proc, modbase); + } + extent += ")"; + delete[] (char*) children; + return {return_type.base, extent + return_type.extent}; + } + } + case SymTagEnum::SymTagFunctionArgType: { + DWORD underlying_type_id = + get_info(type_index, proc, modbase); + return {resolve_type(underlying_type_id, proc, modbase), ""}; + } + case SymTagEnum::SymTagTypedef: + case SymTagEnum::SymTagEnum: + case SymTagEnum::SymTagUDT: + case SymTagEnum::SymTagBaseClass: + return { + get_info_wchar(type_index, proc, modbase), "" + }; + default: + return { + "::type>(tag)) + + ">", + "" + }; + }; + } + + static std::string resolve_type(ULONG type_index, HANDLE proc, ULONG64 modbase) { + const auto type = lookup_type(type_index, proc, modbase); + return type.base + type.extent; + } + + struct function_info { + HANDLE proc; + ULONG64 modbase; + int counter; + int n_children; + int n_ignore; + std::string str; + }; + + // Enumerates function parameters + static BOOL __stdcall enumerator_callback( + PSYMBOL_INFO symbol_info, + ULONG, + PVOID data + ) { + function_info* ctx = (function_info*)data; + if(ctx->counter++ >= ctx->n_children) { + return false; + } + if(ctx->n_ignore-- > 0) { + return true; // just skip + } + ctx->str += resolve_type(symbol_info->TypeIndex, ctx->proc, ctx->modbase); + if(ctx->counter < ctx->n_children) { + ctx->str += ", "; + } + return true; + } + + std::recursive_mutex dbghelp_lock; + + // TODO: Handle backtrace_pcinfo calling the callback multiple times on inlined functions + stacktrace_frame resolve_frame(HANDLE proc, frame_ptr addr) { + // The get_frame_object_info() ends up being inexpensive, at on my machine + // debug release + // uncached trace resolution (29 frames) 1.9-2.1 ms 1.4-1.8 ms + // cached trace resolution (29 frames) 1.1-1.2 ms 0.2-0.4 ms + // get_frame_object_info() 0.001-0.002 ms 0.0003-0.0006 ms + // At some point it might make sense to make an option to control this. + auto object_frame = get_frame_object_info(addr); + const std::lock_guard lock(dbghelp_lock); // all dbghelp functions are not thread safe + alignas(SYMBOL_INFO) char buffer[sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR)]; + SYMBOL_INFO* symbol = (SYMBOL_INFO*)buffer; + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; + union { DWORD64 a; DWORD b; } displacement; + IMAGEHLP_LINE line; + bool got_line = SymGetLineFromAddr(proc, addr, &displacement.b, &line); + if(SymFromAddr(proc, addr, &displacement.a, symbol)) { + if(got_line) { + IMAGEHLP_STACK_FRAME frame; + frame.InstructionOffset = symbol->Address; + // https://docs.microsoft.com/en-us/windows/win32/api/dbghelp/nf-dbghelp-symsetcontext + // "If you call SymSetContext to set the context to its current value, the + // function fails but GetLastError returns ERROR_SUCCESS." + // This is the stupidest fucking api I've ever worked with. + if(SymSetContext(proc, &frame, nullptr) == FALSE && GetLastError() != ERROR_SUCCESS) { + std::fprintf(stderr, "Stack trace: Internal error while calling SymSetContext\n"); + return { + addr, + object_frame.object_address, + { static_cast(line.LineNumber) }, + nullable::null(), + line.FileName, + symbol->Name, + false + }; + } + DWORD n_children = get_info( + symbol->TypeIndex, + proc, + symbol->ModBase + ); + DWORD class_parent_id = get_info( + symbol->TypeIndex, + proc, + symbol->ModBase + ); + function_info fi { + proc, + symbol->ModBase, + 0, + int(n_children), + class_parent_id != (DWORD)-1, + "" + }; + SymEnumSymbols(proc, 0, nullptr, enumerator_callback, &fi); + std::string signature = symbol->Name + std::string("(") + fi.str + ")"; + // There's a phenomina with DIA not inserting commas after template parameters. Fix them here. + static std::regex comma_re(R"(,(?=\S))"); + signature = std::regex_replace(signature, comma_re, ", "); + return { + addr, + object_frame.object_address, + { static_cast(line.LineNumber) }, + nullable::null(), + line.FileName, + signature, + false, + }; + } else { + return { + addr, + object_frame.object_address, + nullable::null(), + nullable::null(), + "", + symbol->Name, + false + }; + } + } else { + return { + addr, + object_frame.object_address, + nullable::null(), + nullable::null(), + "", + "", + false + }; + } + } + + std::vector resolve_frames(const std::vector& frames) { + const std::lock_guard lock(dbghelp_lock); // all dbghelp functions are not thread safe + std::vector trace; + trace.reserve(frames.size()); + + // TODO: When does this need to be called? Can it be moved to the symbolizer? + SymSetOptions(SYMOPT_ALLOW_ABSOLUTE_SYMBOLS); + HANDLE proc = GetCurrentProcess(); + if(get_cache_mode() == cache_mode::prioritize_speed) { + get_syminit_manager().init(proc); + } else { + if(!SymInitialize(proc, NULL, TRUE)) { + throw internal_error("SymInitialize failed"); + } + } + for(const auto frame : frames) { + try { + trace.push_back(resolve_frame(proc, frame)); + } catch(...) { // NOSONAR + if(!detail::should_absorb_trace_exceptions()) { + throw; + } + auto entry = null_frame; + entry.raw_address = frame; + trace.push_back(entry); + } + } + if(get_cache_mode() != cache_mode::prioritize_speed) { + if(!SymCleanup(proc)) { + throw internal_error("SymCleanup failed"); + } + } + return trace; + } +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/symbols_with_dl.cpp b/dep/cpptrace/src/symbols/symbols_with_dl.cpp new file mode 100644 index 00000000000..8b3d35a1ab3 --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols_with_dl.cpp @@ -0,0 +1,55 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDL + +#include +#include "symbols/symbols.hpp" + +#include +#include +#include + +#include + +namespace cpptrace { +namespace detail { +namespace libdl { + stacktrace_frame resolve_frame(const frame_ptr addr) { + Dl_info info; + if(dladdr(reinterpret_cast(addr), &info)) { // thread-safe + auto base = get_module_image_base(info.dli_fname); + return { + addr, + base.has_value() + ? addr - reinterpret_cast(info.dli_fbase) + base.unwrap_value() + : 0, + nullable::null(), + nullable::null(), + info.dli_fname ? info.dli_fname : "", + info.dli_sname ? info.dli_sname : "", + false + }; + } else { + return { + addr, + 0, + nullable::null(), + nullable::null(), + "", + "", + false + }; + } + } + + std::vector resolve_frames(const std::vector& frames) { + std::vector trace; + trace.reserve(frames.size()); + for(const auto frame : frames) { + trace.push_back(resolve_frame(frame)); + } + return trace; + } +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/symbols_with_libbacktrace.cpp b/dep/cpptrace/src/symbols/symbols_with_libbacktrace.cpp new file mode 100644 index 00000000000..360ff945c8e --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols_with_libbacktrace.cpp @@ -0,0 +1,106 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE + +#include +#include "symbols/symbols.hpp" +#include "platform/program_name.hpp" + +#include +#include +#include +#include +#include +#include + +#ifdef CPPTRACE_BACKTRACE_PATH +#include CPPTRACE_BACKTRACE_PATH +#else +#include +#endif + +namespace cpptrace { +namespace detail { +namespace libbacktrace { + int full_callback(void* data, std::uintptr_t address, const char* file, int line, const char* symbol) { + stacktrace_frame& frame = *static_cast(data); + frame.raw_address = address; + frame.line = line; + frame.filename = file ? file : ""; + frame.symbol = symbol ? symbol : ""; + return 0; + } + + void syminfo_callback(void* data, std::uintptr_t address, const char* symbol, std::uintptr_t, std::uintptr_t) { + stacktrace_frame& frame = *static_cast(data); + frame.raw_address = address; + frame.line = 0; + frame.filename = ""; + frame.symbol = symbol ? symbol : ""; + } + + void error_callback(void*, const char* msg, int errnum) { + if(msg == std::string("no debug info in ELF executable")) { + // https://github.com/jeremy-rifkin/cpptrace/issues/114 + // https://github.com/ianlancetaylor/libbacktrace/blob/ae1e707dbacd4a5cc82fcf2d3816f410e9c5fec4/elf.c#L592 + // not a critical error, just return + return; + } + throw internal_error("Libbacktrace error: {}, code {}", msg, errnum); + } + + backtrace_state* get_backtrace_state() { + static std::mutex mutex; + const std::lock_guard lock(mutex); + // backtrace_create_state must be called only one time per program + static backtrace_state* state = nullptr; + static bool called = false; + if(!called) { + state = backtrace_create_state(program_name(), true, error_callback, nullptr); + called = true; + } + return state; + } + + // TODO: Handle backtrace_pcinfo calling the callback multiple times on inlined functions + stacktrace_frame resolve_frame(const frame_ptr addr) { + try { + stacktrace_frame frame = null_frame; + frame.raw_address = addr; + backtrace_pcinfo( + get_backtrace_state(), + addr, + full_callback, + error_callback, + &frame + ); + if(frame.symbol.empty()) { + // fallback, try to at least recover the symbol name with backtrace_syminfo + backtrace_syminfo( + get_backtrace_state(), + addr, + syminfo_callback, + error_callback, + &frame + ); + } + return frame; + } catch(...) { // NOSONAR + if(!should_absorb_trace_exceptions()) { + throw; + } + return null_frame; + } + } + + std::vector resolve_frames(const std::vector& frames) { + std::vector trace; + trace.reserve(frames.size()); + for(const auto frame : frames) { + trace.push_back(resolve_frame(frame)); + } + return trace; + } +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/symbols_with_libdwarf.cpp b/dep/cpptrace/src/symbols/symbols_with_libdwarf.cpp new file mode 100644 index 00000000000..5456057ed8d --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols_with_libdwarf.cpp @@ -0,0 +1,145 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF + +#include "symbols/symbols.hpp" + +#include +#include "dwarf/resolver.hpp" +#include "utils/common.hpp" +#include "utils/error.hpp" +#include "utils/utils.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cpptrace { +namespace detail { +namespace libdwarf { + std::unique_ptr get_resolver_for_object(const std::string& object_path) { + #if IS_APPLE + // Check if dSYM exist, if not fallback to debug map + if(!directory_exists(object_path + ".dSYM")) { + return make_debug_map_resolver(object_path); + } + #endif + return make_dwarf_resolver(object_path); + } + + // not thread-safe, replies on caller to lock + maybe_owned get_resolver(const std::string& object_name) { + // cache resolvers since objects are likely to be traced more than once + static std::unordered_map> resolver_map; + auto it = resolver_map.find(object_name); + if(it != resolver_map.end()) { + return it->second.get(); + } else { + std::unique_ptr resolver_object = get_resolver_for_object(object_name); + if(get_cache_mode() == cache_mode::prioritize_speed) { + // .emplace needed, for some reason .insert tries to copy <= gcc 7.2 + return resolver_map.emplace(object_name, std::move(resolver_object)).first->second.get(); + } else { + // gcc 4 has trouble with automatic moves of locals here https://godbolt.org/z/9oWdWjbf8 + return maybe_owned{std::move(resolver_object)}; + } + } + } + + // flatten trace with inlines + std::vector flatten_inlines(std::vector& trace) { + std::vector final_trace; + for(auto& entry : trace) { + // most recent call first + if(!entry.inlines.empty()) { + // insert in reverse order + final_trace.insert( + final_trace.end(), + std::make_move_iterator(entry.inlines.rbegin()), + std::make_move_iterator(entry.inlines.rend()) + ); + } + final_trace.push_back(std::move(entry.frame)); + if(!entry.inlines.empty()) { + // rotate line info due to quirk of how dwarf stores this stuff + // inclusive range + auto begin = final_trace.end() - (1 + entry.inlines.size()); + auto end = final_trace.end() - 1; + auto carry_line = end->line; + auto carry_column = end->column; + std::string carry_filename = std::move(end->filename); + for(auto it = end; it != begin; it--) { + it->line = (it - 1)->line; + it->column = (it - 1)->column; + it->filename = std::move((it - 1)->filename); + } + begin->line = carry_line; + begin->column = carry_column; + begin->filename = std::move(carry_filename); + } + } + return final_trace; + } + + CPPTRACE_FORCE_NO_INLINE_FOR_PROFILING + std::vector resolve_frames(const std::vector& frames) { + std::vector trace(frames.size(), {null_frame, {}}); + // Locking around all libdwarf interaction per https://github.com/davea42/libdwarf-code/discussions/184 + // And also locking for interactions with get_resolver + static std::mutex mutex; + const std::lock_guard lock(mutex); + for(const auto& group : collate_frames(frames, trace)) { + try { + const auto& object_name = group.first; + auto resolver = get_resolver(object_name); + for(const auto& entry : group.second) { + const auto& dlframe = entry.first.get(); + auto& frame = entry.second.get(); + try { + frame = resolver->resolve_frame(dlframe); + } catch(...) { + frame.frame.raw_address = dlframe.raw_address; + frame.frame.object_address = dlframe.object_address; + frame.frame.filename = dlframe.object_path; + if(!should_absorb_trace_exceptions()) { + throw; + } + } + } + } catch(...) { // NOSONAR + if(!should_absorb_trace_exceptions()) { + throw; + } + } + } + // fill in basic info for any frames where there were resolution issues + for(std::size_t i = 0; i < frames.size(); i++) { + const auto& dlframe = frames[i]; + auto& frame = trace[i]; + if(frame.frame == null_frame) { + frame = { + { + dlframe.raw_address, + dlframe.object_address, + nullable::null(), + nullable::null(), + dlframe.object_path, + "", + false + }, + {} + }; + } + } + // flatten and finish + return flatten_inlines(trace); + } +} +} +} + +#endif diff --git a/dep/cpptrace/src/symbols/symbols_with_nothing.cpp b/dep/cpptrace/src/symbols/symbols_with_nothing.cpp new file mode 100644 index 00000000000..c1c4c6c69d2 --- /dev/null +++ b/dep/cpptrace/src/symbols/symbols_with_nothing.cpp @@ -0,0 +1,22 @@ +#ifdef CPPTRACE_GET_SYMBOLS_WITH_NOTHING + +#include +#include "symbols/symbols.hpp" + +#include + +namespace cpptrace { +namespace detail { +namespace nothing { + std::vector resolve_frames(const std::vector& frames) { + return std::vector(frames.size(), null_frame); + } + + std::vector resolve_frames(const std::vector& frames) { + return std::vector(frames.size(), null_frame); + } +} +} +} + +#endif diff --git a/dep/cpptrace/src/unwind/unwind.hpp b/dep/cpptrace/src/unwind/unwind.hpp new file mode 100644 index 00000000000..285ed4316fc --- /dev/null +++ b/dep/cpptrace/src/unwind/unwind.hpp @@ -0,0 +1,28 @@ +#ifndef UNWIND_HPP +#define UNWIND_HPP + +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#include +#include + +namespace cpptrace { +namespace detail { + #ifdef CPPTRACE_HARD_MAX_FRAMES + constexpr std::size_t hard_max_frames = CPPTRACE_HARD_MAX_FRAMES; + #else + constexpr std::size_t hard_max_frames = 400; + #endif + + CPPTRACE_FORCE_NO_INLINE + std::vector capture_frames(std::size_t skip, std::size_t max_depth); + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_capture_frames(frame_ptr* buffer, std::size_t size, std::size_t skip, std::size_t max_depth); + + bool has_safe_unwind(); +} +} + +#endif diff --git a/dep/cpptrace/src/unwind/unwind_with_dbghelp.cpp b/dep/cpptrace/src/unwind/unwind_with_dbghelp.cpp new file mode 100644 index 00000000000..bf74e7b172e --- /dev/null +++ b/dep/cpptrace/src/unwind/unwind_with_dbghelp.cpp @@ -0,0 +1,171 @@ +#ifdef CPPTRACE_UNWIND_WITH_DBGHELP + +#include +#include "unwind/unwind.hpp" +#include "utils/common.hpp" +#include "utils/utils.hpp" +#include "platform/dbghelp_syminit_manager.hpp" + +#include +#include +#include +#include + +#include +#include + +// Fucking windows headers +#ifdef min + #undef min +#endif + +namespace cpptrace { +namespace detail { + #if IS_MSVC + #pragma warning(push) + #pragma warning(disable: 4740) // warning C4740: flow in or out of inline asm code suppresses global optimization + #endif + CPPTRACE_FORCE_NO_INLINE + std::vector capture_frames(std::size_t skip, std::size_t max_depth) { + skip++; + // https://jpassing.com/2008/03/12/walking-the-stack-of-the-current-thread/ + + // Get current thread context + // GetThreadContext cannot be used on the current thread. + // RtlCaptureContext doesn't work on i386 + CONTEXT context; + #if defined(_M_IX86) || defined(__i386__) + ZeroMemory(&context, sizeof(CONTEXT)); + context.ContextFlags = CONTEXT_CONTROL; + #if IS_MSVC + __asm { + label: + mov [context.Ebp], ebp; + mov [context.Esp], esp; + mov eax, [label]; + mov [context.Eip], eax; + } + #else + asm( + "label:\n\t" + "mov{l %%ebp, %[cEbp] | %[cEbp], ebp};\n\t" + "mov{l %%esp, %[cEsp] | %[cEsp], esp};\n\t" + "mov{l $label, %%eax | eax, OFFSET label};\n\t" + "mov{l %%eax, %[cEip] | %[cEip], eax};\n\t" + : [cEbp] "=r" (context.Ebp), + [cEsp] "=r" (context.Esp), + [cEip] "=r" (context.Eip) + ); + #endif + #else + RtlCaptureContext(&context); + #endif + // Setup current frame + STACKFRAME64 frame; + ZeroMemory(&frame, sizeof(STACKFRAME64)); + DWORD machine_type; + #if defined(_M_IX86) || defined(__i386__) + machine_type = IMAGE_FILE_MACHINE_I386; + frame.AddrPC.Offset = context.Eip; + frame.AddrPC.Mode = AddrModeFlat; + frame.AddrFrame.Offset = context.Ebp; + frame.AddrFrame.Mode = AddrModeFlat; + frame.AddrStack.Offset = context.Esp; + frame.AddrStack.Mode = AddrModeFlat; + #elif defined(_M_X64) || defined(__x86_64__) + machine_type = IMAGE_FILE_MACHINE_AMD64; + frame.AddrPC.Offset = context.Rip; + frame.AddrPC.Mode = AddrModeFlat; + frame.AddrFrame.Offset = context.Rsp; + frame.AddrFrame.Mode = AddrModeFlat; + frame.AddrStack.Offset = context.Rsp; + frame.AddrStack.Mode = AddrModeFlat; + #elif defined(_M_IA64) || defined(__aarch64__) + machine_type = IMAGE_FILE_MACHINE_IA64; + frame.AddrPC.Offset = context.StIIP; + frame.AddrPC.Mode = AddrModeFlat; + frame.AddrFrame.Offset = context.IntSp; + frame.AddrFrame.Mode = AddrModeFlat; + frame.AddrBStore.Offset= context.RsBSP; + frame.AddrBStore.Mode = AddrModeFlat; + frame.AddrStack.Offset = context.IntSp; + frame.AddrStack.Mode = AddrModeFlat; + #else + #error "Cpptrace: StackWalk64 not supported for this platform yet" + #endif + + std::vector trace; + + // Dbghelp is is single-threaded, so acquire a lock. + static std::mutex mutex; + std::lock_guard lock(mutex); + // For some reason SymInitialize must be called before StackWalk64 + // Note that the code assumes that + // SymInitialize( GetCurrentProcess(), NULL, TRUE ) has + // already been called. + // + HANDLE proc = GetCurrentProcess(); + HANDLE thread = GetCurrentThread(); + if(get_cache_mode() == cache_mode::prioritize_speed) { + get_syminit_manager().init(proc); + } else { + if(!SymInitialize(proc, NULL, TRUE)) { + throw internal_error("SymInitialize failed"); + } + } + while(trace.size() < max_depth) { + if( + !StackWalk64( + machine_type, + proc, + thread, + &frame, + machine_type == IMAGE_FILE_MACHINE_I386 ? NULL : &context, + NULL, + SymFunctionTableAccess64, + SymGetModuleBase64, + NULL + ) + ) { + // Either failed or finished walking + break; + } + if(frame.AddrPC.Offset != 0) { + // Valid frame + if(skip) { + skip--; + } else { + // On x86/x64/arm, as far as I can tell, the frame return address is always one after the call + // So we just decrement to get the pc back inside the `call` / `bl` + // This is done with _Unwind too but conditionally based on info from _Unwind_GetIPInfo. + trace.push_back(to_frame_ptr(frame.AddrPC.Offset) - 1); + } + } else { + // base + break; + } + } + if(get_cache_mode() != cache_mode::prioritize_speed) { + if(!SymCleanup(proc)) { + throw internal_error("SymCleanup failed"); + } + } + return trace; + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_capture_frames(frame_ptr*, std::size_t, std::size_t, std::size_t) { + // Can't safe trace with dbghelp + return 0; + } + #if IS_MSVC + #pragma warning(pop) + #endif + + bool has_safe_unwind() { + return false; + } +} +} + +#endif diff --git a/dep/cpptrace/src/unwind/unwind_with_execinfo.cpp b/dep/cpptrace/src/unwind/unwind_with_execinfo.cpp new file mode 100644 index 00000000000..d5734ba5a60 --- /dev/null +++ b/dep/cpptrace/src/unwind/unwind_with_execinfo.cpp @@ -0,0 +1,45 @@ +#ifdef CPPTRACE_UNWIND_WITH_EXECINFO + +#include "unwind/unwind.hpp" +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#include +#include +#include +#include + +#include + +namespace cpptrace { +namespace detail { + CPPTRACE_FORCE_NO_INLINE + std::vector capture_frames(std::size_t skip, std::size_t max_depth) { + skip++; + std::vector addrs(skip + std::min(hard_max_frames, max_depth), nullptr); + // thread safe + const int n_frames = backtrace(addrs.data(), static_cast(addrs.size())); + // I hate the copy here but it's the only way that isn't UB + std::vector frames(n_frames - skip, 0); + for(int i = skip; i < n_frames; i++) { + // On x86/x64/arm, as far as I can tell, the frame return address is always one after the call + // So we just decrement to get the pc back inside the `call` / `bl` + // This is done with _Unwind too but conditionally based on info from _Unwind_GetIPInfo. + frames[i - skip] = reinterpret_cast(addrs[i]) - 1; + } + return frames; + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_capture_frames(frame_ptr*, std::size_t, std::size_t, std::size_t) { + // Can't safe trace with execinfo + return 0; + } + + bool has_safe_unwind() { + return false; + } +} +} + +#endif diff --git a/dep/cpptrace/src/unwind/unwind_with_libunwind.cpp b/dep/cpptrace/src/unwind/unwind_with_libunwind.cpp new file mode 100644 index 00000000000..f537abd26b1 --- /dev/null +++ b/dep/cpptrace/src/unwind/unwind_with_libunwind.cpp @@ -0,0 +1,87 @@ +#ifdef CPPTRACE_UNWIND_WITH_LIBUNWIND + +#include "unwind/unwind.hpp" +#include "utils/common.hpp" +#include "utils/error.hpp" +#include "utils/utils.hpp" + +#include +#include +#include +#include + +#include + +namespace cpptrace { +namespace detail { + CPPTRACE_FORCE_NO_INLINE + std::vector capture_frames(std::size_t skip, std::size_t max_depth) { + skip++; + std::vector frames; + unw_context_t context; + unw_cursor_t cursor; + unw_getcontext(&context); + unw_init_local(&cursor, &context); + do { + unw_word_t pc; + unw_word_t sp; + unw_get_reg(&cursor, UNW_REG_IP, &pc); + unw_get_reg(&cursor, UNW_REG_SP, &sp); + if(skip) { + skip--; + } else { + // pc is the instruction after the `call`, adjust back to the previous instruction + frames.push_back(to_frame_ptr(pc) - 1); + } + } while(unw_step(&cursor) > 0 && frames.size() < max_depth); + return frames; + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_capture_frames(frame_ptr* buffer, std::size_t size, std::size_t skip, std::size_t max_depth) { + // some code duplication, but whatever + skip++; + unw_context_t context; + unw_cursor_t cursor; + // thread and signal-safe https://www.nongnu.org/libunwind/man/unw_getcontext(3).html + unw_getcontext(&context); + // thread and signal-safe https://www.nongnu.org/libunwind/man/unw_init_local(3).html + unw_init_local(&cursor, &context); + size_t i = 0; + while(i < size && i < max_depth) { + unw_word_t pc; + unw_word_t sp; + // thread and signal-safe https://www.nongnu.org/libunwind/man/unw_get_reg(3).html + unw_get_reg(&cursor, UNW_REG_IP, &pc); + unw_get_reg(&cursor, UNW_REG_SP, &sp); + if(skip) { + skip--; + } else { + // thread and signal-safe + if(unw_is_signal_frame(&cursor)) { + // pc is the instruction that caused the signal + // just a cast, thread and signal safe + buffer[i] = to_frame_ptr(pc); + } else { + // pc is the instruction after the `call`, adjust back to the previous instruction + // just a cast, thread and signal safe + buffer[i] = to_frame_ptr(pc) - 1; + } + i++; + } + // thread and signal-safe as long as the cursor is in the local address space, which it is + // https://www.nongnu.org/libunwind/man/unw_step(3).html + if(unw_step(&cursor) <= 0) { + break; + } + } + return i; + } + + bool has_safe_unwind() { + return true; + } +} +} + +#endif diff --git a/dep/cpptrace/src/unwind/unwind_with_nothing.cpp b/dep/cpptrace/src/unwind/unwind_with_nothing.cpp new file mode 100644 index 00000000000..8099d4a1c2a --- /dev/null +++ b/dep/cpptrace/src/unwind/unwind_with_nothing.cpp @@ -0,0 +1,25 @@ +#ifdef CPPTRACE_UNWIND_WITH_NOTHING + +#include "unwind/unwind.hpp" + +#include +#include + +namespace cpptrace { +namespace detail { + std::vector capture_frames(std::size_t, std::size_t) { + return {}; + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_capture_frames(frame_ptr*, std::size_t, std::size_t, std::size_t) { + return 0; + } + + bool has_safe_unwind() { + return false; + } +} +} + +#endif diff --git a/dep/cpptrace/src/unwind/unwind_with_unwind.cpp b/dep/cpptrace/src/unwind/unwind_with_unwind.cpp new file mode 100644 index 00000000000..da3cc57031a --- /dev/null +++ b/dep/cpptrace/src/unwind/unwind_with_unwind.cpp @@ -0,0 +1,75 @@ +#ifdef CPPTRACE_UNWIND_WITH_UNWIND + +#include "unwind/unwind.hpp" +#include "utils/common.hpp" +#include "utils/error.hpp" +#include "utils/utils.hpp" + +#include +#include +#include +#include + +#include + +namespace cpptrace { +namespace detail { + struct unwind_state { + std::size_t skip; + std::size_t max_depth; + std::vector& vec; + }; + + _Unwind_Reason_Code unwind_callback(_Unwind_Context* context, void* arg) { + unwind_state& state = *static_cast(arg); + if(state.skip) { + state.skip--; + if(_Unwind_GetIP(context) == frame_ptr(0)) { + return _URC_END_OF_STACK; + } else { + return _URC_NO_REASON; + } + } + + ASSERT( + state.vec.size() < state.max_depth, + "Somehow cpptrace::detail::unwind_callback is being called beyond the max_depth" + ); + int is_before_instruction = 0; + frame_ptr ip = _Unwind_GetIPInfo(context, &is_before_instruction); + if(!is_before_instruction && ip != frame_ptr(0)) { + ip--; + } + if (ip == frame_ptr(0)) { + return _URC_END_OF_STACK; + } else { + state.vec.push_back(ip); + if(state.vec.size() >= state.max_depth) { + return _URC_END_OF_STACK; + } else { + return _URC_NO_REASON; + } + } + } + + CPPTRACE_FORCE_NO_INLINE + std::vector capture_frames(std::size_t skip, std::size_t max_depth) { + std::vector frames; + unwind_state state{skip + 1, max_depth, frames}; + _Unwind_Backtrace(unwind_callback, &state); // presumably thread-safe + return frames; + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_capture_frames(frame_ptr*, std::size_t, std::size_t, std::size_t) { + // Can't safe trace with _Unwind + return 0; + } + + bool has_safe_unwind() { + return false; + } +} +} + +#endif diff --git a/dep/cpptrace/src/unwind/unwind_with_winapi.cpp b/dep/cpptrace/src/unwind/unwind_with_winapi.cpp new file mode 100644 index 00000000000..e5888e9aa5f --- /dev/null +++ b/dep/cpptrace/src/unwind/unwind_with_winapi.cpp @@ -0,0 +1,53 @@ +#ifdef CPPTRACE_UNWIND_WITH_WINAPI + +#include +#include "unwind/unwind.hpp" +#include "utils/common.hpp" +#include "utils/utils.hpp" + +#include +#include +#include + +#include + +// Fucking windows headers +#ifdef min + #undef min +#endif + +namespace cpptrace { +namespace detail { + CPPTRACE_FORCE_NO_INLINE + std::vector capture_frames(std::size_t skip, std::size_t max_depth) { + std::vector addrs(skip + std::min(hard_max_frames, max_depth), nullptr); + std::size_t n_frames = CaptureStackBackTrace( + static_cast(skip + 1), + static_cast(addrs.size()), + addrs.data(), + NULL + ); + // I hate the copy here but it's the only way that isn't UB + std::vector frames(n_frames, 0); + for(std::size_t i = 0; i < n_frames; i++) { + // On x86/x64/arm, as far as I can tell, the frame return address is always one after the call + // So we just decrement to get the pc back inside the `call` / `bl` + // This is done with _Unwind too but conditionally based on info from _Unwind_GetIPInfo. + frames[i] = reinterpret_cast(addrs[i]) - 1; + } + return frames; + } + + CPPTRACE_FORCE_NO_INLINE + std::size_t safe_capture_frames(frame_ptr*, std::size_t, std::size_t, std::size_t) { + // Can't safe trace with winapi + return 0; + } + + bool has_safe_unwind() { + return false; + } +} +} + +#endif diff --git a/dep/cpptrace/src/utils/common.hpp b/dep/cpptrace/src/utils/common.hpp new file mode 100644 index 00000000000..a323b48a618 --- /dev/null +++ b/dep/cpptrace/src/utils/common.hpp @@ -0,0 +1,47 @@ +#ifndef COMMON_HPP +#define COMMON_HPP + +#include +#include +#include + +#include + +#include "platform/platform.hpp" + +#define ESC "\033[" +#define RESET ESC "0m" +#define RED ESC "31m" +#define GREEN ESC "32m" +#define YELLOW ESC "33m" +#define BLUE ESC "34m" +#define MAGENTA ESC "35m" +#define CYAN ESC "36m" + +#if IS_GCC || IS_CLANG + #define NODISCARD __attribute__((warn_unused_result)) +// #elif IS_MSVC && _MSC_VER >= 1700 +// #define NODISCARD _Check_return_ +#else + #define NODISCARD +#endif + +namespace cpptrace { +namespace detail { + static const stacktrace_frame null_frame { + 0, + 0, + nullable::null(), + nullable::null(), + "", + "", + false + }; + + bool should_absorb_trace_exceptions(); + bool should_resolve_inlined_calls(); + cache_mode get_cache_mode(); +} +} + +#endif diff --git a/dep/cpptrace/src/utils/error.hpp b/dep/cpptrace/src/utils/error.hpp new file mode 100644 index 00000000000..6594c1fad4e --- /dev/null +++ b/dep/cpptrace/src/utils/error.hpp @@ -0,0 +1,171 @@ +#ifndef ERROR_HPP +#define ERROR_HPP + +#include +#include +#include +#include + +#include "utils/common.hpp" +#include "utils/microfmt.hpp" + +#if IS_MSVC + #define CPPTRACE_PFUNC __FUNCSIG__ +#else + #define CPPTRACE_PFUNC __extension__ __PRETTY_FUNCTION__ +#endif + +namespace cpptrace { +namespace detail { + class internal_error : public std::exception { + std::string msg; + public: + internal_error(std::string message) : msg("Cpptrace internal error: " + std::move(message)) {} + template + internal_error(const char* format, Args&&... args) : internal_error(microfmt::format(format, args...)) {} + const char* what() const noexcept override { + return msg.c_str(); + } + }; + + // Lightweight std::source_location. + struct source_location { + const char* const file; + const int line; + constexpr source_location( + const char* _file, + int _line + ) : file(_file), line(_line) {} + }; + + #define CPPTRACE_CURRENT_LOCATION ::cpptrace::detail::source_location(__FILE__, __LINE__) + + enum class assert_type { + assert, + verify, + panic, + }; + + constexpr const char* assert_actions[] = {"assertion", "verification", "panic"}; + constexpr const char* assert_names[] = {"ASSERT", "VERIFY", "PANIC"}; + + [[noreturn]] inline void assert_fail( + assert_type type, + const char* expression, + const char* signature, + source_location location, + const char* message + ) { + const char* action = assert_actions[static_cast::type>(type)]; + const char* name = assert_names[static_cast::type>(type)]; + if(message == nullptr) { + throw internal_error( + "Cpptrace {} failed at {}:{}: {}\n" + " {}({});\n", + action, location.file, location.line, signature, + name, expression + ); + } else { + throw internal_error( + "Cpptrace {} failed at {}:{}: {}: {}\n" + " {}({});\n", + action, location.file, location.line, signature, message, + name, expression + ); + } + } + + [[noreturn]] inline void panic( + const char* signature, + source_location location, + const std::string& message = "" + ) { + if(message == "") { + throw internal_error( + "Cpptrace panic {}:{}: {}\n", + location.file, location.line, signature + ); + } else { + throw internal_error( + "Cpptrace panic {}:{}: {}: {}\n", + location.file, location.line, signature, message.c_str() + ); + } + } + + template + void nullfn() { + // this method doesn't do anything and is never called. + } + + #define PHONY_USE(...) (nullfn()) + + // Work around a compiler warning + template + bool as_bool(T&& value) { + return static_cast(std::forward(value)); + } + + // Work around a compiler warning + template + std::string as_string(T&& value) { + return std::string(std::forward(value)); + } + + inline std::string as_string() { + return ""; + } + + // Check condition in both debug and release. std::runtime_error on failure. + #define PANIC(...) ((::cpptrace::detail::panic)(CPPTRACE_PFUNC, CPPTRACE_CURRENT_LOCATION, ::cpptrace::detail::as_string(__VA_ARGS__))) + + template + void assert_impl( + T condition, + const char* message, + assert_type type, + const char* args, + const char* signature, + source_location location + ) { + if(!as_bool(condition)) { + assert_fail(type, args, signature, location, message); + } + } + + template + void assert_impl( + T condition, + assert_type type, + const char* args, + const char* signature, + source_location location + ) { + assert_impl( + condition, + nullptr, + type, + args, + signature, + location + ); + } + + // Check condition in both debug and release. std::runtime_error on failure. + #define VERIFY(...) ( \ + assert_impl(__VA_ARGS__, ::cpptrace::detail::assert_type::verify, #__VA_ARGS__, CPPTRACE_PFUNC, CPPTRACE_CURRENT_LOCATION) \ + ) + + #ifndef NDEBUG + // Check condition in both debug. std::runtime_error on failure. + #define ASSERT(...) ( \ + assert_impl(__VA_ARGS__, ::cpptrace::detail::assert_type::assert, #__VA_ARGS__, CPPTRACE_PFUNC, CPPTRACE_CURRENT_LOCATION) \ + ) + #else + // Check condition in both debug. std::runtime_error on failure. + #define ASSERT(...) PHONY_USE(__VA_ARGS__) + #endif +} +} + +#endif diff --git a/dep/cpptrace/src/utils/microfmt.hpp b/dep/cpptrace/src/utils/microfmt.hpp new file mode 100644 index 00000000000..66bb34fa1b2 --- /dev/null +++ b/dep/cpptrace/src/utils/microfmt.hpp @@ -0,0 +1,307 @@ +#ifndef MICROFMT_HPP +#define MICROFMT_HPP + +#include +#include +#include +#include +#include +#include +#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + #include +#endif +#ifdef _MSC_VER + #include +#endif + +// https://github.com/jeremy-rifkin/microfmt +// Format: {[align][width][:[fill][base]]} # width: number or {} + +namespace cpptrace { +namespace microfmt { + namespace detail { + inline std::uint64_t clz(std::uint64_t value) { + #ifdef _MSC_VER + unsigned long out = 0; + #ifdef _WIN64 + _BitScanReverse64(&out, value); + #else + if(_BitScanReverse(&out, std::uint32_t(value >> 32))) { + return 63 - int(out + 32); + } + _BitScanReverse(&out, std::uint32_t(value)); + #endif + return 63 - out; + #else + return __builtin_clzll(value); + #endif + } + + template U to(V v) { + return static_cast(v); // A way to cast to U without "warning: useless cast to type" + } + + enum class alignment { left, right }; + + struct format_options { + alignment align = alignment::left; + char fill = ' '; + size_t width = 0; + char base = 'd'; + }; + + template void do_write(std::string& out, It begin, It end, const format_options& options) { + auto size = end - begin; + if(static_cast(size) >= options.width) { + out.append(begin, end); + } else { + auto out_size = out.size(); + out.resize(out_size + options.width); + if(options.align == alignment::left) { + std::copy(begin, end, out.begin() + out_size); + std::fill(out.begin() + out_size + size, out.end(), options.fill); + } else { + std::fill(out.begin() + out_size, out.begin() + out_size + (options.width - size), options.fill); + std::copy(begin, end, out.begin() + out_size + (options.width - size)); + } + } + } + + template + std::string to_string(std::uint64_t value, const char* digits = "0123456789abcdef") { + if(value == 0) { + return "0"; + } else { + // digits = floor(1 + log_base(x)) + // log_base(x) = log_2(x) / log_2(base) + // log_2(x) == 63 - clz(x) + // 1 + (63 - clz(value)) / (63 - clz(1 << shift)) + // 63 - clz(1 << shift) is the same as shift + auto n_digits = to(1 + (63 - clz(value)) / shift); + std::string number; + number.resize(n_digits); + std::size_t i = n_digits - 1; + while(value > 0) { + number[i--] = digits[value & mask]; + value >>= shift; + } + return number; + } + } + + inline std::string to_string(std::uint64_t value, const format_options& options) { + switch(options.base) { + case 'H': return to_string<4, 0xf>(value, "0123456789ABCDEF"); + case 'h': return to_string<4, 0xf>(value); + case 'o': return to_string<3, 0x7>(value); + case 'b': return to_string<1, 0x1>(value); + default: return std::to_string(value); // failure: decimal + } + } + + class format_value { + enum class value_type { + char_value, + int64_value, + uint64_value, + string_value, + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + string_view_value, + #endif + c_string_value, + }; + union { + char char_value; + std::int64_t int64_value; + std::uint64_t uint64_value; + const std::string* string_value; + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + std::string_view string_view_value; + #endif + const char* c_string_value; + }; + value_type value; + + public: + format_value(char c) : char_value(c), value(value_type::char_value) {} + format_value(short int_val) : int64_value(int_val), value(value_type::int64_value) {} + format_value(int int_val) : int64_value(int_val), value(value_type::int64_value) {} + format_value(long int_val) : int64_value(int_val), value(value_type::int64_value) {} + format_value(long long int_val) : int64_value(int_val), value(value_type::int64_value) {} + format_value(unsigned char int_val) : uint64_value(int_val), value(value_type::uint64_value) {} + format_value(unsigned short int_val) : uint64_value(int_val), value(value_type::uint64_value) {} + format_value(unsigned int int_val) : uint64_value(int_val), value(value_type::uint64_value) {} + format_value(unsigned long int_val) : uint64_value(int_val), value(value_type::uint64_value) {} + format_value(unsigned long long int_val) : uint64_value(int_val), value(value_type::uint64_value) {} + format_value(const std::string& string) : string_value(&string), value(value_type::string_value) {} + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + format_value(std::string_view sv) : string_view_value(sv), value(value_type::string_view_value) {} + #endif + format_value(const char* c_string) : c_string_value(c_string), value(value_type::c_string_value) {} + + int unwrap_int() const { + switch(value) { + case value_type::int64_value: return static_cast(int64_value); + case value_type::uint64_value: return static_cast(uint64_value); + default: return 0; // failure: just 0 + } + } + + public: + void write(std::string& out, const format_options& options) const { + switch(value) { + case value_type::char_value: + do_write(out, &char_value, &char_value + 1, options); + break; + case value_type::int64_value: + { + std::string str; + std::int64_t val = int64_value; + if(val < 0) { + str += '-'; + val *= -1; + } + str += to_string(static_cast(val), options); + do_write(out, str.begin(), str.end(), options); + } + break; + case value_type::uint64_value: + { + std::string str = to_string(uint64_value, options); + do_write(out, str.begin(), str.end(), options); + } + break; + case value_type::string_value: + do_write(out, string_value->begin(), string_value->end(), options); + break; + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + case value_type::string_view_value: + do_write(out, string_view_value.begin(), string_view_value.end(), options); + break; + #endif + case value_type::c_string_value: + do_write(out, c_string_value, c_string_value + std::strlen(c_string_value), options); + break; + } // failure: nop + } + }; + + template + std::string format(It fmt_begin, It fmt_end, std::array args) { + std::string str; + std::size_t arg_i = 0; + auto it = fmt_begin; + auto peek = [&] (std::size_t dist) -> char { // 0 on failure + return fmt_end - it > signed(dist) ? *(it + dist) : 0; + }; + auto read_number = [&] () -> int { // -1 on failure + auto scan = it; + int num = 0; + while(scan != fmt_end && isdigit(*scan)) { + num *= 10; + num += *scan - '0'; + scan++; + } + if(scan != it) { + it = scan; + return num; + } else { + return -1; + } + }; + for(; it != fmt_end; it++) { + if((*it == '{' || *it == '}') && peek(1) == *it) { // parse {{ and }} escapes + it++; + } else if(*it == '{' && it + 1 != fmt_end) { + auto saved_it = it; + auto handle_formatter = [&] () { + it++; + format_options options; + // try to parse alignment + if(*it == '<' || *it == '>') { + options.align = *it++ == '<' ? alignment::left : alignment::right; + } + // try to parse width + auto width = read_number(); // handles fmt_end check + if(width != -1) { + options.width = width; + } else if(it != fmt_end && *it == '{') { // try to parse variable width + if(peek(1) != '}') { + return false; + } + it += 2; + options.width = arg_i < args.size() ? args[arg_i++].unwrap_int() : 0; + } + // try to parse fill/base + if(it != fmt_end && *it == ':') { + it++; + if(fmt_end - it > 1 && *it != '}' && peek(1) != '}') { // two chars before the }, fill+base + options.fill = *it++; + options.base = *it++; + } else if(it != fmt_end && *it != '}') { // one char before the }, just base + if(*it == 'd' || *it == 'h' || *it == 'H' || *it == 'o' || *it == 'b') { + options.base = *it++; + } else { + options.fill = *it++; + } + } + } + if(it == fmt_end || *it != '}') { + return false; + } + if(arg_i < args.size()) { + args[arg_i++].write(str, options); + } + return true; + }; + if(handle_formatter()) { + continue; // If reached here, successfully parsed and wrote a formatter. Don't write *it. + } + it = saved_it; // go back + } + str += *it; + } + return str; + } + } + + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + template + std::string format(std::string_view fmt, Args&&... args) { + return detail::format(fmt.begin(), fmt.end(), {detail::format_value(args)...}); + } + + // working around an old msvc bug https://godbolt.org/z/88T8hrzzq mre: https://godbolt.org/z/drd8echbP + inline std::string format(std::string_view fmt) { + return detail::format<1>(fmt.begin(), fmt.end(), {detail::format_value(1)}); + } + #endif + + template + std::string format(const char* fmt, Args&&... args) { + return detail::format(fmt, fmt + std::strlen(fmt), {detail::format_value(args)...}); + } + + inline std::string format(const char* fmt) { + return detail::format<1>(fmt, fmt + std::strlen(fmt), {detail::format_value(1)}); + } + + template + void print(const S& fmt, Args&&... args) { + std::cout< + void print(std::ostream& ostream, const S& fmt, Args&&... args) { + ostream< + void print(std::FILE* stream, const S& fmt, Args&&... args) { + auto str = format(fmt, args...); + fwrite(str.data(), 1, str.size(), stream); + } +} +} + +#endif diff --git a/dep/cpptrace/src/utils/utils.hpp b/dep/cpptrace/src/utils/utils.hpp new file mode 100644 index 00000000000..ecea6920660 --- /dev/null +++ b/dep/cpptrace/src/utils/utils.hpp @@ -0,0 +1,590 @@ +#ifndef UTILS_HPP +#define UTILS_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils/common.hpp" +#include "utils/error.hpp" +#include "utils/microfmt.hpp" + +#if IS_WINDOWS + #include + #include +#else + #include + #include +#endif + +namespace cpptrace { +namespace detail { + inline bool isatty(int fd) { + #if IS_WINDOWS + return _isatty(fd); + #else + return ::isatty(fd); + #endif + } + + inline int fileno(std::FILE* stream) { + #if IS_WINDOWS + return _fileno(stream); + #else + return ::fileno(stream); + #endif + } + + inline std::vector split(const std::string& str, const std::string& delims) { + std::vector vec; + std::size_t old_pos = 0; + std::size_t pos = 0; + while((pos = str.find_first_of(delims, old_pos)) != std::string::npos) { + vec.emplace_back(str.substr(old_pos, pos - old_pos)); + old_pos = pos + 1; + } + vec.emplace_back(str.substr(old_pos)); + return vec; + } + + template + inline std::string join(const C& container, const std::string& delim) { + auto iter = std::begin(container); + auto end = std::end(container); + std::string str; + if(std::distance(iter, end) > 0) { + str += *iter; + while(++iter != end) { + str += delim; + str += *iter; + } + } + return str; + } + + // first value in a sorted range such that *it <= value + template + ForwardIt first_less_than_or_equal(ForwardIt begin, ForwardIt end, const T& value) { + auto it = std::upper_bound(begin, end, value); + // it is first > value, we want first <= value + if(it != begin) { + return --it; + } + return end; + } + + // first value in a sorted range such that *it <= value + template + ForwardIt first_less_than_or_equal(ForwardIt begin, ForwardIt end, const T& value, Compare compare) { + auto it = std::upper_bound(begin, end, value, compare); + // it is first > value, we want first <= value + if(it != begin) { + return --it; + } + return end; + } + + constexpr const char* const whitespace = " \t\n\r\f\v"; + + inline std::string trim(const std::string& str) { + if(str.empty()) { + return ""; + } + const std::size_t left = str.find_first_not_of(whitespace); + const std::size_t right = str.find_last_not_of(whitespace) + 1; + return str.substr(left, right - left); + } + + inline bool is_little_endian() { + std::uint16_t num = 0x1; + const auto* ptr = (std::uint8_t*)# + return ptr[0] == 1; + } + + // Modified from + // https://stackoverflow.com/questions/105252/how-do-i-convert-between-big-endian-and-little-endian-values-in-c + template + struct byte_swapper; + + template + struct byte_swapper { + T operator()(T val) { + return val; + } + }; + + template + struct byte_swapper { + T operator()(T val) { + return (((val >> 8) & 0xff) | ((val & 0xff) << 8)); + } + }; + + template + struct byte_swapper { + T operator()(T val) { + return (((val & 0xff000000) >> 24) | + ((val & 0x00ff0000) >> 8) | + ((val & 0x0000ff00) << 8) | + ((val & 0x000000ff) << 24)); + } + }; + + template + struct byte_swapper { + T operator()(T val) { + return (((val & 0xff00000000000000ULL) >> 56) | + ((val & 0x00ff000000000000ULL) >> 40) | + ((val & 0x0000ff0000000000ULL) >> 24) | + ((val & 0x000000ff00000000ULL) >> 8 ) | + ((val & 0x00000000ff000000ULL) << 8 ) | + ((val & 0x0000000000ff0000ULL) << 24) | + ((val & 0x000000000000ff00ULL) << 40) | + ((val & 0x00000000000000ffULL) << 56)); + } + }; + + template::value, int>::type = 0> + T byteswap(T value) { + return byte_swapper{}(value); + } + + inline void enable_virtual_terminal_processing_if_needed() noexcept { + // enable colors / ansi processing if necessary + #if IS_WINDOWS + // https://docs.microsoft.com/en-us/windows/console/console-virtual-terminal-sequences#example-of-enabling-virtual-terminal-processing + #ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING + constexpr DWORD ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x4; + #endif + HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); + DWORD dwMode = 0; + if(hOut == INVALID_HANDLE_VALUE) return; + if(!GetConsoleMode(hOut, &dwMode)) return; + if(dwMode != (dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) + if(!SetConsoleMode(hOut, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) return; + #endif + } + + constexpr unsigned n_digits(unsigned value) noexcept { + return value < 10 ? 1 : 1 + n_digits(value / 10); + } + static_assert(n_digits(1) == 1, "n_digits utility producing the wrong result"); + static_assert(n_digits(9) == 1, "n_digits utility producing the wrong result"); + static_assert(n_digits(10) == 2, "n_digits utility producing the wrong result"); + static_assert(n_digits(11) == 2, "n_digits utility producing the wrong result"); + static_assert(n_digits(1024) == 4, "n_digits utility producing the wrong result"); + + struct nullopt_t {}; + + static constexpr nullopt_t nullopt; + + template< + typename T, + typename std::enable_if::type, void>::value, int>::type = 0 + > + class optional { + union { + char x; + T uvalue; + }; + + bool holds_value = false; + + public: + optional() noexcept {} + + optional(nullopt_t) noexcept {} + + ~optional() { + reset(); + } + + optional(const optional& other) : holds_value(other.holds_value) { + if(holds_value) { + new (static_cast(std::addressof(uvalue))) T(other.uvalue); + } + } + + optional(optional&& other) + noexcept(std::is_nothrow_move_constructible::value) + : holds_value(other.holds_value) + { + if(holds_value) { + new (static_cast(std::addressof(uvalue))) T(std::move(other.uvalue)); + } + } + + optional& operator=(const optional& other) { + optional copy(other); + swap(copy); + return *this; + } + + optional& operator=(optional&& other) + noexcept(std::is_nothrow_move_assignable::value && std::is_nothrow_move_constructible::value) + { + reset(); + if(other.holds_value) { + new (static_cast(std::addressof(uvalue))) T(std::move(other.uvalue)); + holds_value = true; + } + return *this; + } + + template< + typename U = T, + typename std::enable_if::type, optional>::value, int>::type = 0 + > + optional(U&& value) : holds_value(true) { + new (static_cast(std::addressof(uvalue))) T(std::forward(value)); + } + + template< + typename U = T, + typename std::enable_if::type, optional>::value, int>::type = 0 + > + optional& operator=(U&& value) { + optional o(std::forward(value)); + swap(o); + return *this; + } + + optional& operator=(nullopt_t) noexcept { + reset(); + return *this; + } + + void swap(optional& other) noexcept { + if(holds_value && other.holds_value) { + std::swap(uvalue, other.uvalue); + } else if(holds_value && !other.holds_value) { + new (&other.uvalue) T(std::move(uvalue)); + uvalue.~T(); + } else if(!holds_value && other.holds_value) { + new (static_cast(std::addressof(uvalue))) T(std::move(other.uvalue)); + other.uvalue.~T(); + } + std::swap(holds_value, other.holds_value); + } + + bool has_value() const { + return holds_value; + } + + explicit operator bool() const { + return holds_value; + } + + void reset() { + if(holds_value) { + uvalue.~T(); + } + holds_value = false; + } + + NODISCARD T& unwrap() & { + ASSERT(holds_value, "Optional does not contain a value"); + return uvalue; + } + + NODISCARD const T& unwrap() const & { + ASSERT(holds_value, "Optional does not contain a value"); + return uvalue; + } + + NODISCARD T&& unwrap() && { + ASSERT(holds_value, "Optional does not contain a value"); + return std::move(uvalue); + } + + NODISCARD const T&& unwrap() const && { + ASSERT(holds_value, "Optional does not contain a value"); + return std::move(uvalue); + } + + template + NODISCARD T value_or(U&& default_value) const & { + return holds_value ? uvalue : static_cast(std::forward(default_value)); + } + + template + NODISCARD T value_or(U&& default_value) && { + return holds_value ? std::move(uvalue) : static_cast(std::forward(default_value)); + } + }; + + extern std::atomic_bool absorb_trace_exceptions; + + template::value, int>::type = 0> + class Result { + union { + T value_; + E error_; + }; + enum class member { value, error }; + member active; + public: + Result(T&& value) : value_(std::move(value)), active(member::value) {} + Result(E&& error) : error_(std::move(error)), active(member::error) { + if(!absorb_trace_exceptions.load()) { + std::fprintf(stderr, "%s\n", unwrap_error().what()); + } + } + Result(T& value) : value_(T(value)), active(member::value) {} + Result(E& error) : error_(E(error)), active(member::error) { + if(!absorb_trace_exceptions.load()) { + std::fprintf(stderr, "%s\n", unwrap_error().what()); + } + } + Result(Result&& other) : active(other.active) { + if(other.active == member::value) { + new (&value_) T(std::move(other.value_)); + } else { + new (&error_) E(std::move(other.error_)); + } + } + ~Result() { + if(active == member::value) { + value_.~T(); + } else { + error_.~E(); + } + } + + bool has_value() const { + return active == member::value; + } + + bool is_error() const { + return active == member::error; + } + + explicit operator bool() const { + return has_value(); + } + + NODISCARD optional value() const & { + return has_value() ? value_ : nullopt; + } + + NODISCARD optional error() const & { + return is_error() ? error_ : nullopt; + } + + NODISCARD optional value() && { + return has_value() ? std::move(value_) : nullopt; + } + + NODISCARD optional error() && { + return is_error() ? std::move(error_) : nullopt; + } + + NODISCARD T& unwrap_value() & { + ASSERT(has_value(), "Result does not contain a value"); + return value_; + } + + NODISCARD const T& unwrap_value() const & { + ASSERT(has_value(), "Result does not contain a value"); + return value_; + } + + NODISCARD T unwrap_value() && { + ASSERT(has_value(), "Result does not contain a value"); + return std::move(value_); + } + + NODISCARD E& unwrap_error() & { + ASSERT(is_error(), "Result does not contain an error"); + return error_; + } + + NODISCARD const E& unwrap_error() const & { + ASSERT(is_error(), "Result does not contain an error"); + return error_; + } + + NODISCARD E unwrap_error() && { + ASSERT(is_error(), "Result does not contain an error"); + return std::move(error_); + } + + template + NODISCARD T value_or(U&& default_value) const & { + return has_value() ? value_ : static_cast(std::forward(default_value)); + } + + template + NODISCARD T value_or(U&& default_value) && { + return has_value() ? std::move(value_) : static_cast(std::forward(default_value)); + } + + void drop_error() const { + if(is_error()) { + std::fprintf(stderr, "%s\n", unwrap_error().what()); + } + } + }; + + struct monostate {}; + + // TODO: Re-evaluate use of off_t + template::value, int>::type = 0> + Result load_bytes(std::FILE* object_file, off_t offset) { + T object; + if(std::fseek(object_file, offset, SEEK_SET) != 0) { + return internal_error("fseek error"); + } + if(std::fread(&object, sizeof(T), 1, object_file) != 1) { + return internal_error("fread error"); + } + return object; + } + + // shamelessly stolen from stackoverflow + inline bool directory_exists(const std::string& path) { + #if IS_WINDOWS + DWORD dwAttrib = GetFileAttributesA(path.c_str()); + return dwAttrib != INVALID_FILE_ATTRIBUTES && (dwAttrib & FILE_ATTRIBUTE_DIRECTORY); + #else + struct stat sb; + return stat(path.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode); + #endif + } + + inline std::string basename(const std::string& path) { + // Assumes no trailing /'s + auto pos = path.rfind('/'); + if(pos == std::string::npos) { + return path; + } else { + return path.substr(pos + 1); + } + } + + // A way to cast to unsigned long long without "warning: useless cast to type" + template + unsigned long long to_ull(T t) { + return static_cast(t); + } + template + frame_ptr to_frame_ptr(T t) { + return static_cast(t); + } + + // A way to cast to U without "warning: useless cast to type" + template + U to(V v) { + return static_cast(v); + } + + // TODO: Rework some stuff here. Not sure deleters should be optional or moved. + // Also allow file_wrapper file = std::fopen(object_path.c_str(), "rb"); + template< + typename T, + typename D + // workaround for: + // == 19.38-specific msvc bug https://developercommunity.visualstudio.com/t/MSVC-1938331290-preview-fails-to-comp/10505565 + // <= 19.23 msvc also appears to fail (but for a different reason https://godbolt.org/z/6Y5EvdWPK) + #if !defined(_MSC_VER) || !(_MSC_VER <= 1923 || _MSC_VER == 1938) + , + typename std::enable_if< + std::is_same()(std::declval())), void>::value, + int + >::type = 0, + typename std::enable_if< + std::is_standard_layout::value && std::is_trivial::value, + int + >::type = 0, + typename std::enable_if< + std::is_nothrow_move_constructible::value, + int + >::type = 0 + #endif + > + class raii_wrapper { + T obj; + optional deleter; + public: + raii_wrapper(T obj, D deleter) : obj(obj), deleter(deleter) {} + raii_wrapper(raii_wrapper&& other) noexcept : obj(std::move(other.obj)), deleter(std::move(other.deleter)) { + other.deleter = nullopt; + } + raii_wrapper(const raii_wrapper&) = delete; + raii_wrapper& operator=(raii_wrapper&&) = delete; + raii_wrapper& operator=(const raii_wrapper&) = delete; + ~raii_wrapper() { + if(deleter.has_value()) { + deleter.unwrap()(obj); + } + } + operator T&() { + return obj; + } + operator const T&() const { + return obj; + } + T& get() { + return obj; + } + const T& get() const { + return obj; + } + }; + + template< + typename T, + typename D + // workaround a msvc bug https://developercommunity.visualstudio.com/t/MSVC-1938331290-preview-fails-to-comp/10505565 + #if !defined(_MSC_VER) || _MSC_VER != 1938 + , + typename std::enable_if< + std::is_same()(std::declval())), void>::value, + int + >::type = 0, + typename std::enable_if< + std::is_standard_layout::value && std::is_trivial::value, + int + >::type = 0 + #endif + > + raii_wrapper::type, D> raii_wrap(T obj, D deleter) { + return raii_wrapper::type, D>(obj, deleter); + } + + inline void file_deleter(std::FILE* ptr) { + if(ptr) { + fclose(ptr); + } + } + + using file_wrapper = raii_wrapper; + + template + class maybe_owned { + std::unique_ptr owned; + T* ptr; + public: + maybe_owned(T* ptr) : ptr(ptr) {} + maybe_owned(std::unique_ptr&& owned) : owned(std::move(owned)), ptr(this->owned.get()) {} + T* operator->() { + return ptr; + } + }; +} +} + +#endif diff --git a/src/framework/CMakeLists.txt b/src/framework/CMakeLists.txt index 579180475ef..174ffadd7c7 100644 --- a/src/framework/CMakeLists.txt +++ b/src/framework/CMakeLists.txt @@ -29,10 +29,6 @@ set(framework_SRCS GameSystem/TypeContainerFunctions.h GameSystem/TypeContainerFunctionsPtr.h GameSystem/TypeContainerVisitor.h - Network/MangosSocket.h - Network/MangosSocketImpl.h - Network/MangosSocketMgr.h - Network/MangosSocketMgrImpl.h Platform/CompilerDefs.h Platform/Define.h Policies/CreationPolicy.h @@ -40,6 +36,7 @@ set(framework_SRCS Policies/Singleton.h Policies/SingletonImp.h Policies/ThreadingModel.h + Policies/ObjectConstructorTraits.h Utilities/ByteConverter.h Utilities/EventProcessor.h Utilities/EventMap.h @@ -86,7 +83,6 @@ source_group("Dynamic" include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${TBB_INCLUDE_DIRS} - ${ACE_INCLUDE_DIR} ${CMAKE_SOURCE_DIR}/src/shared ) diff --git a/src/framework/Network/MangosSocket.h b/src/framework/Network/MangosSocket.h deleted file mode 100644 index fc969398c48..00000000000 --- a/src/framework/Network/MangosSocket.h +++ /dev/null @@ -1,232 +0,0 @@ -#ifndef MANGOSSOCKET_H -#define MANGOSSOCKET_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if !defined (ACE_LACKS_PRAGMA_ONCE) -#pragma once -#endif /* ACE_LACKS_PRAGMA_ONCE */ - -#include "Common.h" - -class ACE_Message_Block; -class WorldPacket; -class WorldSession; - - -#if defined( __GNUC__ ) -#pragma pack(1) -#else -#pragma pack(push,1) -#endif - -struct ServerPktHeader -{ - uint16 size; - uint16 cmd; -}; - -struct ClientPktHeader -{ - uint16 size; - uint32 cmd; -}; - -#if defined( __GNUC__ ) -#pragma pack() -#else -#pragma pack(pop) -#endif - -// Handler that can communicate over stream sockets. -typedef ACE_Svc_Handler WorldHandler; - -/** - * MangosSocket. - * - * This class is responsible for the communication with - * remote clients. - * Most methods return -1 on failure. - * The class uses reference counting. - * - * For output the class uses one buffer (64K usually) and - * a queue where it stores packet if there is no place on - * the queue. The reason this is done, is because the server - * does really a lot of small-size writes to it, and it doesn't - * scale well to allocate memory for every. When something is - * written to the output buffer the socket is not immediately - * activated for output (again for the same reason), there - * is 10ms celling (thats why there is Update() method). - * This concept is similar to TCP_CORK, but TCP_CORK - * uses 200ms celling. As result overhead generated by - * sending packets from "producer" threads is minimal, - * and doing a lot of writes with small size is tolerated. - * - * The calls to Update () method are managed by WorldSocketMgr - * and ReactorRunnable. - * - * For input ,the class uses one 1024 bytes buffer on stack - * to which it does recv() calls. And then received data is - * distributed where its needed. 1024 matches pretty well the - * traffic generated by client for now. - * - * The input/output do speculative reads/writes (AKA it tryes - * to read all data available in the kernel buffer or tryes to - * write everything available in userspace buffer), - * which is ok for using with Level and Edge Triggered IO - * notification. - * - */ -template -class MangosSocket : public WorldHandler -{ - public: - // things called by ACE framework. - MangosSocket(); - virtual ~MangosSocket(void); - - // Declare the acceptor for this class - typedef ACE_Connector Connector; - // Declare some friends - friend class ACE_Connector; - friend class ACE_NonBlocking_Connect_Handler; - - // Mutex type used for various synchronizations. - using LockType = std::mutex; - typedef std::unique_lock GuardType; - - // Queue for storing packets for which there is no space. - typedef ACE_Unbounded_Queue PacketQueueT; - - // Check if socket is closed. - bool IsClosed() const { return closing_; } - - // Close the socket. - void CloseSocket (void); - - // Called on open ,the void* is the acceptor. - virtual int open(void *); - - // Called on failures inside of the acceptor, don't call from your code. - virtual int close(int); - - // Get address of connected peer. - const std::string& GetRemoteAddress () const { return m_Address; } - - // Send A packet on the socket, this function is reentrant. - // @param pct packet to send - // @return -1 of failure - int SendPacket (const WorldPacket& pct); - - // Add reference to this object. - long AddReference() { return static_cast(add_reference()); } - - // Remove reference to this object. - long RemoveReference() { return static_cast(remove_reference()); } - - void SetSession(SessionType* t) { m_Session = t; } - void SetClientSocket() { m_isServerSocket = false; } - /** - * @brief returns true iif the socket is connected TO a client (ie we are the server) - */ - bool IsServerSide() { return m_isServerSocket; } - protected: - // process one incoming packet. - // @param new_pct received packet ,note that you need to delete it. - int ProcessIncoming (WorldPacket* new_pct) { delete new_pct; return 0; } - int OnSocketOpen() { return 0; } - - // Called when we can read from the socket. - virtual int handle_input (ACE_HANDLE = ACE_INVALID_HANDLE); - - // Called when the socket can write. - virtual int handle_output (ACE_HANDLE = ACE_INVALID_HANDLE); - - // Called when connection is closed or error happens. - virtual int handle_close (ACE_HANDLE = ACE_INVALID_HANDLE, - ACE_Reactor_Mask = ACE_Event_Handler::ALL_EVENTS_MASK); - - // Called by WorldSocketMgr/ReactorRunnable. - int Update (void); - - // Helper functions for processing incoming data. - int handle_input_header (void); - int handle_input_payload (void); - int handle_input_missing_data (void); - - // Help functions to mark/unmark the socket for output. - // @param g the guard is for m_OutBufferLock, the function will release it - int cancel_wakeup_output (GuardType& g); - int schedule_wakeup_output (GuardType& g); - - // Try to write WorldPacket to m_OutBuffer ,return -1 if no space - // Need to be called with m_OutBufferLock lock held - int iSendPacket (const WorldPacket& pct); - - // Flush m_PacketQueue if there are packets in it - // Need to be called with m_OutBufferLock lock held - // @return true if it wrote to the buffer ( AKA you need - // to mark the socket for output ). - bool iFlushPacketQueue (); - - // Time in which the last ping was received - ACE_Time_Value m_LastPingTime; - - // Keep track of over-speed pings ,to prevent ping flood. - uint32 m_OverSpeedPings; - - // Address of the remote peer - std::string m_Address; - - // Class used for managing encryption of the headers - Crypt m_Crypt; - - // Mutex lock to protect m_Session - LockType m_SessionLock; - - // Session to which received packets are routed - SessionType* m_Session; - - // here are stored the fragments of the received data - WorldPacket* m_RecvWPct; - - // This block actually refers to m_RecvWPct contents, - // which allows easy and safe writing to it. - // It wont free memory when its deleted. m_RecvWPct takes care of freeing. - ACE_Message_Block m_RecvPct; - - // Fragment of the received header. - ACE_Message_Block m_Header; - - // Mutex for protecting output related data. - LockType m_OutBufferLock; - - // Buffer used for writing output. - ACE_Message_Block *m_OutBuffer; - - // Size of the m_OutBuffer. - size_t m_OutBufferSize; - - // Here are stored packets for which there was no space on m_OutBuffer, - // this allows not-to kick player if its buffer is overflowed. - PacketQueueT m_PacketQueue; - - // True if the socket is registered with the reactor for output - bool m_OutActive; - - uint32 m_Seed; - - bool m_isServerSocket; -}; - -#endif // MANGOSSOCKET_H diff --git a/src/framework/Network/MangosSocketImpl.h b/src/framework/Network/MangosSocketImpl.h deleted file mode 100644 index 85444bbc2f7..00000000000 --- a/src/framework/Network/MangosSocketImpl.h +++ /dev/null @@ -1,543 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "MangosSocket.h" -#include "Common.h" - -#include "Util.h" -#include "WorldPacket.h" -#include "SharedDefines.h" -#include "ByteBuffer.h" -#include "Database/DatabaseEnv.h" -#include "WorldSession.h" -#include "Log.h" -#include "DBCStores.h" - - -template -MangosSocket::MangosSocket() : - WorldHandler(), - m_LastPingTime(ACE_Time_Value::zero), - m_OverSpeedPings(0), - m_Session(0), - m_RecvWPct(0), - m_RecvPct(), - m_Header(sizeof(ClientPktHeader)), - m_OutBuffer(0), - m_OutBufferSize(65536), - m_OutActive(false), - m_Seed(static_cast(rand32())), - m_isServerSocket(true) -{ - reference_counting_policy().value(ACE_Event_Handler::Reference_Counting_Policy::ENABLED); -} - -template -MangosSocket::~MangosSocket(void) -{ - delete m_RecvWPct; - - if (m_OutBuffer) - m_OutBuffer->release(); - - closing_ = true; - - peer().close(); - - WorldPacket* pct; - while (m_PacketQueue.dequeue_head(pct) == 0) - delete pct; -} - -template -void MangosSocket::CloseSocket(void) -{ - { - GuardType lock(m_OutBufferLock); - - if (closing_) - return; - - closing_ = true; - peer().close_writer(); - } - - { - GuardType lock(m_SessionLock); - - m_Session = nullptr; - } -} - -template -int MangosSocket::SendPacket(const WorldPacket& pct) -{ - GuardType lock(m_OutBufferLock); - - if (closing_) - return -1; - - if (((SocketName*)this)->iSendPacket(pct) == -1) - { - WorldPacket* npct; - - ACE_NEW_RETURN(npct, WorldPacket(pct), -1); - - // NOTE maybe check of the size of the queue can be good ? - // to make it bounded instead of unbounded - if (m_PacketQueue.enqueue_tail(npct) == -1) - { - delete npct; - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MangosSocket::SendPacket: m_PacketQueue.enqueue_tail failed"); - return -1; - } - } - - return 0; -} - -template -int MangosSocket::open(void *a) -{ - ACE_UNUSED_ARG(a); - - // Prevent double call to this func. - if (m_OutBuffer) - return -1; - - // This will also prevent the socket from being Updated - // while we are initializing it. - m_OutActive = true; - - // Hook for the manager. - if (((SocketName*)this)->OnSocketOpen() == -1) - return -1; - - // Allocate the buffer. - ACE_NEW_RETURN(m_OutBuffer, ACE_Message_Block(m_OutBufferSize), -1); - - // Store peer address. - ACE_INET_Addr remote_addr; - - if (peer().get_remote_addr(remote_addr) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::open: peer ().get_remote_addr errno = %s", ACE_OS::strerror(errno)); - return -1; - } - - m_Address = remote_addr.get_host_addr(); - - if (((SocketName*)this)->SendStartupPacket() == -1) - return -1; - - // Register with ACE Reactor - if (reactor()->register_handler(this, ACE_Event_Handler::READ_MASK | ACE_Event_Handler::WRITE_MASK) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::open: unable to register client handler errno = %s", ACE_OS::strerror(errno)); - return -1; - } - - // reactor takes care of the socket from now on - remove_reference(); - - return 0; -} - -template -int MangosSocket::close(int) -{ - shutdown(); - - closing_ = true; - - remove_reference(); - - return 0; -} - -template -int MangosSocket::handle_input(ACE_HANDLE) -{ - if (closing_) - return -1; - - switch (handle_input_missing_data()) - { - case -1 : - { - if ((errno == EWOULDBLOCK) || - (errno == EAGAIN)) - { - return Update(); // interesting line ,isn't it ? - } - - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "WorldSocket::handle_input: Peer error closing connection errno = %s", ACE_OS::strerror(errno)); - - errno = ECONNRESET; - return -1; - } - case 0: - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "WorldSocket::handle_input: Peer has closed connection"); - - errno = ECONNRESET; - return -1; - } - case 1: - return 1; - default: - return Update(); // another interesting line ;) - } - - ACE_NOTREACHED(return -1); -} - -template -int MangosSocket::handle_output(ACE_HANDLE) -{ - GuardType lock(m_OutBufferLock); - - if (closing_) - return -1; - - const size_t send_len = m_OutBuffer->length(); - - if (send_len == 0) - return cancel_wakeup_output(lock); - -#ifdef MSG_NOSIGNAL - ssize_t n = peer().send(m_OutBuffer->rd_ptr(), send_len, MSG_NOSIGNAL); -#else - ssize_t n = peer().send(m_OutBuffer->rd_ptr(), send_len); -#endif // MSG_NOSIGNAL - - if (n == 0) - return -1; - else if (n == -1) - { -#ifdef _WIN32 - if (WSAGetLastError() == WSAEWOULDBLOCK) - return schedule_wakeup_output(lock); -#endif - - if (errno == EWOULDBLOCK || errno == EAGAIN) - return schedule_wakeup_output(lock); - - return -1; - } - else if (n < (ssize_t)send_len) //now n > 0 - { - m_OutBuffer->rd_ptr(static_cast(n)); - - // move the data to the base of the buffer - m_OutBuffer->crunch(); - - return schedule_wakeup_output(lock); - } - else //now n == send_len - { - m_OutBuffer->reset(); - - if (!iFlushPacketQueue()) - return cancel_wakeup_output(lock); - else - return schedule_wakeup_output(lock); - } - - ACE_NOTREACHED(return 0); -} - -template -int MangosSocket::handle_close(ACE_HANDLE h, ACE_Reactor_Mask) -{ - // Critical section - { - GuardType lock(m_OutBufferLock); - - closing_ = true; - - if (h == ACE_INVALID_HANDLE) - peer().close_writer(); - } - - // Critical section - { - GuardType lock(m_SessionLock); - - m_Session = nullptr; - } - - reactor()->remove_handler(this, ACE_Event_Handler::DONT_CALL | ACE_Event_Handler::ALL_EVENTS_MASK); - return 0; -} - -template -int MangosSocket::Update(void) -{ - if (closing_) - return -1; - - if (m_OutActive || m_OutBuffer->length() == 0) - return 0; - - return handle_output(get_handle()); -} - -template -int MangosSocket::handle_input_header(void) -{ - MANGOS_ASSERT(m_RecvWPct == nullptr); - - MANGOS_ASSERT(m_Header.length() == sizeof(ClientPktHeader)); - - m_Crypt.DecryptRecv((uint8*) m_Header.rd_ptr(), sizeof(ClientPktHeader)); - - ClientPktHeader& header = *((ClientPktHeader*) m_Header.rd_ptr()); - - EndianConvertReverse(header.size); - EndianConvert(header.cmd); - - if ((header.size < 4) || (header.size > 10240) || (header.cmd > 10240)) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::handle_input_header: client %s sent malformed packet size = %d , cmd = %d", - GetRemoteAddress().c_str(), header.size, header.cmd); - - errno = EINVAL; - return -1; - } - - header.size -= 4; - - ACE_NEW_RETURN(m_RecvWPct, WorldPacket((uint16) header.cmd, header.size), -1); - - if (header.size > 0) - { - m_RecvWPct->resize(header.size); - m_RecvPct.base((char*) m_RecvWPct->contents(), m_RecvWPct->size()); - } - else - MANGOS_ASSERT(m_RecvPct.space() == 0); - - return 0; -} - -template -int MangosSocket::handle_input_payload(void) -{ - // set errno properly here on error !!! - // now have a header and payload - - MANGOS_ASSERT(m_RecvPct.space() == 0); - MANGOS_ASSERT(m_Header.space() == 0); - MANGOS_ASSERT(m_RecvWPct != nullptr); - - const int ret = ((SocketName*)this)->ProcessIncoming(m_RecvWPct); - - m_RecvPct.base(nullptr, 0); - m_RecvPct.reset(); - m_RecvWPct = nullptr; - - m_Header.reset(); - - if (ret == -1) - errno = EINVAL; - - return ret; -} - -template -int MangosSocket::handle_input_missing_data(void) -{ - char buf [4096]; - - ACE_Data_Block db(sizeof(buf), - ACE_Message_Block::MB_DATA, - buf, - 0, - 0, - ACE_Message_Block::DONT_DELETE, - 0); - - ACE_Message_Block message_block(&db, - ACE_Message_Block::DONT_DELETE, - 0); - - const size_t recv_size = message_block.space(); - - const ssize_t n = peer().recv(message_block.wr_ptr(), - recv_size); - - if (n <= 0) - return (int)n; - - message_block.wr_ptr(n); - - while (message_block.length() > 0) - { - if (m_Header.space() > 0) - { - //need to receive the header - const size_t to_header = (message_block.length() > m_Header.space() ? m_Header.space() : message_block.length()); - m_Header.copy(message_block.rd_ptr(), to_header); - message_block.rd_ptr(to_header); - - if (m_Header.space() > 0) - { - // Couldn't receive the whole header this time. - MANGOS_ASSERT(message_block.length() == 0); - errno = EWOULDBLOCK; - return -1; - } - - // We just received nice new header - if (handle_input_header() == -1) - { - MANGOS_ASSERT((errno != EWOULDBLOCK) && (errno != EAGAIN)); - return -1; - } - } - - // Its possible on some error situations that this happens - // for example on closing when epoll receives more chunked data and stuff - // hope this is not hack ,as proper m_RecvWPct is asserted around - if (!m_RecvWPct) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Forcing close on input m_RecvWPct = nullptr"); - errno = EINVAL; - return -1; - } - - // We have full read header, now check the data payload - if (m_RecvPct.space() > 0) - { - //need more data in the payload - const size_t to_data = (message_block.length() > m_RecvPct.space() ? m_RecvPct.space() : message_block.length()); - m_RecvPct.copy(message_block.rd_ptr(), to_data); - message_block.rd_ptr(to_data); - - if (m_RecvPct.space() > 0) - { - // Couldn't receive the whole data this time. - MANGOS_ASSERT(message_block.length() == 0); - errno = EWOULDBLOCK; - return -1; - } - } - - //just received fresh new payload - if (handle_input_payload() == -1) - { - MANGOS_ASSERT((errno != EWOULDBLOCK) && (errno != EAGAIN)); - return -1; - } - } - - return size_t(n) == recv_size ? 1 : 2; -} - -template -int MangosSocket::cancel_wakeup_output(GuardType& g) -{ - if (!m_OutActive) - return 0; - - m_OutActive = false; - - g.unlock(); - - if (reactor()->cancel_wakeup - (this, ACE_Event_Handler::WRITE_MASK) == -1) - { - // would be good to store errno from reactor with errno guard - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MangosSocket::cancel_wakeup_output"); - return -1; - } - - return 0; -} - -template -int MangosSocket::schedule_wakeup_output(GuardType& g) -{ - if (m_OutActive) - return 0; - - m_OutActive = true; - - g.unlock(); - - if (reactor()->schedule_wakeup - (this, ACE_Event_Handler::WRITE_MASK) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MangosSocket::schedule_wakeup_output"); - return -1; - } - - return 0; -} - -template -int MangosSocket::iSendPacket(const WorldPacket& pct) -{ - if (m_OutBuffer->space() < pct.size() + sizeof(ServerPktHeader)) - { - errno = ENOBUFS; - return -1; - } - - ServerPktHeader header; - - header.cmd = pct.GetOpcode(); - - header.size = (uint16) pct.size() + 2; - - EndianConvertReverse(header.size); - EndianConvert(header.cmd); - - m_Crypt.EncryptSend((uint8*) & header, sizeof(header)); - - if (m_OutBuffer->copy((char*) & header, sizeof(header)) == -1) - ACE_ASSERT(false); - - if (!pct.empty()) - if (m_OutBuffer->copy((char*) pct.contents(), pct.size()) == -1) - ACE_ASSERT(false); - - return 0; -} - -template -bool MangosSocket::iFlushPacketQueue() -{ - WorldPacket *pct; - bool haveone = false; - - while (m_PacketQueue.dequeue_head(pct) == 0) - { - if (((SocketName*)this)->iSendPacket(*pct) == -1) - { - if (m_PacketQueue.enqueue_head(pct) == -1) - { - delete pct; - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MangosSocket::iFlushPacketQueue m_PacketQueue->enqueue_head"); - return false; - } - - break; - } - else - { - haveone = true; - delete pct; - } - } - - return haveone; -} diff --git a/src/framework/Network/MangosSocketMgr.h b/src/framework/Network/MangosSocketMgr.h deleted file mode 100644 index b90fbb85a36..00000000000 --- a/src/framework/Network/MangosSocketMgr.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef MANGOSSOCKETMGR_H -#define MANGOSSOCKETMGR_H - -#include - -#include - -template -class MangosSocketAcceptor; -template -class ReactorRunnable; -class ACE_Event_Handler; - -// Manages all sockets connected to peers and network threads -template -class MangosSocketMgr -{ - public: - - // Start network, listen at address:port . - int StartNetwork(ACE_UINT16 port, std::string& address); - - // Stops all network threads, It will wait for all running threads . - void StopNetwork(); - - // Wait untill all network threads have "joined" . - void Wait(); - - void SetOutKBuff(int v) { m_SockOutKBuff = v; } - void SetOutUBuff(int v) { m_SockOutUBuff = v; } - void SetThreads(int v) { m_NetThreadsCount = v; } - void SetTcpNodelay(bool v) { m_UseNoDelay = v; } - void SetInterval(int v) { m_Interval = v * 1000; /* to microseconds */ } - - int Connect(int port, std::string const& address, SocketType*& sock); - protected: - int OnSocketOpen(SocketType* sock); - int StartReactiveIO(ACE_UINT16 port, const char* address); - int StartThreadsIfNeeded(); - - MangosSocketMgr(); - ~MangosSocketMgr(); - - ReactorRunnable* m_NetThreads; - size_t m_NetThreadsCount; - - int m_SockOutKBuff; - int m_SockOutUBuff; - bool m_UseNoDelay; - int m_Interval; - - std::string m_addr; - ACE_UINT16 m_port; - - MangosSocketAcceptor* m_Acceptor; -}; - -#endif // MANGOSSOCKETMGR_H diff --git a/src/framework/Network/MangosSocketMgrImpl.h b/src/framework/Network/MangosSocketMgrImpl.h deleted file mode 100644 index 843daf5f9c3..00000000000 --- a/src/framework/Network/MangosSocketMgrImpl.h +++ /dev/null @@ -1,385 +0,0 @@ -#include "MangosSocketMgr.h" - - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "Log.h" -#include "Common.h" -#include "Config/Config.h" -#include "Database/DatabaseEnv.h" - -template -class MangosSocketAcceptor : public ACE_Acceptor -{ -public: - MangosSocketAcceptor(void) { } - virtual ~MangosSocketAcceptor(void) - { - if (this->reactor()) - this->reactor()->cancel_timer(this, 1); - } - -protected: - - virtual int handle_timeout(const ACE_Time_Value ¤t_time, const void *act = 0) - { - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Resuming acceptor"); - this->reactor()->cancel_timer(this, 1); - return this->reactor()->register_handler(this, ACE_Event_Handler::ACCEPT_MASK); - } - - virtual int handle_accept_error(void) - { -#if defined(ENFILE) && defined(EMFILE) - if (errno == ENFILE || errno == EMFILE) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Out of file descriptors, suspending incoming connections for 10 seconds"); - this->reactor()->remove_handler(this, ACE_Event_Handler::ACCEPT_MASK | ACE_Event_Handler::DONT_CALL); - this->reactor()->schedule_timer(this, NULL, ACE_Time_Value(10)); - } -#endif - return 0; - } -}; - -/** -* This is a helper class to WorldSocketMgr ,that manages -* network threads, and assigning connections from acceptor thread -* to other network threads -*/ -template -class ReactorRunnable : protected ACE_Task_Base -{ -public: - ReactorRunnable() : - m_Reactor(0), - m_Connections(0), - m_ThreadId(-1), - m_Interval(0) - { - ACE_Reactor_Impl* imp = 0; - -#if defined (ACE_HAS_EVENT_POLL) || defined (ACE_HAS_DEV_POLL) - - imp = new ACE_Dev_Poll_Reactor(); - - imp->max_notify_iterations(128); - imp->restart(1); - -#else - - imp = new ACE_TP_Reactor(); - imp->max_notify_iterations(128); - -#endif - - m_Reactor = new ACE_Reactor(imp, 1); - } - - virtual ~ReactorRunnable() - { - Stop(); - Wait(); - - delete m_Reactor; - } - - void Stop() - { - m_Reactor->end_reactor_event_loop(); - } - - int Start(int interval) - { - m_Interval = interval; - - if (m_ThreadId != -1) - return -1; - - return (m_ThreadId = activate()); - } - - void Wait() - { - ACE_Task_Base::wait(); - } - - long Connections() - { - return m_Connections; - } - - int AddSocket(SocketType* sock) - { - std::unique_lock lock(m_NewSockets_Lock); - - ++m_Connections; - sock->AddReference(); - sock->reactor(m_Reactor); - m_NewSockets.insert(sock); - - return 0; - } - - ACE_Reactor* GetReactor() - { - return m_Reactor; - } - -protected: - void AddNewSockets() - { - std::unique_lock lock(m_NewSockets_Lock); - - if (m_NewSockets.empty()) - return; - - for (typename SocketSet::const_iterator i = m_NewSockets.begin(); i != m_NewSockets.end(); ++i) - { - SocketType* sock = (*i); - - if (sock->IsClosed()) - { - sock->RemoveReference(); - --m_Connections; - } - else - m_Sockets.insert(sock); - } - - m_NewSockets.clear(); - } - - virtual int svc() - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Network Thread Starting"); - - WorldDatabase.ThreadStart(); - - MANGOS_ASSERT(m_Reactor); - - typename SocketSet::iterator i, t; - - while (!m_Reactor->reactor_event_loop_done()) - { - // dont be too smart to move this outside the loop - // the run_reactor_event_loop will modify interval - ACE_Time_Value interval(0, m_Interval); - - if (m_Reactor->run_reactor_event_loop(interval) == -1) - break; - - AddNewSockets(); - - for (i = m_Sockets.begin(); i != m_Sockets.end();) - { - if ((*i)->Update() == -1) - { - t = i; - ++i; - (*t)->CloseSocket(); - (*t)->RemoveReference(); - --m_Connections; - m_Sockets.erase(t); - } - else - ++i; - } - } - - WorldDatabase.ThreadEnd(); - - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Network Thread Exitting"); - - return 0; - } - -private: - using AtomicInt = std::atomic; - typedef std::set SocketSet; - - ACE_Reactor* m_Reactor; - AtomicInt m_Connections; - int m_ThreadId; - int m_Interval; - - SocketSet m_Sockets; - - SocketSet m_NewSockets; - std::mutex m_NewSockets_Lock; -}; - -template -MangosSocketMgr::MangosSocketMgr(): - m_NetThreads(0), - m_NetThreadsCount(0), - m_SockOutKBuff(-1), - m_SockOutUBuff(65536), - m_UseNoDelay(true), - m_Interval(10000), - m_port(0), - m_Acceptor(0) -{ -} - -template -MangosSocketMgr::~MangosSocketMgr() -{ - Wait(); - - delete [] m_NetThreads; - delete m_Acceptor; -} - -template -int MangosSocketMgr::StartThreadsIfNeeded() -{ - if (m_NetThreads) - return 0; - m_NetThreads = new ReactorRunnable[m_NetThreadsCount]; - for (size_t i = 0; i < m_NetThreadsCount; ++i) - m_NetThreads[i].Start(m_Interval); - return 0; -} - -template -int MangosSocketMgr::StartReactiveIO(ACE_UINT16 port, const char* address) -{ - if (StartThreadsIfNeeded() == -1) - return -1; - - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Max allowed socket connections %d", ACE::max_handles()); - - if (m_SockOutUBuff <= 0) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Network.OutUBuff is wrong in your config file"); - return -1; - } - - m_Acceptor = new MangosSocketAcceptor(); - - ACE_INET_Addr listen_addr(port, address); - - if (m_Acceptor->open(listen_addr, m_NetThreads[0].GetReactor(), ACE_NONBLOCK) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Failed to open acceptor, check if the port is free"); - return -1; - } - - return 0; -} - -template -int MangosSocketMgr::StartNetwork(ACE_UINT16 port, std::string& address) -{ - if (!sLog.HasLogLevelOrHigher(LOG_LVL_DEBUG)) - ACE_Log_Msg::instance()->priority_mask(LM_ERROR, ACE_Log_Msg::PROCESS); - - if (StartReactiveIO(port, address.c_str()) == -1) - return -1; - - return 0; -} - -template -void MangosSocketMgr::StopNetwork() -{ - if (m_Acceptor) - m_Acceptor->close(); - - if (m_NetThreadsCount != 0) - { - for (size_t i = 0; i < m_NetThreadsCount; ++i) - m_NetThreads[i].Stop(); - } -} - -template -void MangosSocketMgr::Wait() -{ - if (m_NetThreadsCount != 0) - { - for (size_t i = 0; i < m_NetThreadsCount; ++i) - m_NetThreads[i].Wait(); - } -} - -template -int MangosSocketMgr::OnSocketOpen(SocketType* sock) -{ - // set some options here - if (m_SockOutKBuff >= 0) - { - if (sock->peer().set_option(SOL_SOCKET, SO_SNDBUF, (void*)&m_SockOutKBuff, sizeof(int)) == -1 && errno != ENOTSUP) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MangosSocketMgr::OnSocketOpen set_option SO_SNDBUF"); - return -1; - } - } - - static const int ndoption = 1; - - // Set TCP_NODELAY. - if (m_UseNoDelay) - { - if (sock->peer().set_option(ACE_IPPROTO_TCP, TCP_NODELAY, (void*)&ndoption, sizeof(int)) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MangosSocketMgr::OnSocketOpen: peer().set_option TCP_NODELAY errno = %s", ACE_OS::strerror(errno)); - return -1; - } - } - - sock->m_OutBufferSize = static_cast(m_SockOutUBuff); - - // we skip the Acceptor Thread - size_t min = 1; - - MANGOS_ASSERT(m_NetThreadsCount >= 1); - - for (size_t i = 1; i < m_NetThreadsCount; ++i) - if (m_NetThreads[i].Connections() < m_NetThreads[min].Connections()) - min = i; - - return m_NetThreads[min].AddSocket(sock); -} - -template -int MangosSocketMgr::Connect(int port, std::string const& address, SocketType*& handler) -{ - if (StartThreadsIfNeeded() == -1) - return -1; - - ACE_INET_Addr addr(port, address.c_str()); - handler = new SocketType(); - handler->SetClientSocket(); - - // Create the connector - typename SocketType::Connector connector; - - //Connects to remote machine - if (connector.connect(handler,addr) == -1) - { - // Handler is already deleted. - handler = nullptr; - return -1; - } - // Now add a reactor so our connnection gets updated - OnSocketOpen(handler); - - return 0; -} diff --git a/src/framework/Platform/Define.h b/src/framework/Platform/Define.h index 6145c1f365a..8061a1becfb 100644 --- a/src/framework/Platform/Define.h +++ b/src/framework/Platform/Define.h @@ -22,14 +22,28 @@ #ifndef MANGOS_DEFINE_H #define MANGOS_DEFINE_H -#include +#include "Platform/CompilerDefs.h" +#include -#include -#include -#include -#include +#if PLATFORM == PLATFORM_WINDOWS +// Unfortunately, every library (e.g. MySQL, G3D) includes in their **HEADER** +// and will break parts of the code, since Windows adds so many marcos and stuff. +// So if we can't beat them, join them! +// We include Windows.h first and remove all the conflicting definitions after. -#include "Platform/CompilerDefs.h" +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +// Now we need to remove all the stuff that comes included with Windows headers and will most probably cause errors... +#undef WIN32_LEAN_AND_MEAN +#undef NOMINMAX +#undef ERROR +#undef IGNORE +#endif // WINDOWS #define MANGOS_LITTLEENDIAN 0 #define MANGOS_BIGENDIAN 1 @@ -42,8 +56,6 @@ # endif //ACE_BYTE_ORDER #endif //MANGOS_ENDIAN -#define MANGOS_PATH_MAX PATH_MAX // ace/os_include/os_limits.h -> ace/Basic_Types.h - #if PLATFORM == PLATFORM_WINDOWS # ifndef DECLSPEC_NORETURN # define DECLSPEC_NORETURN __declspec(noreturn) @@ -60,14 +72,14 @@ # define ATTR_PRINTF(F,V) #endif //COMPILER == COMPILER_GNU -typedef ACE_INT64 int64; -typedef ACE_INT32 int32; -typedef ACE_INT16 int16; -typedef ACE_INT8 int8; -typedef ACE_UINT64 uint64; -typedef ACE_UINT32 uint32; -typedef ACE_UINT16 uint16; -typedef ACE_UINT8 uint8; +typedef std::int64_t int64; +typedef std::int32_t int32; +typedef std::int16_t int16; +typedef std::int8_t int8; +typedef std::uint64_t uint64; +typedef std::uint32_t uint32; +typedef std::uint16_t uint16; +typedef std::uint8_t uint8; #ifndef _WIN32 typedef uint16 WORD; diff --git a/src/framework/Policies/ObjectConstructorTraits.h b/src/framework/Policies/ObjectConstructorTraits.h new file mode 100644 index 00000000000..8a349f496b4 --- /dev/null +++ b/src/framework/Policies/ObjectConstructorTraits.h @@ -0,0 +1,37 @@ +#ifndef MANGOS_POLICIES_CONSTRUCTOR_TRAITS_H +#define MANGOS_POLICIES_CONSTRUCTOR_TRAITS_H + +namespace MaNGOS { namespace Policies +{ + /// Disallows the copy constructor and **allows** the move constructor + struct NoCopyButAllowMove { + protected: + NoCopyButAllowMove() = default; + ~NoCopyButAllowMove() = default; + public: + // remove copy + NoCopyButAllowMove(NoCopyButAllowMove const&) = delete; + NoCopyButAllowMove& operator=(NoCopyButAllowMove const&) = delete; + + // but allow move + NoCopyButAllowMove(NoCopyButAllowMove&&) = default; + NoCopyButAllowMove& operator=(NoCopyButAllowMove&&) = default; + }; + + /// Disallows the copy constructor and also disallows the move constructor + struct NoCopyNoMove { + protected: + NoCopyNoMove() = default; + ~NoCopyNoMove() = default; + public: + // remove copy + NoCopyNoMove(NoCopyNoMove const&) = delete; + NoCopyNoMove& operator=(NoCopyNoMove const&) = delete; + + // remove move + NoCopyNoMove(NoCopyNoMove&&) = delete; + NoCopyNoMove& operator=(NoCopyNoMove&&) = delete; + }; +}} // namespace MaNGOS::Policies + +#endif // MANGOS_POLICIES_CONSTRUCTOR_TRAITS_H diff --git a/src/framework/Utilities/EventProcessor.cpp b/src/framework/Utilities/EventProcessor.cpp index 050abb99821..e52505db1da 100644 --- a/src/framework/Utilities/EventProcessor.cpp +++ b/src/framework/Utilities/EventProcessor.cpp @@ -19,7 +19,8 @@ */ #include "EventProcessor.h" -#include "Log.h" // Zerix: For MANGOS_ASSERT. No idea. +#include "Log.h" +#include "Errors.h" void BasicEvent::ScheduleAbort() { diff --git a/src/game/Anticheat/Anticheat.cpp b/src/game/Anticheat/Anticheat.cpp index 0dde3f9e0f6..043e6583a61 100644 --- a/src/game/Anticheat/Anticheat.cpp +++ b/src/game/Anticheat/Anticheat.cpp @@ -15,6 +15,7 @@ */ #include "Anticheat.h" +#include "IO/Multithreading/CreateThread.h" AnticheatManager* AnticheatManager::instance() { @@ -93,7 +94,7 @@ Warden* AnticheatManager::CreateWardenFor(WorldSession* client, BigNumber* K) void AnticheatManager::StartWardenUpdateThread() { - m_wardenUpdateThread = std::thread(&AnticheatManager::UpdateWardenSessions, this); + m_wardenUpdateThread = IO::Multithreading::CreateThread("WardenSessions", [this]() { UpdateWardenSessions(); }); } void AnticheatManager::StopWardenUpdateThread() diff --git a/src/game/Anticheat/WardenAnticheat/WardenModuleMgr.cpp b/src/game/Anticheat/WardenAnticheat/WardenModuleMgr.cpp index 8a9884e15c6..10b41498f1d 100644 --- a/src/game/Anticheat/WardenAnticheat/WardenModuleMgr.cpp +++ b/src/game/Anticheat/WardenAnticheat/WardenModuleMgr.cpp @@ -30,7 +30,7 @@ #include "World.h" #include "Log.h" -#include +#include "IO/Filesystem/FileSystem.h" #include #include @@ -38,31 +38,16 @@ INSTANTIATE_SINGLETON_1(WardenModuleMgr); -namespace -{ std::vector GetModuleNames(std::string const& moduleDir) { - ACE_DIR* dirp = ACE_OS::opendir(ACE_TEXT(moduleDir.c_str())); - - std::vector results; - - if (dirp) - { - ACE_DIRENT* dp; + // Get all the files in warden folder, might also include ".cr" or ".key" files + std::vector result = IO::Filesystem::GetAllFilesInFolder(moduleDir, IO::Filesystem::OutputFilePath::FullFilePath); - // look only for .bin files, and assume (for now) that the corresponding .key and .cr files exist - while (!!(dp = ACE_OS::readdir(dirp))) - if (!memcmp(&dp->d_name[strlen(dp->d_name) - 4], ".bin", 4)) - results.emplace_back(moduleDir + "/" + dp->d_name); + // Remove all elements that don't end with ".bin" + std::function MustEndWithBin = [](const std::string &s) { return s.size() < 4 || s.substr(s.size() - 4) != ".bin"; }; + result.erase(std::remove_if(result.begin(), result.end(), MustEndWithBin), result.end()); -#ifndef _WIN32 - // this causes a crash on Windows, so just accept a minor memory leak for now - ACE_OS::closedir(dirp); -#endif - } - - return results; -} + return result; } WardenModuleMgr::WardenModuleMgr() diff --git a/src/game/Anticheat/WardenAnticheat/WardenScan.hpp b/src/game/Anticheat/WardenAnticheat/WardenScan.hpp index cf70814272f..9e98ca02a8e 100644 --- a/src/game/Anticheat/WardenAnticheat/WardenScan.hpp +++ b/src/game/Anticheat/WardenAnticheat/WardenScan.hpp @@ -27,6 +27,7 @@ #include "ByteBuffer.h" #include "World.h" #include "Log.h" +#include "Errors.h" #include "Crypto/Hash/SHA1.h" #include diff --git a/src/game/AuraRemovalMgr.cpp b/src/game/AuraRemovalMgr.cpp index 142dfeaafac..be45b5f3669 100644 --- a/src/game/AuraRemovalMgr.cpp +++ b/src/game/AuraRemovalMgr.cpp @@ -41,7 +41,7 @@ void AuraRemovalManager::LoadFromDB() } else { - BarGoLink bar((int)result->GetRowCount()); + BarGoLink bar(result->GetRowCount()); do { bar.step(); diff --git a/src/game/Battlegrounds/BattleGroundAV.cpp b/src/game/Battlegrounds/BattleGroundAV.cpp index e18625a12f7..3bd05a5cf91 100644 --- a/src/game/Battlegrounds/BattleGroundAV.cpp +++ b/src/game/Battlegrounds/BattleGroundAV.cpp @@ -1079,7 +1079,7 @@ void BattleGroundAV::EventPlayerDestroyedPoint(BG_AV_Nodes node) { sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "BattleGroundAV: player destroyed point node %i", node); - MANGOS_ASSERT(m_nodes[node].owner != BG_AV_TEAM_NEUTRAL) + MANGOS_ASSERT(m_nodes[node].owner != BG_AV_TEAM_NEUTRAL); BattleGroundTeamIndex ownerTeamIdx = BattleGroundTeamIndex(m_nodes[node].owner); Team ownerTeam = ownerTeamIdx == BG_TEAM_ALLIANCE ? ALLIANCE : HORDE; diff --git a/src/game/CMakeLists.txt b/src/game/CMakeLists.txt index b5434480623..1afd571853f 100644 --- a/src/game/CMakeLists.txt +++ b/src/game/CMakeLists.txt @@ -195,7 +195,9 @@ set (game_SRCS PlayerBots/PlayerBotAI.cpp PlayerBots/PlayerBotMgr.cpp Protocol/Opcodes.cpp + Protocol/WorldSocket.h Protocol/WorldSocket.cpp + Protocol/WorldSocketMgr.h Protocol/WorldSocketMgr.cpp Spells/Spell.cpp Spells/SpellAuras.cpp @@ -411,8 +413,6 @@ set (game_SRCS PlayerBots/PlayerBotAI.h PlayerBots/PlayerBotMgr.h Protocol/Opcodes.h - Protocol/WorldSocket.h - Protocol/WorldSocketMgr.h Spells/Spell.h Spells/SpellAuraDefines.h Spells/SpellAuras.h @@ -454,7 +454,6 @@ if(WIN32) #allow exceptions in destructors if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:implicitNoexcept-") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D__ACE_INLINE__") elseif(MINGW) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fpermissive") endif() diff --git a/src/game/Commands/CharacterCommands.cpp b/src/game/Commands/CharacterCommands.cpp index 9535c933a2c..0275d8f13bc 100644 --- a/src/game/Commands/CharacterCommands.cpp +++ b/src/game/Commands/CharacterCommands.cpp @@ -36,6 +36,7 @@ #include "Config/Config.h" #include +#include bool ChatHandler::HandleCharacterAIInfoCommand(char* /*args*/) { diff --git a/src/game/CreatureGroups.cpp b/src/game/CreatureGroups.cpp index 33249f86c55..a9e9f9fb663 100644 --- a/src/game/CreatureGroups.cpp +++ b/src/game/CreatureGroups.cpp @@ -487,7 +487,7 @@ void CreatureGroupsManager::Load() int32 maxCount = fields[3].GetInt32(); if (maxCount <= 0) - maxCount = INT_MAX; + maxCount = std::numeric_limits::max(); else if (minCount > maxCount) { sLog.Out(LOG_DBERROR, LOG_LVL_MINIMAL, "CREATURE GROUPS: Min count %u is bigger than Max count %u for id %u in group with leader guid %u", minCount, maxCount, creatureId, fields[0].GetUInt32()); diff --git a/src/game/Group/CreatureLinkingMgr.cpp b/src/game/Group/CreatureLinkingMgr.cpp index b9df0b9b480..f1b605d7f8a 100644 --- a/src/game/Group/CreatureLinkingMgr.cpp +++ b/src/game/Group/CreatureLinkingMgr.cpp @@ -82,7 +82,7 @@ void CreatureLinkingMgr::LoadFromDB() } else { - BarGoLink bar((int)result->GetRowCount()); + BarGoLink bar(result->GetRowCount()); do { bar.step(); @@ -127,7 +127,7 @@ void CreatureLinkingMgr::LoadFromDB() return; } - BarGoLink guidBar((int)result->GetRowCount()); + BarGoLink guidBar(result->GetRowCount()); do { guidBar.step(); diff --git a/src/game/InstanceStatistics.cpp b/src/game/InstanceStatistics.cpp index b9a2081afb7..9b49c6ab635 100644 --- a/src/game/InstanceStatistics.cpp +++ b/src/game/InstanceStatistics.cpp @@ -43,7 +43,7 @@ void InstanceStatisticsMgr::LoadFromDB() } else { - BarGoLink bar((int)result->GetRowCount()); + BarGoLink bar(result->GetRowCount()); do { bar.step(); @@ -76,7 +76,7 @@ void InstanceStatisticsMgr::LoadFromDB() } else { - BarGoLink bar((int)result->GetRowCount()); + BarGoLink bar(result->GetRowCount()); do { bar.step(); @@ -122,7 +122,7 @@ void InstanceStatisticsMgr::LoadFromDB() } else { - BarGoLink bar((int)result->GetRowCount()); + BarGoLink bar(result->GetRowCount()); do { bar.step(); diff --git a/src/game/InstanceStatistics.h b/src/game/InstanceStatistics.h index 9b6292d4127..d3f8316835b 100644 --- a/src/game/InstanceStatistics.h +++ b/src/game/InstanceStatistics.h @@ -19,6 +19,7 @@ #ifndef INSTANCE_STATISTICS_H #define INSTANCE_STATISTICS_H +#include #include "Common.h" enum eInstanceCustomCounter : int diff --git a/src/game/Maps/Map.cpp b/src/game/Maps/Map.cpp index a5f6d71e3b9..0eaca5fb6dd 100644 --- a/src/game/Maps/Map.cpp +++ b/src/game/Maps/Map.cpp @@ -161,12 +161,12 @@ Map::Map(uint32 id, time_t expiry, uint32 InstanceId) int numObjThreads = (int)sWorld.getConfig(CONFIG_UINT32_MAP_OBJECTSUPDATE_THREADS); if (numObjThreads > 1) { - m_objectThreads.reset(new ThreadPool(numObjThreads -1)); + m_objectThreads.reset(new ThreadPool("MapObj", numObjThreads -1)); m_objectThreads->start>(); } - m_motionThreads.reset(new ThreadPool(sWorld.getConfig(CONFIG_UINT32_CONTINENTS_MOTIONUPDATE_THREADS))); - m_visibilityThreads.reset(new ThreadPool(std::max((int)sWorld.getConfig(CONFIG_UINT32_MAP_VISIBILITYUPDATE_THREADS) -1,0))); - m_cellThreads.reset(new ThreadPool(std::max((int)sWorld.getConfig(CONFIG_UINT32_MTCELLS_THREADS) - 1, 0))); + m_motionThreads.reset(new ThreadPool("MapMotion", sWorld.getConfig(CONFIG_UINT32_CONTINENTS_MOTIONUPDATE_THREADS))); + m_visibilityThreads.reset(new ThreadPool("MapVis", std::max((int)sWorld.getConfig(CONFIG_UINT32_MAP_VISIBILITYUPDATE_THREADS) -1,0))); + m_cellThreads.reset(new ThreadPool("MapCell", std::max((int)sWorld.getConfig(CONFIG_UINT32_MTCELLS_THREADS) - 1, 0))); m_visibilityThreads->start>(); m_cellThreads->start(); m_motionThreads->start(); diff --git a/src/game/Maps/MapManager.cpp b/src/game/Maps/MapManager.cpp index 010973a6dd1..556f8f01398 100644 --- a/src/game/Maps/MapManager.cpp +++ b/src/game/Maps/MapManager.cpp @@ -33,6 +33,7 @@ #include "Map.h" #include "BattleGround.h" #include "ThreadPool.h" +#include "IO/Multithreading/CreateThread.h" typedef MaNGOS::ClassLevelLockable MapManagerLock; INSTANTIATE_SINGLETON_2(MapManager, MapManagerLock); @@ -42,8 +43,8 @@ MapManager::MapManager() : i_gridCleanUpDelay(sWorld.getConfig(CONFIG_UINT32_INTERVAL_GRIDCLEAN)), i_MaxInstanceId(RESERVED_INSTANCES_LAST), - m_threads(new ThreadPool(sWorld.getConfig(CONFIG_UINT32_MAPUPDATE_INSTANCED_UPDATE_THREADS))), - m_instanceCreationThreads(new ThreadPool(1)) + m_threads(new ThreadPool("MapManager", sWorld.getConfig(CONFIG_UINT32_MAPUPDATE_INSTANCED_UPDATE_THREADS))), + m_instanceCreationThreads(new ThreadPool("NewMapForPlayer", 1)) { i_timer.SetInterval(sWorld.getConfig(CONFIG_UINT32_INTERVAL_MAPUPDATE)); m_threads->start>(); @@ -343,13 +344,13 @@ void MapManager::Update(uint32 diff) instanceCreators.emplace_back([this]() {CreateNewInstancesForPlayers();}); std::future instances = m_instanceCreationThreads->processWorkload(std::move(instanceCreators), ThreadPool::Callable()); - + i_maxContinentThread = continentsIdx; i_continentUpdateFinished.store(0); if (!m_continentThreads || m_continentThreads->size() < continentsUpdaters.size()) { - m_continentThreads.reset(new ThreadPool(continentsUpdaters.size())); + m_continentThreads.reset(new ThreadPool("MapContinent", continentsUpdaters.size())); m_continentThreads->start<>(); } std::future continents = m_continentThreads->processWorkload(std::move(continentsUpdaters), diff --git a/src/game/Maps/MoveMap.cpp b/src/game/Maps/MoveMap.cpp index 8d1ea1ccf1d..f04c83a8474 100644 --- a/src/game/Maps/MoveMap.cpp +++ b/src/game/Maps/MoveMap.cpp @@ -21,6 +21,7 @@ #include "VMapFactory.h" #include "MoveMap.h" #include "MoveMapSharedDefines.h" +#include "Errors.h" namespace MMAP { diff --git a/src/game/Maps/ZoneScriptMgr.h b/src/game/Maps/ZoneScriptMgr.h index 86a0fcbef60..a4a02975a25 100644 --- a/src/game/Maps/ZoneScriptMgr.h +++ b/src/game/Maps/ZoneScriptMgr.h @@ -21,7 +21,6 @@ #define OUTDOORPVP_OBJECTIVE_UPDATE_INTERVAL 1000 #include "ZoneScript.h" -#include class Player; class GameObject; diff --git a/src/game/Movement/spline/spline.h b/src/game/Movement/spline/spline.h index 036ecdb0fd8..215f781f28e 100644 --- a/src/game/Movement/spline/spline.h +++ b/src/game/Movement/spline/spline.h @@ -21,6 +21,7 @@ #include "typedefs.h" #include "Log.h" +#include "Errors.h" #include #include diff --git a/src/game/ObjectGuid.h b/src/game/ObjectGuid.h index 671ab8b9db1..5bacdcb03ef 100644 --- a/src/game/ObjectGuid.h +++ b/src/game/ObjectGuid.h @@ -80,8 +80,8 @@ enum HighGuid // NOSTALRIUS : Code supprime par MaNGOS. Eviter de l'utiliser. #define GUID_HIPART(x) (uint32)((uint64(x) >> 48) & 0x0000FFFF) // We have different low and middle part size for different guid types -#define _GUID_LOPART_2(x) (uint32)(uint64(x) & UI64LIT(0x00000000FFFFFFFF)) -#define _GUID_LOPART_3(x) (uint32)(uint64(x) & UI64LIT(0x0000000000FFFFFF)) +#define _GUID_LOPART_2(x) (uint32)(uint64(x) & uint64(0x00000000FFFFFFFF)) +#define _GUID_LOPART_3(x) (uint32)(uint64(x) & uint64(0x0000000000FFFFFF)) // Pour les codes TrinityCore #define IS_EMPTY_GUID(g) (g == 0) @@ -148,7 +148,7 @@ class ObjectGuid static HighGuid GetHigh(uint64 guid) { return HighGuid((guid >> 48) & 0x0000FFFF); } static void ClampPlayerGuid(uint64& value); HighGuid GetHigh() const { return GetHigh(m_guid); } - uint32 GetEntry() const { return HasEntry() ? uint32((m_guid >> 24) & UI64LIT(0x0000000000FFFFFF)) : 0; } + uint32 GetEntry() const { return HasEntry() ? uint32((m_guid >> 24) & uint64(0x0000000000FFFFFF)) : 0; } uint32 GetCounter() const { return GetCounter(m_guid, HasEntry()); @@ -157,8 +157,8 @@ class ObjectGuid static uint32 GetCounter(uint64 guid, bool hasEntry) { return hasEntry - ? uint32(guid & UI64LIT(0x0000000000FFFFFF)) - : uint32(guid & UI64LIT(0x00000000FFFFFFFF)); + ? uint32(guid & uint64(0x0000000000FFFFFF)) + : uint32(guid & uint64(0x00000000FFFFFFFF)); } static uint32 GetMaxCounter(HighGuid high) diff --git a/src/game/Objects/Object.h b/src/game/Objects/Object.h index 6df526b381e..ee747560bf2 100644 --- a/src/game/Objects/Object.h +++ b/src/game/Objects/Object.h @@ -24,6 +24,7 @@ #include "Common.h" #include "Log.h" +#include "Errors.h" #include "ByteBuffer.h" #include "UpdateFields.h" #include "UpdateData.h" diff --git a/src/game/Objects/Player.cpp b/src/game/Objects/Player.cpp index 252673e3112..ca8ebb3dcac 100644 --- a/src/game/Objects/Player.cpp +++ b/src/game/Objects/Player.cpp @@ -81,6 +81,8 @@ #include "world/scourge_invasion.h" #include "world/world_event_wareffort.h" +#include + #define ZONE_UPDATE_INTERVAL (1*IN_MILLISECONDS) #define PLAYER_SKILL_INDEX(x) (PLAYER_SKILL_INFO_1_1 + ((x)*3)) @@ -22803,7 +22805,8 @@ static char const* type_strings[] = "GM", "GMCritical", "Anticheat", - "Scripts" + "Scripts", + "Network", }; static_assert(sizeof(type_strings) / sizeof(type_strings[0]) == LOG_TYPE_MAX, "type_strings must be updated"); diff --git a/src/game/Objects/Totem.cpp b/src/game/Objects/Totem.cpp index 1f9f23e3382..4dcae460b15 100644 --- a/src/game/Objects/Totem.cpp +++ b/src/game/Objects/Totem.cpp @@ -174,18 +174,11 @@ void Totem::SetTypeBySummonSpell(SpellEntry const* spellProto) bool Totem::IsImmuneToSpellEffect(SpellEntry const* spellInfo, SpellEffectIndex index, bool castOnSelf) const { - // Check for Mana Spring & Healing Stream totems - switch (spellInfo->SpellFamilyName) - { - case SPELLFAMILY_SHAMAN: - if (spellInfo->IsFitToFamilyMask(UI64LIT(0x00000002000)) || - spellInfo->IsFitToFamilyMask(UI64LIT(0x00000004000)) || - spellInfo->IsFitToFamilyMask(UI64LIT(0x00004000000))) - return false; - break; - default: - break; - } + // Totem may affected by some specific spells + // Mana Spring, Healing stream, Mana tide + // Flags : 0x00000002000 | 0x00000004000 | 0x00004000000 -> 0x00004006000 + if (spellInfo->SpellFamilyName == SPELLFAMILY_SHAMAN && spellInfo->IsFitToFamilyMask(uint64(0x00004006000))) + return false; // Totems should not be immune to self casted spells. if (castOnSelf) diff --git a/src/game/Objects/Unit.cpp b/src/game/Objects/Unit.cpp index 456b11af05b..0a6469bb99a 100644 --- a/src/game/Objects/Unit.cpp +++ b/src/game/Objects/Unit.cpp @@ -56,6 +56,7 @@ #include "Anticheat.h" #include "InstanceStatistics.h" #include "MovementPacketSender.h" +#include "Errors.h" //#define DEBUG_DEBUFF_LIMIT @@ -10164,9 +10165,8 @@ void Unit::CleanupDeletedAuras() // - Player::SetDeathState // - Pet::AddObjectToRemoveList // Seen happening with spells like [Health Funnel], [Tainted Blood] - ACE_Stack_Trace st; sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "[Crash/Auras] Deleting aura holder %u in use (%s)", iter->GetId(), GetObjectGuid().GetString().c_str()); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "%s", st.c_str()); + MaNGOS::Errors::PrintStacktrace(); } else delete iter; @@ -10178,9 +10178,8 @@ void Unit::CleanupDeletedAuras() { if (iter->IsInUse()) { - ACE_Stack_Trace st; sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "[Crash/Auras] Deleting aura %u in use (%s)", iter->GetId(), GetObjectGuid().GetString().c_str()); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "%s", st.c_str()); + MaNGOS::Errors::PrintStacktrace(); } else delete iter; diff --git a/src/game/PacketBroadcast/MovementBroadcaster.h b/src/game/PacketBroadcast/MovementBroadcaster.h index 3e147c3765a..be760edbbdc 100644 --- a/src/game/PacketBroadcast/MovementBroadcaster.h +++ b/src/game/PacketBroadcast/MovementBroadcaster.h @@ -13,7 +13,7 @@ class PlayerBroadcaster; class MovementBroadcaster final { - typedef std::set > PlayersBCastSet; + typedef std::set> PlayersBCastSet; std::size_t m_num_threads; diff --git a/src/game/PacketBroadcast/PlayerBroadcaster.cpp b/src/game/PacketBroadcast/PlayerBroadcaster.cpp index c7f908e5346..f5a8cda1b96 100644 --- a/src/game/PacketBroadcast/PlayerBroadcaster.cpp +++ b/src/game/PacketBroadcast/PlayerBroadcaster.cpp @@ -1,4 +1,5 @@ #include "PlayerBroadcaster.h" + #include "MovementBroadcaster.h" #include "WorldPacket.h" #include "WorldSocket.h" @@ -7,22 +8,15 @@ uint32 PlayerBroadcaster::num_bcaster_created = 0; uint32 PlayerBroadcaster::num_bcaster_deleted = 0; -PlayerBroadcaster::PlayerBroadcaster(WorldSocket* w_socket, ObjectGuid const& self, std::size_t max_queue) - : MAX_QUEUE_SIZE(max_queue), m_socket(w_socket), m_self(self), instanceId(0), lastUpdatePackets(0) +PlayerBroadcaster::PlayerBroadcaster(std::shared_ptr socket, ObjectGuid const& self, std::size_t max_queue) + : MAX_QUEUE_SIZE(max_queue), m_socket(std::move(socket)), m_self(self), instanceId(0), lastUpdatePackets(0) { - if (m_socket) - m_socket->AddReference(); - m_queue.reserve(max_queue); ++num_bcaster_created; } -void PlayerBroadcaster::ChangeSocket(WorldSocket* new_socket) +void PlayerBroadcaster::ChangeSocket(std::shared_ptr const& new_socket) { - if (m_socket) - m_socket->RemoveReference(); - if (new_socket) - new_socket->AddReference(); m_socket = new_socket; } @@ -117,11 +111,7 @@ ObjectGuid PlayerBroadcaster::GetGUID() const void PlayerBroadcaster::FreeAtLogout() { - if (m_socket) - { - m_socket->RemoveReference(); - m_socket = nullptr; - } + m_socket = nullptr; std::unique_lock q_g(m_queue_lock), v_g(m_listeners_lock); m_queue.clear(); m_listeners.clear(); @@ -129,7 +119,6 @@ void PlayerBroadcaster::FreeAtLogout() PlayerBroadcaster::~PlayerBroadcaster() { - if (m_socket) - m_socket->RemoveReference(); + m_socket = nullptr; ++num_bcaster_deleted; } diff --git a/src/game/PacketBroadcast/PlayerBroadcaster.h b/src/game/PacketBroadcast/PlayerBroadcaster.h index 91daf4a0500..93ae06a3d51 100644 --- a/src/game/PacketBroadcast/PlayerBroadcaster.h +++ b/src/game/PacketBroadcast/PlayerBroadcaster.h @@ -23,7 +23,7 @@ class PlayerBroadcaster final std::size_t const MAX_QUEUE_SIZE; - WorldSocket* m_socket; + std::shared_ptr m_socket; ObjectGuid m_self; std::map > m_listeners; @@ -45,13 +45,13 @@ class PlayerBroadcaster final uint32 lastUpdatePackets; public: - PlayerBroadcaster(WorldSocket* socket, ObjectGuid const& self, std::size_t max_queue = 500); + PlayerBroadcaster(std::shared_ptr socket, ObjectGuid const& self, std::size_t max_queue = 500); ~PlayerBroadcaster(); static uint32 num_bcaster_created; static uint32 num_bcaster_deleted; - void ChangeSocket(WorldSocket* new_socket); + void ChangeSocket(std::shared_ptr const& new_socket); void FreeAtLogout(); ObjectGuid GetGUID() const; diff --git a/src/game/Protocol/WorldSocket.cpp b/src/game/Protocol/WorldSocket.cpp index 4a5e76e82e3..26605a2afad 100644 --- a/src/game/Protocol/WorldSocket.cpp +++ b/src/game/Protocol/WorldSocket.cpp @@ -3,6 +3,7 @@ * Copyright (C) 2009-2011 MaNGOSZero * Copyright (C) 2011-2016 Nostalrius * Copyright (C) 2016-2017 Elysium Project + * Copyright (C) 2017-2024 VMaNGOS Project * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -25,123 +26,208 @@ #include "SharedDefines.h" #include "WorldSession.h" #include "WorldSocket.h" -#include "WorldSocketMgr.h" + +#include #include "AddonHandler.h" #include "Opcodes.h" -#include "MangosSocketImpl.h" -#include "ace/OS_NS_netdb.h" #include "Crypto/Hash/SHA1.h" +#include "Database/SqlPreparedStatement.h" +#include "Database/DatabaseEnv.h" +#include "DBCStores.h" +#include "IO/Networking/DNS.h" +#include "WorldSocketMgr.h" -template class MangosSocket; +#if defined( __GNUC__ ) +#pragma pack(1) +#else +#pragma pack(push,1) +#endif +struct ServerPktHeader +{ + uint16 size; + uint16 cmd; -int WorldSocket::ProcessIncoming(WorldPacket* new_pct) + char const* data() const + { + return reinterpret_cast(this); + } + + std::size_t headerSize() const + { + return sizeof(ServerPktHeader); + } +}; +#if defined( __GNUC__ ) +#pragma pack() +#else +#pragma pack(pop) +#endif + +WorldSocket::WorldSocket(IO::Networking::AsyncSocket socket) + : m_socket(std::move(socket)), + m_lastPingTime(std::chrono::system_clock::time_point::min()), + m_overSpeedPings(0), + m_Session(nullptr), + m_authSeed(static_cast(rand32())), + m_remoteIpAddressStringAfterProxy(m_socket.GetRemoteIpString()) { - MANGOS_ASSERT(new_pct); + m_sendQueueIsRunning.clear(); // there is no atomic_flag::constructor on windows to initialize it with false by default (and if left out, linux is uninitialized and will fail randomly) +} - // manage memory ;) - std::unique_ptr aptr(new_pct); +WorldSocket::~WorldSocket() +{ + CloseSocket(); + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] Connection closed", GetRemoteIpString().c_str()); +} + +void WorldSocket::DoRecvIncomingData() +{ + std::shared_ptr header = std::make_shared(); + + m_socket.Read((char*)header.get(), sizeof(ClientPktHeader), [self = shared_from_this(), header](IO::NetworkError const& error, std::size_t) -> void + { + if (error) + { + if (error.GetErrorType() != IO::NetworkError::ErrorType::SocketClosed || !self->IsClosing()) // only print error if it's not "normal close" related + { + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] WorldSocket::DoRecvIncomingData: IoError: %s", self->m_socket.GetRemoteIpString().c_str(), error.ToString().c_str()); + self->CloseSocket(); // This call to CloseSocket is actually necessary for once, so that others can see that this socket is not usable anymore + } + return; + } + + // thread safe due to always being called from service context + self->m_Crypt.DecryptRecv((uint8*)header.get(), sizeof(ClientPktHeader)); - const ACE_UINT16 opcode = new_pct->GetOpcode(); + EndianConvertReverse(header->size); + EndianConvert(header->cmd); + + if ((header->size < 4) || (header->size > 0x2800) || (header->cmd >= NUM_MSG_TYPES)) + { + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] WorldSocket::DoRecvIncomingData: client sent malformed packet size = %u, cmd = %u", self->m_socket.GetRemoteIpString().c_str(), header->size, header->cmd); + return; + } + + size_t remainingPacketSize = header->size - sizeof(header->cmd); + if (remainingPacketSize == 0) + { // Fastpath, it's probably an OpCode without any data + auto packet = std::make_unique(header->cmd, 0); + if (self->_HandleCompleteReceivedPacket(std::move(packet)) == HandlerResult::Okay) + self->DoRecvIncomingData(); + } + else + { + // Allocate WorldPacket once and write into the memory inplace, no need to move or copy stuff + // Cannot move std::unique_ptr into function capture, so it's wrapped into std::shared_ptr + std::shared_ptr> packetTmpSharedPtr(new std::unique_ptr(new WorldPacket(header->cmd, remainingPacketSize))); + (*packetTmpSharedPtr)->resize(remainingPacketSize); + self->m_socket.Read((char*)((*packetTmpSharedPtr)->contents()), (*packetTmpSharedPtr)->size(), [self, packetTmpSharedPtr](IO::NetworkError const& error, std::size_t) -> void + { + if (error) + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "WorldSocket::DoRecvIncomingData: Error %s", error.ToString().c_str()); + self->CloseSocket(); + return; + } + + // by std::moving the content of the shared_ptr, we will separate the unique_ptr out of the shared_ptr. + if (self->_HandleCompleteReceivedPacket(std::move(*packetTmpSharedPtr)) == HandlerResult::Okay) + self->DoRecvIncomingData(); + }); + } + }); +} + +WorldSocket::HandlerResult WorldSocket::_HandleCompleteReceivedPacket(std::unique_ptr packet) +{ + uint16 const opcode = packet->GetOpcode(); if (opcode >= NUM_MSG_TYPES) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "SESSION: received nonexistent opcode 0x%.4X", opcode); - return -1; + return HandlerResult::Fail; } - if (closing_) - return -1; + if (IsClosing()) + return HandlerResult::Fail; - new_pct->FillPacketTime(WorldTimer::getMSTime()); + packet->FillPacketTime(WorldTimer::getMSTime()); try { switch (opcode) { case CMSG_PING: - return HandlePing(*new_pct); + return _HandlePing(*packet); case CMSG_AUTH_SESSION: - if (m_Session) + if (m_Session != nullptr) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::ProcessIncoming: Player send CMSG_AUTH_SESSION again"); - return -1; + return HandlerResult::Fail; } - - return HandleAuthSession(*new_pct); + return _HandleAuthSession(*packet); default: - { - GuardType lock(m_SessionLock); - - if (m_Session != nullptr) - { - // WARNINIG here we call it with locks held. - // Its possible to cause deadlock if QueuePacket calls back - m_Session->QueuePacket(std::move(aptr)); - return 0; - } - else + if (m_Session == nullptr) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); - return -1; + return HandlerResult::Fail; } - } + + m_Session->QueuePacket(std::move(packet)); + return HandlerResult::Okay; } } - catch (ByteBufferException &) + catch (ByteBufferException&) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::ProcessIncoming ByteBufferException occured while parsing an instant handled packet (opcode: %u) from client %s, accountid=%i.", - opcode, GetRemoteAddress().c_str(), m_Session ? m_Session->GetAccountId() : -1); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::ProcessIncoming ByteBufferException occured while parsing an instant handled packet (opcode: %u) from client %s, accountid=%i.", opcode, GetRemoteIpString().c_str(), m_Session ? m_Session->GetAccountId() : -1); + if (sLog.HasLogLevelOrHigher(LOG_LVL_DEBUG)) { sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Dumping error-causing packet:"); - new_pct->hexlike(); + packet->hexlike(); } if (sWorld.getConfig(CONFIG_BOOL_KICK_PLAYER_ON_BAD_PACKET)) { sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "Disconnecting session [account id %i / address %s] for badly formatted packet.", - m_Session ? m_Session->GetAccountId() : -1, GetRemoteAddress().c_str()); + m_Session ? m_Session->GetAccountId() : -1, GetRemoteIpString().c_str()); - return -1; + return HandlerResult::Fail; } - else - return 0; + + return HandlerResult::Okay; } - ACE_NOTREACHED(return 0); + MANGOS_ASSERT(false); // This should never be reached } +/// This function will resolve the ip-addresse of the current host +/// For example if you hostname is called "world.mycoolserver.com" and it points to 123.45.66.7 it will be added to the server list +/// Also 127.0.0.1 will be added as a fallback +/// This list is later used to determine if clients try to connect to this server without registering at realmd first static std::set GetServerAddresses() { std::set addresses; - char hostName[MAXHOSTNAMELEN] = {}; + addresses.insert("127.0.0.1"); - if (ACE_OS::hostname(hostName, MAXHOSTNAMELEN) != -1) + std::string myHostname = IO::Networking::DNS::GetOwnHostname(); + std::vector ipAddresses = IO::Networking::DNS::ResolveDomainAll(myHostname, IO::Networking::IpAddress::Type::IPv4); + for (auto const& ipAddress : ipAddresses) { - if (hostent* hp = ACE_OS::gethostbyname(hostName)) - { - for (int i = 0; hp->h_addr_list[i] != 0; ++i) - { - in_addr addr; - memcpy(&addr, hp->h_addr_list[i], sizeof(in_addr)); - addresses.insert(ACE_OS::inet_ntoa(addr)); - } - } + addresses.insert(ipAddress.ToString()); } - addresses.insert("127.0.0.1"); - return addresses; } -int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) +WorldSocket::HandlerResult WorldSocket::_HandleAuthSession(WorldPacket& recvPacket) { - // NOTE: ATM the socket is singlethread, have this in mind ... Crypto::Hash::SHA1::Digest digest; uint32 clientSeed; uint32 serverId; uint32 clientBuild; - uint32 id, security; + uint32 accountId; + AccountTypes security; LocaleConstant locale; std::string account, os, platform; BigNumber K; @@ -171,21 +257,36 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) SendPacket(packet); sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandleAuthSession: Sent Auth Response (version mismatch)."); - return -1; + return HandlerResult::Fail; } // Get the account information from the realmd database std::string safe_account = account; // Duplicate, else will screw the SHA hash verification below LoginDatabase.escape_string(safe_account); // No SQL injection, username escaped. - // 0 1 2 3 4 5 6 7 8 9 10 - std::unique_ptr result(LoginDatabase.PQuery("SELECT a.`id`, aa.`gmLevel`, a.`sessionkey`, a.`last_ip`, a.`v`, a.`s`, a.`mutetime`, a.`locale`, a.`os`, a.`platform`, a.`flags`, " - // 11 12 13 - "a.`email`, a.`email_verif`, ab.`unbandate` > UNIX_TIMESTAMP() OR ab.`unbandate` = ab.`bandate` FROM `account` a LEFT JOIN `account_access` aa ON a.`id` = aa.`id` AND aa.`RealmID` IN (-1, %u) " - "LEFT JOIN `account_banned` ab ON a.`id` = ab.`id` AND ab.`active` = 1 WHERE a.`username` = '%s' && DATEDIFF(NOW(), a.`last_login`) < 1 ORDER BY aa.`RealmID` DESC LIMIT 1", realmID, safe_account.c_str())); + auto accountQueryResult = + LoginDatabase.PQuery("SELECT " + "a.`id`, " // 0 + "aa.`gmLevel`, " // 1 + "a.`sessionkey`, " // 2 + "a.`last_ip`, " // 3 + "a.`v`, " // 4 + "a.`s`, " // 5 + "a.`mutetime`, " // 6 + "a.`locale`, " // 7 + "a.`os`, " // 8 + "a.`platform`, " // 9 + "a.`flags`, " // 10 + "a.`email`, " // 11 + "a.`email_verif`, " // 12 + "ab.`unbandate` > UNIX_TIMESTAMP() OR ab.`unbandate` = ab.`bandate` " // 13 + "FROM `account` a " + "LEFT JOIN `account_access` aa ON a.`id` = aa.`id` AND aa.`RealmID` IN (-1, %u) " + "LEFT JOIN `account_banned` ab ON a.`id` = ab.`id` AND ab.`active` = 1 WHERE a.`username` = '%s' && DATEDIFF(NOW(), a.`last_login`) < 1 " + "ORDER BY aa.`RealmID` DESC LIMIT 1", realmID, safe_account.c_str()); // Stop if the account is not found - if (!result) + if (!accountQueryResult) { packet.Initialize(SMSG_AUTH_RESPONSE, 1); packet << uint8(AUTH_UNKNOWN_ACCOUNT); @@ -193,34 +294,33 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) SendPacket(packet); sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandleAuthSession: Sent Auth Response (unknown account)."); - return -1; + return HandlerResult::Fail; } - Field* fields = result->Fetch(); + Field* fields = accountQueryResult->Fetch(); // Prevent connecting directly to mangosd by checking // that same ip connected to realmd previously. - if (strcmp(fields[3].GetString(), GetRemoteAddress().c_str()) && - serverAddressList.find(GetRemoteAddress()) == serverAddressList.end()) + if (fields[3].GetCppString() != GetRemoteIpString() && serverAddressList.find(GetRemoteIpString()) == serverAddressList.end()) { packet.Initialize(SMSG_AUTH_RESPONSE, 1); packet << uint8(AUTH_FAILED); SendPacket(packet); - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "WorldSocket::HandleAuthSession: Sent Auth Response (Account IP differs)."); - return -1; + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "WorldSocket::HandleAuthSession: Sent Auth Response (Account IP differs from realmd)."); + return HandlerResult::Fail; } - id = fields[0].GetUInt32(); - security = fields[1].GetString() ? fields[1].GetUInt32() : SEC_PLAYER; + accountId = fields[0].GetUInt32(); + security = fields[1].GetString() ? (AccountTypes)(fields[1].GetUInt32()) : SEC_PLAYER; if (security > SEC_ADMINISTRATOR) // prevent invalid security settings in DB security = SEC_ADMINISTRATOR; K.SetHexStr(fields[2].GetString()); if (K.AsByteArray().empty()) - return -1; + return HandlerResult::Fail; - time_t mutetime = time_t (fields[6].GetUInt64()); + time_t mutetime = time_t(fields[6].GetUInt64()); locale = LocaleConstant(fields[7].GetUInt8()); if (locale >= MAX_LOCALE) locale = LOCALE_enUS; @@ -231,34 +331,34 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) bool verifiedEmail = fields[12].GetBool() || email.empty(); // treat no email as verified (created from console) bool isBanned = fields[13].GetBool(); - if (isBanned || sAccountMgr.IsIPBanned(GetRemoteAddress())) + if (isBanned || sAccountMgr.IsIPBanned(GetRemoteIpString())) { packet.Initialize(SMSG_AUTH_RESPONSE, 1); packet << uint8(AUTH_BANNED); SendPacket(packet); sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandleAuthSession: Sent Auth Response (Account banned)."); - return -1; + return HandlerResult::Fail; } // Check locked state for server AccountTypes allowedAccountType = sWorld.GetPlayerSecurityLimit(); - if (allowedAccountType > SEC_PLAYER && AccountTypes(security) < allowedAccountType) + if (allowedAccountType > SEC_PLAYER && security < allowedAccountType) { packet.Initialize(SMSG_AUTH_RESPONSE, 1); packet << uint8(AUTH_UNAVAILABLE); SendPacket(packet); sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "WorldSocket::HandleAuthSession: User tries to login but his security level is not enough"); - return -1; + return HandlerResult::Fail; } // Check that Key and account name are the same on client and server Crypto::Hash::SHA1::Generator sha; uint32 t = 0; - uint32 seed = m_Seed; + uint32 seed = m_authSeed; sha.UpdateData(account); sha.UpdateData((uint8 *) &t, 4); @@ -273,11 +373,11 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) packet << uint8(AUTH_FAILED); SendPacket(packet); - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandleAuthSession: Sent Auth Response (authentification failed)."); - return -1; + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandleAuthSession: Sent Auth Response (authentification failed), account ID: %u.", accountId); + return HandlerResult::Fail; } - std::string address = GetRemoteAddress(); + std::string address = GetRemoteIpString(); sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "WorldSocket::HandleAuthSession: Client '%s' authenticated successfully from %s.", account.c_str(), @@ -298,7 +398,7 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) else { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandleAuthSession: Unrecognized OS '%s' for account '%s' from %s", os.c_str(), account.c_str(), address.c_str()); - return -1; + return HandlerResult::Fail; } ClientPlatformType clientPlatform; @@ -309,11 +409,10 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) else { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandleAuthSession: Unrecognized Platform '%s' for account '%s' from %s", platform.c_str(), account.c_str(), address.c_str()); - return -1; + return HandlerResult::Fail; } - // NOTE ATM the socket is single-threaded, have this in mind ... - ACE_NEW_RETURN(m_Session, WorldSession(id, this, AccountTypes(security), mutetime, locale), -1); + m_Session = new WorldSession(accountId, this->shared_from_this(), security, mutetime, locale); m_Crypt.SetKey(K.AsByteArray()); m_Crypt.Init(); @@ -327,7 +426,7 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) m_Session->SetSessionKey(K); m_Session->LoadGlobalAccountData(); m_Session->LoadTutorialsData(); - sAccountMgr.UpdateAccountData(id, account, email, verifiedEmail, AccountTypes(security)); + sAccountMgr.UpdateAccountData(accountId, account, email, verifiedEmail, security); sWorld.AddSession(m_Session); @@ -335,10 +434,10 @@ int WorldSocket::HandleAuthSession(WorldPacket& recvPacket) if (sAddOnHandler.BuildAddonPacket(&recvPacket, &addonPacket)) SendPacket(addonPacket); - return 0; + return HandlerResult::Okay; } -int WorldSocket::HandlePing(WorldPacket& recvPacket) +WorldSocket::HandlerResult WorldSocket::_HandlePing(WorldPacket& recvPacket) { uint32 ping; #if SUPPORTED_CLIENT_BUILD > CLIENT_BUILD_1_8_4 @@ -351,74 +450,155 @@ int WorldSocket::HandlePing(WorldPacket& recvPacket) recvPacket >> latency; #endif - if (m_LastPingTime == ACE_Time_Value::zero) - m_LastPingTime = ACE_OS::gettimeofday(); // for 1st ping + if (m_lastPingTime == std::chrono::system_clock::time_point::min()) + m_lastPingTime = std::chrono::system_clock::now(); // for 1st ping else { - ACE_Time_Value cur_time = ACE_OS::gettimeofday(); - ACE_Time_Value diff_time(cur_time); - diff_time -= m_LastPingTime; - m_LastPingTime = cur_time; + auto now = std::chrono::system_clock::now(); + std::chrono::seconds seconds = std::chrono::duration_cast(now - m_lastPingTime); + m_lastPingTime = now; - if (diff_time < ACE_Time_Value(27)) + if (seconds.count() < 27) { - ++m_OverSpeedPings; - - uint32 max_count = sWorld.getConfig(CONFIG_UINT32_MAX_OVERSPEED_PINGS); + ++m_overSpeedPings; - if (max_count && m_OverSpeedPings > max_count) + uint32 maxAllowedOverspeedPings = sWorld.getConfig(CONFIG_UINT32_MAX_OVERSPEED_PINGS); + if (maxAllowedOverspeedPings && m_overSpeedPings > maxAllowedOverspeedPings) { - GuardType lock(m_SessionLock); - if (m_Session && m_Session->GetSecurity() == SEC_PLAYER) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandlePing: Player kicked for " - "overspeeded pings address = %s", - GetRemoteAddress().c_str()); - - return -1; + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandlePing: Player kicked for overspeeded pings address = %s", GetRemoteIpString().c_str()); + return HandlerResult::Fail; } } } else - m_OverSpeedPings = 0; + { + m_overSpeedPings = 0; + } } // critical section { - GuardType lock(m_SessionLock); - -#if SUPPORTED_CLIENT_BUILD > CLIENT_BUILD_1_8_4 if (m_Session) + { +#if SUPPORTED_CLIENT_BUILD > CLIENT_BUILD_1_8_4 m_Session->SetLatency(latency); - else -#else - if (!m_Session) #endif + } + else { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandlePing: peer sent CMSG_PING, " - "but is not authenticated or got recently kicked," - " address = %s", - GetRemoteAddress().c_str()); - return -1; + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "WorldSocket::HandlePing: peer sent CMSG_PING, but is not authenticated or got recently kicked, address = %s", GetRemoteIpString().c_str()); + return HandlerResult::Fail; } } WorldPacket packet(SMSG_PONG, 4); packet << ping; - return SendPacket(packet); -} + SendPacket(packet); -int WorldSocket::OnSocketOpen() -{ - return sWorldSocketMgr->OnSocketOpen(this); + return HandlerResult::Okay; } -int WorldSocket::SendStartupPacket() +void WorldSocket::SendInitialPacketAndStartRecvLoop() { // Send startup packet. WorldPacket packet(SMSG_AUTH_CHALLENGE, 4); - packet << m_Seed; + packet << m_authSeed; + + SendPacket(packet); + + DoRecvIncomingData(); +} + +void WorldSocket::SendPacket(WorldPacket packet) +{ + if (IsClosing()) + return; + + // We don't want to allocate or encrypt anything inside the world thread, so we move everything to the IO thread. + m_sendQueueLock.lock(); + if (m_sendQueue.size() > 1024) // There should never be so many packets queued up. The socket is probably not responding. + { + m_sendQueueLock.unlock(); + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] Send queue is full. Disconnecting.", GetRemoteIpString().c_str()); + CloseSocket(); + return; + } + m_sendQueue.push(std::move(packet)); + m_sendQueueLock.unlock(); - return SendPacket(packet); + // Start AsyncProcessingSendQueue which take things from the queue + if (m_sendQueueIsRunning.test_and_set()) + return; // already running + + m_socket.EnterIoContext([self = shared_from_this()](IO::NetworkError error) + { + self->HandleResultOfAsyncWrite(error, std::make_shared()); + }); +} + +void WorldSocket::HandleResultOfAsyncWrite(IO::NetworkError const& error, std::shared_ptr const& alreadyAllocatedBuffer) +{ + if (error) + { + if (error.GetErrorType() != IO::NetworkError::ErrorType::SocketClosed || !IsClosing()) // only print error if it's not "normal close" related + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] WorldSocket::HandleResultOfAsyncWrite: IoError: %s", GetRemoteIpString().c_str(), error.ToString().c_str()); + CloseSocket(); // This call to CloseSocket is actually necessary for once, so that others can see that this socket is not usable anymore + } + + m_sendQueueIsRunning.clear(); + return; + } + + if (m_sendQueue.empty()) + { + m_sendQueueIsRunning.clear(); + return; + } + + // Combine all packets into `alreadyAllocatedBuffer` + alreadyAllocatedBuffer->clear(); + while (!m_sendQueue.empty()) + { + m_sendQueueLock.lock(); + if (m_sendQueue.empty()) // re-check after we locked the queue if it's really not empty + { + m_sendQueueLock.unlock(); + break; + } + WorldPacket packet = std::move(m_sendQueue.front()); + m_sendQueue.pop(); + m_sendQueueLock.unlock(); + + ServerPktHeader header{}; + + header.cmd = packet.GetOpcode(); + EndianConvert(header.cmd); + + header.size = static_cast(packet.size() + 2); + EndianConvertReverse(header.size); + + m_Crypt.EncryptSend(reinterpret_cast(&header), sizeof(header)); // in vanilla versions of the game only the header is encrypted + + alreadyAllocatedBuffer->append(header.data(), header.headerSize()); + if (!packet.empty()) + alreadyAllocatedBuffer->append(packet.contents(), packet.size()); + } + + m_socket.Write({ alreadyAllocatedBuffer /* dont move, re-used in lambda */ }, [self = shared_from_this(), alreadyAllocatedBuffer](IO::NetworkError const& error) + { + self->HandleResultOfAsyncWrite(error, alreadyAllocatedBuffer); + }); +} + +void WorldSocket::Start() +{ + SendInitialPacketAndStartRecvLoop(); +} + +void WorldSocket::CloseSocket() +{ + m_socket.CloseSocket(); } diff --git a/src/game/Protocol/WorldSocket.h b/src/game/Protocol/WorldSocket.h index 0f4e5375d9e..49d0e4c4a6c 100644 --- a/src/game/Protocol/WorldSocket.h +++ b/src/game/Protocol/WorldSocket.h @@ -3,6 +3,7 @@ * Copyright (C) 2009-2011 MaNGOSZero * Copyright (C) 2011-2016 Nostalrius * Copyright (C) 2016-2017 Elysium Project + * Copyright (C) 2017-2024 VMaNGOS Project * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -19,42 +20,108 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -/** \addtogroup u2w User to World Communication - * @{ - * \file WorldSocket.h - * \author Derex - */ - -#ifndef _WORLDSOCKET_H -#define _WORLDSOCKET_H +#ifndef MANGOS_GAME_SERVER_WORLDSOCKET_H +#define MANGOS_GAME_SERVER_WORLDSOCKET_H -#include "MangosSocket.h" +#include "IO/Networking/AsyncSocket.h" #include "Auth/AuthCrypt.h" +#include "WorldPacket.h" +#include "WorldSession.h" -template -class ReactorRunnable; -template -class MangosSocketMgr; +class WorldSocketMgr; -class WorldSocket: public MangosSocket +class WorldSocket final : public std::enable_shared_from_this { - friend class MangosSocket; - friend class MangosSocketMgr; - friend class WorldSocketMgr; - friend class ReactorRunnable< WorldSocket >; - protected: - int OnSocketOpen(); - int SendStartupPacket(); + friend WorldSocketMgr; - int ProcessIncoming (WorldPacket* new_pct); +private: + enum class HandlerResult + { + Okay, + Fail, + }; - // Called by ProcessIncoming() on CMSG_AUTH_SESSION. - int HandleAuthSession (WorldPacket& recvPacket); +#if defined( __GNUC__ ) +#pragma pack(1) +#else +#pragma pack(push,1) +#endif + struct ClientPktHeader + { + uint16 size; + uint32 cmd; + }; +#if defined( __GNUC__ ) +#pragma pack() +#else +#pragma pack(pop) +#endif - // Called by ProcessIncoming() on CMSG_PING. - int HandlePing (WorldPacket& recvPacket); -}; + /// Time in which the last ping was received + std::chrono::system_clock::time_point m_lastPingTime; + + /// Keep track of over-speed pings, to prevent ping flood. + uint32 m_overSpeedPings; + + /// Class used for managing encryption of the headers + AuthCrypt m_Crypt; // TODO: Rename me to m_crypt + + /// Session to which received packets are routed + WorldSession* m_Session; // TODO: Rename me to m_session + + /// Random seed used in SMSG_AUTH_CHALLENGE and CMSG_AUTH_SESSION + uint32 const m_authSeed; + + /// Session key used to authenticate the client (value from db `account` table) + //BigNumber m_authSessionKey; + + /// Starting the recv loop + void Start(); + + /// Called by WorldSocketMgr when a new connection is made + void SendInitialPacketAndStartRecvLoop(); + + /// process one incoming packet. + void DoRecvIncomingData(); -#endif /* _WORLDSOCKET_H */ + /// Encrypt and write to queue + void HandleResultOfAsyncWrite(IO::NetworkError const& error, std::shared_ptr const& alreadyAllocatedBuffer); + + HandlerResult _HandleCompleteReceivedPacket(std::unique_ptr packet); + + /// Called by ProcessIncoming() on CMSG_AUTH_SESSION. + HandlerResult _HandleAuthSession(WorldPacket& recvPacket); + + /// Called by ProcessIncoming() on CMSG_PING. + HandlerResult _HandlePing(WorldPacket& recvPacket); + + std::mutex m_sendQueueLock; + std::queue m_sendQueue; + std::atomic_flag m_sendQueueIsRunning; + + IO::Networking::AsyncSocket m_socket; + std::string m_remoteIpAddressStringAfterProxy; // might differ from `m_socket.m_descriptor` if behind proxy + +public: + explicit WorldSocket(IO::Networking::AsyncSocket socket); + /// The destructor will automatically close the socket + ~WorldSocket(); + WorldSocket(WorldSocket const&) = delete; + WorldSocket& operator=(WorldSocket const&) = delete; + WorldSocket(WorldSocket&&) = delete; + WorldSocket& operator=(WorldSocket&&) = delete; + + void SendPacket(WorldPacket packet); + + void FinalizeSession() + { + m_Session = nullptr; + } + + // ----- Exposing `m_socket` features ----- + inline std::string const& GetRemoteIpString() const { return m_remoteIpAddressStringAfterProxy; } + inline bool IsClosing() const { return m_socket.IsClosing(); } + void CloseSocket(); +}; -// @} +#endif // MANGOS_GAME_SERVER_WORLDSOCKET_H diff --git a/src/game/Protocol/WorldSocketMgr.cpp b/src/game/Protocol/WorldSocketMgr.cpp index 6a7735b27b8..5a193de088e 100644 --- a/src/game/Protocol/WorldSocketMgr.cpp +++ b/src/game/Protocol/WorldSocketMgr.cpp @@ -3,6 +3,7 @@ * Copyright (C) 2009-2011 MaNGOSZero * Copyright (C) 2011-2016 Nostalrius * Copyright (C) 2016-2017 Elysium Project + * Copyright (C) 2017-2024 VMaNGOS Project * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -19,18 +20,104 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -/** \file WorldSocketMgr.cpp -* \ingroup u2w -* \author Derex -*/ - -#include "WorldSocket.h" #include "WorldSocketMgr.h" -#include "MangosSocketMgrImpl.h" +#include "WorldSocket.h" +#include "Policies/SingletonImp.h" +#include "IO/Networking/AsyncSocketAcceptor.h" +#include "IO/Multithreading/CreateThread.h" +#include "ProxyProtocol/ProxyV2Reader.h" + +INSTANTIATE_SINGLETON_1(WorldSocketMgr); + +bool WorldSocketMgr::StartWorldNetworking(IO::IoContext* ioCtx, WorldSocketMgrOptions const& options) +{ + m_ioContext = ioCtx; + m_settings = options; + + // Launch the listening network socket + m_listener = IO::Networking::AsyncSocketAcceptor::CreateAndBindServer(ioCtx, options.bindIp, options.bindPort); + if (m_listener == nullptr) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Failed to start WorldSocket network"); + return false; + } + m_listener->AutoAcceptSocketsUntilClose([this](IO::Networking::SocketDescriptor socketDescriptor) + { + this->OnNewClientConnected(std::move(socketDescriptor)); + }); + + return true; +} -template class MangosSocketMgr; +void WorldSocketMgr::StopWorldNetworking() +{ + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Stop world networking..."); + if (m_listener != nullptr) + { + m_listener->ClosePortAndStopAcceptingNewConnections(); + m_listener = nullptr; + } +} + +void WorldSocketMgr::OnNewClientConnected(IO::Networking::SocketDescriptor socketDescriptor) +{ + // Attach descriptor to AsyncSocket and configure it before attaching it to the WorldSocket + IO::IoContext* ioContext = GetLeastUsedIoContext(); + auto worldSocket = std::make_shared(std::move(IO::Networking::AsyncSocket(ioContext, std::move(socketDescriptor)))); + std::string const& socketIp = worldSocket->m_socket.GetRemoteIpString(); + + if (IO::NetworkError initError = worldSocket->m_socket.InitializeAndFixateMemoryLocation()) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[%s] Failed to InitializeAndFixateMemoryLocation %s", socketIp.c_str(), initError.ToString().c_str()); + return; // implicit close() + } + + if (m_settings.socketOutByteBufferSize >= 0) + { + IO::NetworkError error = worldSocket->m_socket.SetNativeSocketOption_SystemOutgoingSendBuffer(m_settings.socketOutByteBufferSize); + if (error) + { // We don't close the socket, since its basically just a "warning" I guess. + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] Failed to set SystemOutgoingSendBuffer option on socket. Error: %s", socketIp.c_str(), error.ToString().c_str()); + } + } + + if (m_settings.doExplicitTcpNoDelay) // Set TCP_NODELAY. + { + IO::NetworkError error = worldSocket->m_socket.SetNativeSocketOption_NoDelay(true); + if (error) + { // We don't close the socket, since its basically just a "warning" I guess. + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] Failed to set NoDelay option on socket. Error: %s", socketIp.c_str(), error.ToString().c_str()); + } + } + + // Check if the remote endpoint is actually a trusted proxy, so we can retrieve the real client ip + if (!m_settings.trustedProxyIps.empty() && std::find(m_settings.trustedProxyIps.begin(), m_settings.trustedProxyIps.end(), socketIp) != m_settings.trustedProxyIps.end()) + { + // parse proxy header + ProxyProtocol::ReadProxyV2Handshake(&(worldSocket->m_socket), [worldSocket](nonstd::expected const& maybeIp) + { + if (!maybeIp.has_value()) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] Failed to parse proxy header. Error: %s", worldSocket->m_socket.GetRemoteIpString().c_str(), maybeIp.error().ToString().c_str()); + return; // implicit close() + } + worldSocket->m_remoteIpAddressStringAfterProxy = maybeIp.value().ToString(); + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] Connection accepted (proxy ip: %s)", worldSocket->GetRemoteIpString().c_str(), worldSocket->m_socket.GetRemoteIpString().c_str()); + worldSocket->Start(); + }); + } + else + { + // no proxy, we can start directly + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] Connection accepted", worldSocket->GetRemoteIpString().c_str()); + worldSocket->Start(); + } +} -WorldSocketMgr* WorldSocketMgr::Instance() +IO::IoContext* WorldSocketMgr::GetLeastUsedIoContext() { - return ACE_Singleton::instance(); + // TODO: Currently the main shared ioCtx is used + // but we could do a thread affinity here, just like TrinityCore does it. + // See `Trinity::SocketMgr::SelectThreadWithMinConnections()` + return m_ioContext; } diff --git a/src/game/Protocol/WorldSocketMgr.h b/src/game/Protocol/WorldSocketMgr.h index b0288ea41ee..e245112a469 100644 --- a/src/game/Protocol/WorldSocketMgr.h +++ b/src/game/Protocol/WorldSocketMgr.h @@ -3,6 +3,7 @@ * Copyright (C) 2009-2011 MaNGOSZero * Copyright (C) 2011-2016 Nostalrius * Copyright (C) 2016-2017 Elysium Project + * Copyright (C) 2017-2024 VMaNGOS Project * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -19,31 +20,44 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -/** \addtogroup u2w User to World Communication - * @{ - * \file WorldSocketMgr.h - * \author Derex - */ +#ifndef MANGOS_GAME_SERVER_WORLDSOCKETMGR_H +#define MANGOS_GAME_SERVER_WORLDSOCKETMGR_H -#ifndef __WORLDSOCKETMGR_H -#define __WORLDSOCKETMGR_H +#include +#include +#include "Policies/Singleton.h" +#include "IO/Context/IoContext.h" +#include "IO/Networking/AsyncSocketAcceptor.h" -#include "MangosSocketMgr.h" -#include "ace/Singleton.h" -#include "ace/Thread_Mutex.h" class WorldSocket; -// Manages all sockets connected to peers and network threads -class WorldSocketMgr: public MangosSocketMgr +struct WorldSocketMgrOptions +{ + std::string bindIp; + uint16 bindPort; + int socketOutByteBufferSize; + bool doExplicitTcpNoDelay; + std::vector trustedProxyIps; +}; + +class WorldSocketMgr : public MaNGOS::Singleton> { - public: - friend class ACE_Singleton; - friend class WorldSocket; +public: + explicit WorldSocketMgr() = default; + + /// Will return true start was okay + bool StartWorldNetworking(IO::IoContext* ioCtx, WorldSocketMgrOptions const& options); + void StopWorldNetworking(); + void OnNewClientConnected(IO::Networking::SocketDescriptor socketDescriptor); + +private: + IO::IoContext* GetLeastUsedIoContext(); - static WorldSocketMgr* Instance(); + IO::IoContext* m_ioContext{nullptr}; + std::unique_ptr m_listener{nullptr}; + WorldSocketMgrOptions m_settings{}; }; -#define sWorldSocketMgr WorldSocketMgr::Instance() +#define sWorldSocketMgr MaNGOS::Singleton::Instance() -#endif -// @} +#endif //MANGOS_GAME_SERVER_WORLDSOCKETMGR_H diff --git a/src/game/ScriptMgr.cpp b/src/game/ScriptMgr.cpp index 3b91fc90437..4442ff9da54 100644 --- a/src/game/ScriptMgr.cpp +++ b/src/game/ScriptMgr.cpp @@ -2332,7 +2332,7 @@ void ScriptMgr::LoadEscortData() if (pResult) { - barGoLink bar(pResult->GetRowCount()); + BarGoLink bar(pResult->GetRowCount()); do { bar.step(); @@ -2377,7 +2377,7 @@ void ScriptMgr::LoadEscortData() } else { - barGoLink bar(1); + BarGoLink bar(1); bar.step(); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ""); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ">> Loaded 0 Escort Creature Data. DB table `script_escort_data` is empty."); diff --git a/src/game/SniffFile.cpp b/src/game/SniffFile.cpp index 0aed8772041..9e6688b2ca9 100644 --- a/src/game/SniffFile.cpp +++ b/src/game/SniffFile.cpp @@ -15,6 +15,7 @@ */ #include "SniffFile.h" +#include "Errors.h" #include "Log.h" SniffFile::SniffFile(FILE* pFile) : m_file(pFile) diff --git a/src/game/Spells/Spell.cpp b/src/game/Spells/Spell.cpp index bd0ffc60cf9..bf12287dcc8 100644 --- a/src/game/Spells/Spell.cpp +++ b/src/game/Spells/Spell.cpp @@ -6034,7 +6034,7 @@ SpellCastResult Spell::CheckCast(bool strict) // Swiftmend if (m_spellInfo->Id == 18562) // future versions have special aura state for this { - if (!target->GetAura(SPELL_AURA_PERIODIC_HEAL, SPELLFAMILY_DRUID, UI64LIT(0x50))) + if (!target->GetAura(SPELL_AURA_PERIODIC_HEAL, SPELLFAMILY_DRUID, uint64(0x50))) return SPELL_FAILED_TARGET_AURASTATE; } #endif diff --git a/src/game/Spells/SpellAuras.cpp b/src/game/Spells/SpellAuras.cpp index 7417d796cfa..2f48e9d0255 100644 --- a/src/game/Spells/SpellAuras.cpp +++ b/src/game/Spells/SpellAuras.cpp @@ -2240,7 +2240,7 @@ void Aura::HandleAuraDummy(bool apply, bool Real) { if (apply) { - SpellModifier* mod = new SpellModifier(SPELLMOD_RESIST_MISS_CHANCE, SPELLMOD_FLAT, m_modifier.m_amount, GetId(), UI64LIT(0x0000000000000100)); + SpellModifier* mod = new SpellModifier(SPELLMOD_RESIST_MISS_CHANCE, SPELLMOD_FLAT, m_modifier.m_amount, GetId(), uint64(0x0000000000000100)); pPlayer->AddSpellMod(mod, true); } else @@ -2259,7 +2259,7 @@ void Aura::HandleAuraDummy(bool apply, bool Real) { if (apply) { - SpellModifier *mod = new SpellModifier(SPELLMOD_RESIST_MISS_CHANCE, SPELLMOD_FLAT, m_modifier.m_amount, GetId(), UI64LIT(0x0000000000000008)); + SpellModifier *mod = new SpellModifier(SPELLMOD_RESIST_MISS_CHANCE, SPELLMOD_FLAT, m_modifier.m_amount, GetId(), uint64(0x0000000000000008)); pPlayer->AddSpellMod(mod, true); } else @@ -6949,7 +6949,7 @@ SpellAuraHolder::SpellAuraHolder(SpellEntry const* spellproto, Unit* target, Uni else { // remove this assert when not unit casters will be supported - MANGOS_ASSERT(caster->isType(TYPEMASK_UNIT)) + MANGOS_ASSERT(caster->isType(TYPEMASK_UNIT)); m_casterGuid = caster->GetObjectGuid(); } diff --git a/src/game/Spells/SpellEntry.cpp b/src/game/Spells/SpellEntry.cpp index eb82de95f5f..de02fde4f7a 100644 --- a/src/game/Spells/SpellEntry.cpp +++ b/src/game/Spells/SpellEntry.cpp @@ -62,7 +62,7 @@ SpellSpecific Spells::GetSpellSpecific(uint32 spellId) case SPELLFAMILY_MAGE: { // family flags 18(Molten), 25(Frost/Ice), 28(Mage) - if (spellInfo->SpellFamilyFlags & UI64LIT(0x12000000)) + if (spellInfo->SpellFamilyFlags & uint64(0x12000000)) return SPELL_MAGE_ARMOR; if (spellInfo->EffectApplyAuraName[EFFECT_INDEX_0] == SPELL_AURA_MOD_CONFUSE && spellInfo->PreventionType == SPELL_PREVENTION_TYPE_SILENCE) @@ -72,7 +72,7 @@ SpellSpecific Spells::GetSpellSpecific(uint32 spellId) } case SPELLFAMILY_WARRIOR: { - if (spellInfo->SpellFamilyFlags & UI64LIT(0x00008000010000)) + if (spellInfo->SpellFamilyFlags & uint64(0x00008000010000)) return SPELL_POSITIVE_SHOUT; break; @@ -110,10 +110,10 @@ SpellSpecific Spells::GetSpellSpecific(uint32 spellId) if (spellInfo->IsSealSpell()) return SPELL_SEAL; - if (spellInfo->IsFitToFamilyMask(UI64LIT(0x0000000010000100))) + if (spellInfo->IsFitToFamilyMask(uint64(0x0000000010000100))) return SPELL_BLESSING; - if ((spellInfo->IsFitToFamilyMask(UI64LIT(0x0000000020180400))) && spellInfo->baseLevel != 0) + if ((spellInfo->IsFitToFamilyMask(uint64(0x0000000020180400))) && spellInfo->baseLevel != 0) return SPELL_JUDGEMENT; // Old Judgement of Command @@ -640,7 +640,7 @@ float SpellEntry::CalculateCustomCoefficient(WorldObject const* caster, DamageEf case SPELLFAMILY_PALADIN: { // Seal of Righteousness - if (IsFitToFamilyMask(UI64LIT(0x0000000008000000)) && SpellIconID == 25) + if (IsFitToFamilyMask(uint64(0x0000000008000000)) && SpellIconID == 25) { coeff = 0.10f; float speed = BASE_ATTACK_TIME; @@ -670,7 +670,7 @@ float SpellEntry::CalculateCustomCoefficient(WorldObject const* caster, DamageEf return coeff; // Chain Lightning / Chain Heal / Healing Wave (T1 8/8 bonus) - if (IsFitToFamilyMask(UI64LIT(0x00000000142))) + if (IsFitToFamilyMask(uint64(0x00000000142))) { float multiplier = DmgMultiplier[0]; diff --git a/src/game/Spells/SpellMgr.cpp b/src/game/Spells/SpellMgr.cpp index a6bf5f86285..d79f2c1e30f 100644 --- a/src/game/Spells/SpellMgr.cpp +++ b/src/game/Spells/SpellMgr.cpp @@ -999,7 +999,7 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons return false; // Improved Hamstring -> Hamstring (multi-family check) - if ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x2)) && spellInfo_1->Id == 23694) + if ((spellInfo_2->SpellFamilyFlags & uint64(0x2)) && spellInfo_1->Id == 23694) return false; break; } @@ -1025,7 +1025,7 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons return false; // Improved Wing Clip -> Wing Clip (multi-family check) - if ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x40)) && spellInfo_1->Id == 19229) + if ((spellInfo_2->SpellFamilyFlags & uint64(0x40)) && spellInfo_1->Id == 19229) return false; break; } @@ -1047,18 +1047,18 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons return false; // Blizzard & Chilled (and some other stacked with blizzard spells - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x80)) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x100000))) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x80)) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x100000)))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x80)) && (spellInfo_2->SpellFamilyFlags & uint64(0x100000))) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x80)) && (spellInfo_1->SpellFamilyFlags & uint64(0x100000)))) return false; // Blink & Improved Blink - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x0000000000010000)) && (spellInfo_2->SpellVisual == 72 && spellInfo_2->SpellIconID == 1499)) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x0000000000010000)) && (spellInfo_1->SpellVisual == 72 && spellInfo_1->SpellIconID == 1499))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x0000000000010000)) && (spellInfo_2->SpellVisual == 72 && spellInfo_2->SpellIconID == 1499)) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x0000000000010000)) && (spellInfo_1->SpellVisual == 72 && spellInfo_1->SpellIconID == 1499))) return false; // Fireball & Pyroblast (Dots) - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x1)) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x400000))) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x1)) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x400000)))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x1)) && (spellInfo_2->SpellFamilyFlags & uint64(0x400000))) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x1)) && (spellInfo_1->SpellFamilyFlags & uint64(0x400000)))) return false; // Arcane Missiles @@ -1099,8 +1099,8 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons if (spellInfo_2->SpellFamilyName == SPELLFAMILY_WARRIOR) { // Rend and Deep Wound - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x20)) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x1000000000))) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x20)) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x1000000000)))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x20)) && (spellInfo_2->SpellFamilyFlags & uint64(0x1000000000))) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x20)) && (spellInfo_1->SpellFamilyFlags & uint64(0x1000000000)))) return false; // Battle Shout and Rampage @@ -1114,7 +1114,7 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons } // Hamstring -> Improved Hamstring (multi-family check) - if ((spellInfo_1->SpellFamilyFlags & UI64LIT(0x2)) && spellInfo_2->Id == 23694) + if ((spellInfo_1->SpellFamilyFlags & uint64(0x2)) && spellInfo_2->Id == 23694) return false; // Defensive Stance and Scroll of Protection (multi-family check) @@ -1140,13 +1140,13 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons return false; //Devouring Plague and Shadow Vulnerability - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x2000000)) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x800000000))) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x2000000)) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x800000000)))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x2000000)) && (spellInfo_2->SpellFamilyFlags & uint64(0x800000000))) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x2000000)) && (spellInfo_1->SpellFamilyFlags & uint64(0x800000000)))) return false; //StarShards and Shadow Word: Pain - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x200000)) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x8000))) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x200000)) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x8000)))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x200000)) && (spellInfo_2->SpellFamilyFlags & uint64(0x8000))) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x200000)) && (spellInfo_1->SpellFamilyFlags & uint64(0x8000)))) return false; } break; @@ -1158,8 +1158,8 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons return false; //Omen of Clarity and Blood Frenzy - if (((spellInfo_1->SpellFamilyFlags == UI64LIT(0x0) && spellInfo_1->SpellIconID == 108) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x20000000000000))) || - ((spellInfo_2->SpellFamilyFlags == UI64LIT(0x0) && spellInfo_2->SpellIconID == 108) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x20000000000000)))) + if (((spellInfo_1->SpellFamilyFlags == uint64(0x0) && spellInfo_1->SpellIconID == 108) && (spellInfo_2->SpellFamilyFlags & uint64(0x20000000000000))) || + ((spellInfo_2->SpellFamilyFlags == uint64(0x0) && spellInfo_2->SpellIconID == 108) && (spellInfo_1->SpellFamilyFlags & uint64(0x20000000000000)))) return false; } @@ -1177,13 +1177,13 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons if (spellInfo_2->SpellFamilyName == SPELLFAMILY_HUNTER) { // Rapid Fire & Quick Shots - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x20)) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x20000000000))) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x20)) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x20000000000)))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x20)) && (spellInfo_2->SpellFamilyFlags & uint64(0x20000000000))) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x20)) && (spellInfo_1->SpellFamilyFlags & uint64(0x20000000000)))) return false; // Serpent Sting & (Immolation/Explosive Trap Effect) - if (((spellInfo_1->SpellFamilyFlags & UI64LIT(0x4)) && (spellInfo_2->SpellFamilyFlags & UI64LIT(0x00000004000))) || - ((spellInfo_2->SpellFamilyFlags & UI64LIT(0x4)) && (spellInfo_1->SpellFamilyFlags & UI64LIT(0x00000004000)))) + if (((spellInfo_1->SpellFamilyFlags & uint64(0x4)) && (spellInfo_2->SpellFamilyFlags & uint64(0x00000004000))) || + ((spellInfo_2->SpellFamilyFlags & uint64(0x4)) && (spellInfo_1->SpellFamilyFlags & uint64(0x00000004000)))) return false; // Bestial Wrath @@ -1192,7 +1192,7 @@ bool SpellMgr::IsNoStackSpellDueToSpell(uint32 spellId_1, uint32 spellId_2) cons } // Wing Clip -> Improved Wing Clip (multi-family check) - if ((spellInfo_1->SpellFamilyFlags & UI64LIT(0x40)) && spellInfo_2->Id == 19229) + if ((spellInfo_1->SpellFamilyFlags & uint64(0x40)) && spellInfo_2->Id == 19229) return false; // Concussive Shot and Imp. Concussive Shot (multi-family check) @@ -2843,9 +2843,9 @@ void SpellMgr::CheckUsedSpells(char const* table) continue; } - if (familyMask != UI64LIT(0xFFFFFFFFFFFFFFFF)) + if (familyMask != uint64(0xFFFFFFFFFFFFFFFF)) { - if (familyMask == UI64LIT(0x0000000000000000)) + if (familyMask == uint64(0x0000000000000000)) { if (spellEntry->SpellFamilyFlags) { @@ -2928,9 +2928,9 @@ void SpellMgr::CheckUsedSpells(char const* table) if (family >= 0 && spellEntry->SpellFamilyName != uint32(family)) continue; - if (familyMask != UI64LIT(0xFFFFFFFFFFFFFFFF)) + if (familyMask != uint64(0xFFFFFFFFFFFFFFFF)) { - if (familyMask == UI64LIT(0x0000000000000000)) + if (familyMask == uint64(0x0000000000000000)) { if (spellEntry->SpellFamilyFlags) continue; diff --git a/src/game/Spells/SpellMgr.h b/src/game/Spells/SpellMgr.h index dc437d0d83d..15c0c72c5dc 100644 --- a/src/game/Spells/SpellMgr.h +++ b/src/game/Spells/SpellMgr.h @@ -31,6 +31,7 @@ #include "DBCStores.h" #include "SQLStorages.h" #include "SpellEntry.h" +#include "Errors.h" #include #include diff --git a/src/game/World.cpp b/src/game/World.cpp index 9e0a23108e6..98d71aa5b6e 100644 --- a/src/game/World.cpp +++ b/src/game/World.cpp @@ -82,6 +82,8 @@ #include "GuardMgr.h" #include "TransportMgr.h" #include "RealmZone.h" +#include "IO/Multithreading/CreateThread.h" + #include INSTANTIATE_SINGLETON_1(World); @@ -210,10 +212,10 @@ void World::Shutdown() sAnticheatMgr->StopWardenUpdateThread(); } -// Find a session by its id -WorldSession* World::FindSession(uint32 id) const +/// Find a session by its accountId. Might return nullptr if not found. +WorldSession* World::FindSession(uint32 accountId) const { - SessionMap::const_iterator itr = m_sessions.find(id); + SessionMap::const_iterator itr = m_sessions.find(accountId); if (itr != m_sessions.end()) return itr->second; // also can return nullptr for kicked session @@ -221,11 +223,11 @@ WorldSession* World::FindSession(uint32 id) const return nullptr; } -// Remove a given session -bool World::RemoveSession(uint32 id) +/// Remove a given session by its accountId +bool World::RemoveSession(uint32 accountId) { // Find the session, kick the user, but we can't delete session at this moment to prevent iterator invalidation - SessionMap::const_iterator itr = m_sessions.find(id); + SessionMap::const_iterator itr = m_sessions.find(accountId); if (itr != m_sessions.end() && itr->second) { @@ -321,7 +323,7 @@ void World::AddSession_(WorldSession* s) packet << uint8(AUTH_OK); packet << uint32(0); // BillingTimeRemaining // BillingPlanFlags - packet << uint8(s->HasTrialRestrictions() ? (BILLING_FLAG_TRIAL | BILLING_FLAG_RESTRICTED) : BILLING_FLAG_NONE); + packet << uint8(s->HasTrialRestrictions() ? (BILLING_FLAG_TRIAL | BILLING_FLAG_RESTRICTED) : BILLING_FLAG_NONE); packet << uint32(0); // BillingTimeRested s->SendPacket(&packet); @@ -368,7 +370,7 @@ void World::AddQueuedSession(WorldSession* sess) packet << uint8(AUTH_WAIT_QUEUE); packet << uint32(0); // BillingTimeRemaining // BillingPlanFlags - packet << uint8(sess->HasTrialRestrictions() ? (BILLING_FLAG_TRIAL | BILLING_FLAG_RESTRICTED) : BILLING_FLAG_NONE); + packet << uint8(sess->HasTrialRestrictions() ? (BILLING_FLAG_TRIAL | BILLING_FLAG_RESTRICTED) : BILLING_FLAG_NONE); packet << uint32(0); // BillingTimeRested packet << uint32(GetQueuedSessionPos(sess)); // position in queue sess->SendPacket(&packet); @@ -408,7 +410,7 @@ bool World::RemoveQueuedSession(WorldSession* sess) uint32 loggedInSessions = uint32(m_sessions.size() - m_QueuedSessions.size()); if (loggedInSessions > getConfig(CONFIG_UINT32_PLAYER_HARD_LIMIT)) return found; - + // accept first in queue if ((!m_playerLimit || (int32)sessions <= m_playerLimit) && !m_QueuedSessions.empty()) { @@ -1065,7 +1067,7 @@ void World::LoadConfigSettings(bool reload) setConfig(CONFIG_UINT32_PACKET_BCAST_THREADS, "Network.PacketBroadcast.Threads", 0); setConfig(CONFIG_UINT32_PACKET_BCAST_FREQUENCY, "Network.PacketBroadcast.Frequency", 50); setConfig(CONFIG_UINT32_PBCAST_DIFF_LOWER_VISIBILITY_DISTANCE, "Network.PacketBroadcast.ReduceVisDistance.DiffAbove", 0); - + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "* Anticrash : options 0x%x rearm after %usec", getConfig(CONFIG_UINT32_ANTICRASH_OPTIONS), getConfig(CONFIG_UINT32_ANTICRASH_REARM_TIMER) / 1000); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "* Pathfinding : [%s]", getConfig(CONFIG_BOOL_MMAP_ENABLED) ? "ON" : "OFF"); @@ -1866,10 +1868,7 @@ void World::SetInitialWorldSettings() if (GetWowPatch() >= WOW_PATCH_103 || !getConfig(CONFIG_BOOL_ACCURATE_LFG)) { - m_lfgQueueThread.reset(new std::thread([&]() - { - m_lfgQueue.Update(); - })); + m_lfgQueueThread = IO::Multithreading::CreateThreadPtr("LfgUpdate", [&] { m_lfgQueue.Update(); }); } sAnticheatMgr->StartWardenUpdateThread(); @@ -1878,8 +1877,8 @@ void World::SetInitialWorldSettings() std::make_unique(getConfig(CONFIG_UINT32_PACKET_BCAST_THREADS), std::chrono::milliseconds(getConfig(CONFIG_UINT32_PACKET_BCAST_FREQUENCY))); - m_charDbWorkerThread.reset(new std::thread(&CharactersDatabaseWorkerThread)); - m_asyncPacketsThread.reset(new std::thread(&World::ProcessAsyncPackets, this)); + m_charDbWorkerThread = IO::Multithreading::CreateThreadPtr("CharDB", [](){ CharactersDatabaseWorkerThread(); }); + m_asyncPacketsThread = IO::Multithreading::CreateThreadPtr("AsyncPacket", [this](){ ProcessAsyncPackets(); }); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ""); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "=========================================================="); @@ -1972,7 +1971,7 @@ void World::Update(uint32 diff) m_currentMSTime = WorldTimer::getMSTime(); m_currentTime = std::chrono::time_point_cast(Clock::now()); m_currentDiff = diff; - + // Update the different timers for (auto& timer : m_timers) { @@ -2027,14 +2026,14 @@ void World::Update(uint32 diff) // Update objects (maps, transport, creatures,...) uint32 updateMapSystemTime = WorldTimer::getMSTime(); - + // TODO: find a better place for this if (!m_updateThreads) { - m_updateThreads = std::unique_ptr( new ThreadPool( + m_updateThreads = std::unique_ptr(new ThreadPool( + "WorldUpdate", getConfig(CONFIG_UINT32_ASYNC_TASKS_THREADS_COUNT), - ThreadPool::ClearMode::UPPON_COMPLETION) - ); + ThreadPool::ClearMode::UPPON_COMPLETION)); m_updateThreads->start>(); } std::unique_lock lock(m_asyncTaskQueueMutex); @@ -2042,7 +2041,7 @@ void World::Update(uint32 diff) std::future job = m_updateThreads->processWorkload(_asyncTasksBusy); _asyncTasks.clear(); lock.unlock(); - + sMapMgr.Update(diff); sBattleGroundMgr.Update(diff); sGuardMgr.Update(diff); @@ -2492,7 +2491,7 @@ class BanQueryHolder : public SqlQueryHolder public: BanQueryHolder(BanMode mode, std::string banTarget, uint32 duration, std::string reason, uint32 realmId, std::string author, uint32 authorAccountId) - : m_mode(mode), m_duration(duration), m_reason(reason), m_realmId(realmId), + : m_mode(mode), m_duration(duration), m_reason(reason), m_realmId(realmId), m_author(author), m_banTarget(banTarget), m_accountId(authorAccountId) { } @@ -2828,7 +2827,7 @@ void World::UpdateSessions(uint32 diff) { if (pSession->PlayerLoading()) sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "[CRASH] World::UpdateSession attempt to delete session %u loading a player.", pSession->GetAccountId()); - + AccountPlayHistory& history = m_accountsPlayHistory[pSession->GetAccountId()]; if (!RemoveQueuedSession(pSession)) history.logoutTime = timeNow; diff --git a/src/game/World.h b/src/game/World.h index 5941df91fd3..50209fbbd5f 100644 --- a/src/game/World.h +++ b/src/game/World.h @@ -35,6 +35,7 @@ #include "WorldPacket.h" #include "Multithreading/Messager.h" #include "LFGQueue.h" +#include "LockedQueue.h" #include #include @@ -707,9 +708,9 @@ class World typedef std::unordered_map SessionMap; typedef std::set SessionSet; SessionMap GetAllSessions() { return m_sessions; } - WorldSession* FindSession(uint32 id) const; - void AddSession(WorldSession* s); - bool RemoveSession(uint32 id); + WorldSession* FindSession(uint32 accountId) const; + void AddSession(WorldSession* session); + bool RemoveSession(uint32 accountId); // Get the number of current active sessions void UpdateMaxSessionCounters(); uint32 GetActiveAndQueuedSessionCount() const { return m_sessions.size(); } @@ -956,7 +957,7 @@ class World int32 m_timeZoneOffset; IntervalTimer m_timers[WUPDATE_COUNT]; - SessionMap m_sessions; + SessionMap m_sessions; // Sessions by accountId SessionSet m_disconnectedSessions; std::map m_accountsPlayHistory; bool CanSkipQueue(WorldSession const* session); diff --git a/src/game/WorldSession.cpp b/src/game/WorldSession.cpp index 0c89ab9e6db..956c15b5e51 100644 --- a/src/game/WorldSession.cpp +++ b/src/game/WorldSession.cpp @@ -23,7 +23,7 @@ \ingroup u2w */ -#include "WorldSocket.h" // must be first to make ACE happy with ACE includes in it +#include "WorldSocket.h" #include "Common.h" #include "Database/DatabaseEnv.h" #include "Log.h" @@ -70,22 +70,17 @@ bool MapSessionFilter::Process(std::unique_ptr const& packet) static uint32 g_sessionCounter = 0; // WorldSession constructor -WorldSession::WorldSession(uint32 id, WorldSocket *sock, AccountTypes sec, time_t mute_time, LocaleConstant locale) : +WorldSession::WorldSession(uint32 id, std::shared_ptr sock, AccountTypes sec, time_t mute_time, LocaleConstant locale) : m_guid(g_sessionCounter++), m_muteTime(mute_time), m_connected(true), m_disconnectTimer(0), m_who_recvd(false), m_ah_list_recvd(false), m_accountFlags(0), m_idleTime(WorldTimer::getMSTime()), _player(nullptr), m_socket(sock), m_security(sec), m_accountId(id), m_exhaustionState(0), m_createTime(time(nullptr)), m_previousPlayTime(0), m_logoutTime(0), m_inQueue(false), m_playerLoading(false), m_playerLogout(false), m_playerRecentlyLogout(false), m_playerSave(false), m_sessionDbcLocale(sWorld.GetAvailableDbcLocale(locale)), m_sessionDbLocaleIndex(sObjectMgr.GetIndexForLocale(locale)), m_latency(0), m_tutorialState(TUTORIALDATA_UNCHANGED), m_warden(nullptr), m_cheatData(nullptr), m_bot(nullptr), m_clientOS(CLIENT_OS_UNKNOWN), m_clientPlatform(CLIENT_PLATFORM_UNKNOWN), m_gameBuild(0), m_verifiedEmail(true), - m_charactersCount(10), m_characterMaxLevel(0), m_lastPubChannelMsgTime(0), m_moveRejectTime(0), m_masterPlayer(nullptr) + m_charactersCount(10), m_characterMaxLevel(0), m_lastPubChannelMsgTime(0), m_moveRejectTime(0), m_masterPlayer(nullptr), m_receivedPacketType{}, + m_floodPacketsCount{}, m_tutorials{} { - if (sock) - { - m_address = sock->GetRemoteAddress(); - sock->AddReference(); - } - else - m_address = ""; + m_remoteIpAddress = sock ? sock->GetRemoteIpString() : ""; } // WorldSession destructor @@ -98,9 +93,8 @@ WorldSession::~WorldSession() // If have unclosed socket, close it if (m_socket) { - m_socket->CloseSocket(); - m_socket->RemoveReference(); - m_socket = nullptr; + m_socket->FinalizeSession(); + m_socket = nullptr; // <-- technically this is unnecessary, since we are in the destructor that will destruct all other members soon anyway } // empty incoming packet queue @@ -176,14 +170,13 @@ void WorldSession::SendPacketImpl(WorldPacket const* packet) sendLastPacketBytes = packet->wpos(); // wpos is real written size } -#endif // !_DEBUG +#endif // _DEBUG // sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "[%s]Send packet : %u|0x%x (%s)", GetPlayerName(), packet->GetOpcode(), packet->GetOpcode(), LookupOpcodeName(packet->GetOpcode())); if (m_sniffFile) m_sniffFile->WritePacket(*packet, false, time(nullptr)); - if (m_socket->SendPacket(*packet) == -1) - m_socket->CloseSocket(); + m_socket->SendPacket(*packet); } #if SUPPORTED_CLIENT_BUILD > CLIENT_BUILD_1_7_1 @@ -432,10 +425,10 @@ bool WorldSession::Update(PacketFilter& updater) return false; } - // Cleanup socket pointer if need - if (m_socket && m_socket->IsClosed()) + // Cleanup socket pointer if needed + if (m_socket && m_socket->IsClosing()) { - m_socket->RemoveReference(); + m_socket->FinalizeSession(); m_socket = nullptr; if (m_warden) @@ -492,7 +485,7 @@ bool WorldSession::Update(PacketFilter& updater) bool WorldSession::CanProcessPackets() const { - return ((m_socket && !m_socket->IsClosed()) || (_player && (m_bot || sPlayerBotMgr.IsChatBot(_player->GetGUIDLow())))); + return ((m_socket && !m_socket->IsClosing()) || (_player && (m_bot || sPlayerBotMgr.IsChatBot(_player->GetGUIDLow())))); } void WorldSession::ProcessPackets(PacketFilter& updater) diff --git a/src/game/WorldSession.h b/src/game/WorldSession.h index c7669aab5d5..e42a0dcf9f5 100644 --- a/src/game/WorldSession.h +++ b/src/game/WorldSession.h @@ -33,6 +33,7 @@ #include "ClientDefines.h" #include "Crypto/BigNumber.h" #include "AccountData.h" +#include "LockedQueue.h" struct ItemPrototype; struct AuctionEntry; @@ -278,7 +279,7 @@ class WorldSession { friend class CharacterHandler; public: - WorldSession(uint32 id, WorldSocket *sock, AccountTypes sec, time_t mute_time, LocaleConstant locale); + WorldSession(uint32 id, std::shared_ptr sock, AccountTypes sec, time_t mute_time, LocaleConstant locale); ~WorldSession(); uint32 GetGUID() const { return m_guid; } @@ -304,17 +305,19 @@ class WorldSession Player* GetPlayer() const { return _player; } char const* GetPlayerName() const; void SetSecurity(AccountTypes security) { m_security = security; } - std::string const& GetRemoteAddress() const { return m_address; } + /// Might return "" if player bot + /// TODO rename me to GetRemoteIpString() when all the commits for native branch are done (otherwise too many files will be touched) + std::string const& GetRemoteAddress() const { return m_remoteIpAddress; } void SetPlayer(Player* plr) { _player = plr; } void SetMasterPlayer(MasterPlayer* plr) { m_masterPlayer = plr; } void LoginPlayer(ObjectGuid playerGuid); - WorldSocket* GetSocket() { return m_socket; } + std::shared_ptr GetSocket() { return m_socket; } // Session in auth.queue currently void SetInQueue(bool state) { m_inQueue = state; } // Player online / socket offline system - void SetDisconnectedSession(); // Remove from World::m_session. Used when an account gets disconnected. + void SetDisconnectedSession(); // Remove from World::m_Session. Used when an account gets disconnected. bool UpdateDisconnected(uint32 diff); bool IsConnected() const { return m_connected; } void KickDisconnectedFromWorld() { m_disconnectTimer = 0; } @@ -883,8 +886,8 @@ class WorldSession void LogUnprocessedTail(WorldPacket* packet); uint32 const m_guid; // unique identifier for each session - WorldSocket* m_socket; - std::string m_address; + std::shared_ptr m_socket; + std::string m_remoteIpAddress; // might also be "" LockedQueue, std::mutex> m_recvQueue[PACKET_PROCESS_MAX_TYPE]; bool m_receivedPacketType[PACKET_PROCESS_MAX_TYPE]; uint32 m_floodPacketsCount[FLOOD_MAX_OPCODES_TYPE]; diff --git a/src/game/pchdef.h b/src/game/pchdef.h index a272a801c87..e995e828cdd 100644 --- a/src/game/pchdef.h +++ b/src/game/pchdef.h @@ -1,5 +1,5 @@ //add here most rarely modified headers to speed up debug build compilation -#include "WorldSocket.h" // must be first to make ACE happy with ACE includes in it +#include "WorldSocket.h" #include "Common.h" #include "MapManager.h" @@ -11,4 +11,4 @@ #include "SharedDefines.h" #include "GuildMgr.h" #include "ObjectMgr.h" -#include "ScriptMgr.h" \ No newline at end of file +#include "ScriptMgr.h" diff --git a/src/game/vmap/DynamicTree.cpp b/src/game/vmap/DynamicTree.cpp index 5065416d320..65d38826aa6 100644 --- a/src/game/vmap/DynamicTree.cpp +++ b/src/game/vmap/DynamicTree.cpp @@ -22,6 +22,7 @@ #include "BIHWrap.h" #include "RegularGrid.h" #include "GameObjectModel.h" +#include "Errors.h" template<> struct HashTrait< GameObjectModel> { diff --git a/src/game/vmap/MapTree.cpp b/src/game/vmap/MapTree.cpp index 4f14618d932..c9132f80f7f 100644 --- a/src/game/vmap/MapTree.cpp +++ b/src/game/vmap/MapTree.cpp @@ -21,6 +21,7 @@ #include "VMapManager2.h" #include "VMapDefinitions.h" #include "WorldModel.h" +#include "Errors.h" #include #include diff --git a/src/game/vmap/RegularGrid.h b/src/game/vmap/RegularGrid.h index e438b6f947a..4bd91905559 100644 --- a/src/game/vmap/RegularGrid.h +++ b/src/game/vmap/RegularGrid.h @@ -24,6 +24,7 @@ #include #include #include +#include "Errors.h" using G3D::Vector2; using G3D::Vector3; diff --git a/src/mangosd/CMakeLists.txt b/src/mangosd/CMakeLists.txt index 5e3d398d614..afb9369d930 100644 --- a/src/mangosd/CMakeLists.txt +++ b/src/mangosd/CMakeLists.txt @@ -18,16 +18,16 @@ set(EXECUTABLE_NAME mangosd) set(EXECUTABLE_SRCS CliRunnable.h - MaNGOSsoap.h Master.h - RASocket.h WorldRunnable.h CliRunnable.cpp Main.cpp - MaNGOSsoap.cpp Master.cpp - RASocket.cpp WorldRunnable.cpp + remote/RemoteAccess/RASocket.h + remote/RemoteAccess/RASocket.cpp + remote/soap/MaNGOSsoap.h + remote/soap/MaNGOSsoap.cpp ) @@ -68,7 +68,6 @@ include_directories( ${CMAKE_SOURCE_DIR}/src/game/Threat ${CMAKE_SOURCE_DIR}/src/game/Transports ${CMAKE_SOURCE_DIR}/src/game/vmap - ${ACE_INCLUDE_DIR} ${MYSQL_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR} ) @@ -98,7 +97,6 @@ if(USE_SCRIPTS) framework g3dlite gsoap - ${ACE_LIBRARIES} ) else() target_link_libraries(${EXECUTABLE_NAME} @@ -108,7 +106,6 @@ else() framework g3dlite gsoap - ${ACE_LIBRARIES} ) endif() diff --git a/src/mangosd/CliRunnable.cpp b/src/mangosd/CliRunnable.cpp index 0a83cddd5cb..2234c22693e 100644 --- a/src/mangosd/CliRunnable.cpp +++ b/src/mangosd/CliRunnable.cpp @@ -30,6 +30,10 @@ #include "CliRunnable.h" #include "Database/DatabaseEnv.h" +#if PLATFORM == PLATFORM_APPLE +#include +#endif + void utf8print(void* /*arg*/, const char* str) { #if PLATFORM == PLATFORM_WINDOWS diff --git a/src/mangosd/MaNGOSsoap.cpp b/src/mangosd/MaNGOSsoap.cpp deleted file mode 100644 index 7a3739e24f3..00000000000 --- a/src/mangosd/MaNGOSsoap.cpp +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -#include "MaNGOSsoap.h" - -#define POOL_SIZE 5 - -void MaNGOSsoapRunnable::run() -{ - // create pool - SOAPWorkingThread pool; - pool.activate (THR_NEW_LWP | THR_JOINABLE, POOL_SIZE); - - struct soap soap; - int m, s; - soap_init(&soap); - soap_set_imode(&soap, SOAP_C_UTFSTRING); - soap_set_omode(&soap, SOAP_C_UTFSTRING); - m = soap_bind(&soap, m_host.c_str(), m_port, 100); - - // check every 3 seconds if world ended - soap.accept_timeout = 3; - - soap.recv_timeout = 5; - soap.send_timeout = 5; - if (m < 0) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MaNGOSsoap: couldn't bind to %s:%d", m_host.c_str(), m_port); - exit(-1); - } - - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "MaNGOSsoap: bound to http://%s:%d", m_host.c_str(), m_port); - - while(!World::IsStopped()) - { - s = soap_accept(&soap); - - if (s < 0) - { - // ran into an accept timeout - continue; - } - - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "MaNGOSsoap: accepted connection from IP=%d.%d.%d.%d", (int)(soap.ip>>24)&0xFF, (int)(soap.ip>>16)&0xFF, (int)(soap.ip>>8)&0xFF, (int)soap.ip&0xFF); - struct soap* thread_soap = soap_copy(&soap);// make a safe copy - - ACE_Message_Block *mb = new ACE_Message_Block(sizeof(struct soap*)); - ACE_OS::memcpy(mb->wr_ptr (), &thread_soap, sizeof(struct soap*)); - pool.putq(mb); - } - pool.msg_queue ()->deactivate (); - pool.wait (); - - soap_done(&soap); -} - -void SOAPWorkingThread::process_message (ACE_Message_Block *mb) -{ - ACE_TRACE (ACE_TEXT ("SOAPWorkingThread::process_message")); - - soap* thread_soap; // local copy of the soap instance for handling this message - ACE_OS::memcpy(&thread_soap, mb->rd_ptr(), sizeof(soap*)); - mb->release(); - - soap_serve(thread_soap); // handle request - - soap_free(thread_soap); // calls done() and clear/frees everything else -} - -/* -Code used for generating stubs: - -int ns1__executeCommand(char* command, char** result); -*/ -int ns1__executeCommand(soap* soap, char* command, char** result) -{ - // security check - if (!soap->userid || !soap->passwd) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "MaNGOSsoap: Client didn't provide login information"); - return 401; - } - - uint32 accountId = sAccountMgr.GetId(soap->userid); - if(!accountId) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "MaNGOSsoap: Client used invalid username '%s'", soap->userid); - return 401; - } - - if(!sAccountMgr.CheckPassword(accountId, soap->passwd)) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "MaNGOSsoap: invalid password for account '%s'", soap->userid); - return 401; - } - - if(sAccountMgr.GetSecurity(accountId) < SEC_ADMINISTRATOR) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "MaNGOSsoap: %s's gmlevel is too low", soap->userid); - return 403; - } - - if(!command || !*command) - return soap_sender_fault(soap, "Command mustn't be empty", "The supplied command was an empty string"); - - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "MaNGOSsoap: got command '%s'", command); - SOAPCommand connection; - - // commands are executed in the world thread. We have to wait for them to be completed - { - // CliCommandHolder will be deleted from world, accessing after queueing is NOT save - CliCommandHolder* cmd = new CliCommandHolder(accountId, SEC_CONSOLE, &connection, command, &SOAPCommand::print, &SOAPCommand::commandFinished); - sWorld.QueueCliCommand(cmd); - } - - // wait for callback to complete command - - int acc = connection.pendingCommands.acquire(); - if(acc) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MaNGOSsoap: Error while acquiring lock, acc = %i, errno = %u", acc, errno); - } - - // alright, command finished - - char* printBuffer = soap_strdup(soap, connection.m_printBuffer.c_str()); - if(connection.hasCommandSucceeded()) - { - *result = printBuffer; - return SOAP_OK; - } - else - return soap_sender_fault(soap, printBuffer, printBuffer); -} - - -void SOAPCommand::commandFinished(void* soapconnection, bool success) -{ - SOAPCommand* con = (SOAPCommand*)soapconnection; - con->setCommandSuccess(success); - con->pendingCommands.release(); -} - -//////////////////////////////////////////////////////////////////////////////// -// -// Namespace Definition Table -// -//////////////////////////////////////////////////////////////////////////////// - -struct Namespace namespaces[] = -{ { "SOAP-ENV", "http://schemas.xmlsoap.org/soap/envelope/" }, // must be first - { "SOAP-ENC", "http://schemas.xmlsoap.org/soap/encoding/" }, // must be second - { "xsi", "http://www.w3.org/1999/XMLSchema-instance", "http://www.w3.org/*/XMLSchema-instance" }, - { "xsd", "http://www.w3.org/1999/XMLSchema", "http://www.w3.org/*/XMLSchema" }, - { "ns1", "urn:MaNGOS" }, // "ns1" namespace prefix - { nullptr, nullptr } -}; diff --git a/src/mangosd/MaNGOSsoap.h b/src/mangosd/MaNGOSsoap.h deleted file mode 100644 index a1db7d8c9af..00000000000 --- a/src/mangosd/MaNGOSsoap.h +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -#ifndef _MANGOSSOAP_H -#define _MANGOSSOAP_H - -#include "Common.h" -#include "World.h" -#include "AccountMgr.h" -#include "Log.h" - -#include "soapH.h" -#include "soapStub.h" - -#include -#include - - -class MaNGOSsoapRunnable -{ - public: - void run(); - void setListenArguments(std::string host, uint16 port) - { - m_host = host; - m_port = port; - } - private: - std::string m_host; - uint16 m_port; -}; - -class SOAPWorkingThread : public ACE_Task -{ - public: - SOAPWorkingThread () - { } - - virtual int svc (void) - { - while (1) - { - ACE_Message_Block *mb = 0; - if (this->getq (mb) == -1) - { - ACE_DEBUG ((LM_INFO, - ACE_TEXT ("(%t) Shutting down\n"))); - break; - } - - // Process the message. - process_message (mb); - } - - return 0; - } - private: - void process_message (ACE_Message_Block *mb); -}; - - -class SOAPCommand -{ - public: - SOAPCommand(): - pendingCommands(0, USYNC_THREAD, "pendingCommands") - { - - } - ~SOAPCommand() - { - } - - void appendToPrintBuffer(const char* msg) - { - m_printBuffer += msg; - } - - ACE_Semaphore pendingCommands; - - void setCommandSuccess(bool val) - { - m_success = val; - } - bool hasCommandSucceeded() - { - return m_success; - } - - static void print(void* callbackArg, const char* msg) - { - ((SOAPCommand*)callbackArg)->appendToPrintBuffer(msg); - } - - static void commandFinished(void* callbackArg, bool success); - - bool m_success; - std::string m_printBuffer; -}; - -#endif diff --git a/src/mangosd/Main.cpp b/src/mangosd/Main.cpp index cc02792f482..3113dd4cb17 100644 --- a/src/mangosd/Main.cpp +++ b/src/mangosd/Main.cpp @@ -33,8 +33,7 @@ #include "revision.h" #include #include -#include -#include +#include "ArgparserForServer.h" #ifdef WIN32 #include "ServiceWin32.h" @@ -47,7 +46,7 @@ char serviceDescription[] = "Massive Network Game Object Server"; * 1 - running * 2 - paused */ -int m_ServiceStatus = -1; +volatile int m_ServiceStatus = -1; #else #include "PosixDaemon.h" #endif @@ -60,125 +59,58 @@ DatabaseType LogsDatabase; // Accessor to the l uint32 realmID; // Id of the realm std::string realmName; // Name of the realm -// Print out the usage string for this program on the console. -void usage(const char *prog) -{ - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Usage: \n %s []\n" - " -v, --version print version and exist\n\r" - " -c config_file use config_file as configuration file\n\r" - #ifdef WIN32 - " Running as service functions:\n\r" - " -s run run as service\n\r" - " -s install install service\n\r" - " -s uninstall uninstall service\n\r" - #else - " Running as daemon functions:\n\r" - " -s run run as daemon\n\r" - " -s stop stop daemon\n\r" - #endif - ,prog); -} - char const* g_mainLogFileName = "Server.log"; // Launch the mangos server extern int main(int argc, char **argv) { - // Command line parsing - char const* cfg_file = _MANGOSD_CONFIG; - char const *options = ":c:s:"; - - ACE_Get_Opt cmd_opts(argc, argv, options); - cmd_opts.long_option("version", 'v'); - - char serviceDaemonMode = '\0'; - - int option; - while ((option = cmd_opts()) != EOF) + ServerStartupArguments args; { - switch (option) - { - case 'c': - cfg_file = cmd_opts.opt_arg(); - break; - case 'v': - printf("Core revision: %s\n", _FULLVERSION); - return 0; - case 's': - { - const char *mode = cmd_opts.opt_arg(); - - if (!strcmp(mode, "run")) - serviceDaemonMode = 'r'; -#ifdef WIN32 - else if (!strcmp(mode, "install")) - serviceDaemonMode = 'i'; - else if (!strcmp(mode, "uninstall")) - serviceDaemonMode = 'u'; -#else - else if (!strcmp(mode, "stop")) - serviceDaemonMode = 's'; -#endif - else - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Runtime-Error: -%c unsupported argument %s", cmd_opts.opt_opt(), mode); - usage(argv[0]); - Log::WaitBeforeContinueIfNeed(); - return 1; - } - break; - } - case ':': - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Runtime-Error: -%c option requires an input argument", cmd_opts.opt_opt()); - usage(argv[0]); - Log::WaitBeforeContinueIfNeed(); - return 1; - default: - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Runtime-Error: bad format of commandline arguments"); - usage(argv[0]); - Log::WaitBeforeContinueIfNeed(); - return 1; - } - } + // parseResult is std::expected, where the error is the return code, that might be present when invalid args or "--help" is given + auto parseResult = ParseServerStartupArguments(argc, argv); + if (!parseResult) + return parseResult.error(); -#ifdef WIN32 // windows service command need execute before config read - switch (serviceDaemonMode) - { - case 'i': - if (WinServiceInstall()) - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Installing service"); - return 1; - case 'u': - if (WinServiceUninstall()) - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Uninstalling service"); - return 1; - case 'r': - WinServiceRun(); - break; + args = parseResult.value(); } -#endif - if (!sConfig.SetSource(cfg_file)) + if (args.configFilePath.empty()) + args.configFilePath = _MANGOSD_CONFIG; + + if (!sConfig.LoadFromFile(args.configFilePath)) // must be done before (linux) service init { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Could not find configuration file %s.", cfg_file); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Could not find or parse configuration file %s", args.configFilePath.c_str()); Log::WaitBeforeContinueIfNeed(); - return 1; + return EXIT_FAILURE; } - // Reads config for file names so needs to be after we set the config. sLog.OpenWorldLogFiles(); -#ifndef WIN32 // posix daemon commands need apply after config read - switch (serviceDaemonMode) + switch (args.inputServiceMode) { - case 'r': + case ServiceDaemonAction::NotSet: + break; +#ifdef WIN32 + // windows service command need execute before config read + case ServiceDaemonAction::Install: + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Installing service..."); + return WinServiceInstall() ? EXIT_SUCCESS : EXIT_FAILURE; + case ServiceDaemonAction::Uninstall: + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Uninstalling service..."); + return WinServiceUninstall() ? EXIT_SUCCESS : EXIT_FAILURE; + case ServiceDaemonAction::Start: + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Starting service..."); + return WinServiceRun() ? EXIT_SUCCESS : EXIT_FAILURE; +#else + // posix daemon commands need apply after config read + case ServiceDaemonAction::Start: startDaemon(); break; - case 's': + case ServiceDaemonAction::Stop: stopDaemon(); break; - } #endif + } sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Core revision: %s [world-daemon]", _FULLVERSION); sLog.Out(LOG_BASIC, LOG_LVL_BASIC, " to stop." ); @@ -195,7 +127,7 @@ extern int main(int argc, char **argv) " MM MMM http://getmangos.com\n" " MMMMMM\n\n"); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "VMaNGOS : https://github.com/vmangos"); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Using configuration file %s.", cfg_file); + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Using configuration file %s", sConfig.GetFilename().c_str()); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Alloc library: " MANGOS_ALLOC_LIB ""); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Core Revision: " _FULLVERSION); @@ -207,8 +139,6 @@ extern int main(int argc, char **argv) sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "WARNING: Minimal required version [OpenSSL 0.9.8k]"); } - sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "Using ACE: %s", ACE_VERSION); - // Set progress bars show mode BarGoLink::SetOutputState(sConfig.GetBoolDefault("ShowProgressBars", true)); diff --git a/src/mangosd/Master.cpp b/src/mangosd/Master.cpp index 988c19f60fc..01ecb661e07 100644 --- a/src/mangosd/Master.cpp +++ b/src/mangosd/Master.cpp @@ -23,11 +23,8 @@ \ingroup mangosd */ -#ifndef WIN32 - #include "PosixDaemon.h" -#endif +#include -#include "WorldSocketMgr.h" #include "Common.h" #include "Master.h" #include "WorldSocket.h" @@ -37,40 +34,46 @@ #include "Timer.h" #include "Policies/SingletonImp.h" #include "SystemConfig.h" -#include "revision.h" #include "Config/Config.h" #include "Database/DatabaseEnv.h" #include "CliRunnable.h" -#include "RASocket.h" +#include "remote/RemoteAccess/RASocket.h" +#include "remote/soap/MaNGOSsoap.h" #include "Util.h" -#include "MaNGOSsoap.h" #include "MassMailMgr.h" #include "DBCStores.h" +#include "WorldSocketMgr.h" +#include "IO/Context/IoContext.h" +#include "IO/Multithreading/CreateThread.h" +#include "IO/Networking/AsyncSocketAcceptor.h" +#include "IO/Timer/AsyncSystemTimer.h" + +#include "revision.h" #include "migrations_list.h" -#include -#include -#include +#ifndef WIN32 +#include "PosixDaemon.h" +#endif #include #ifdef WIN32 #include "ServiceWin32.h" -extern int m_ServiceStatus; +extern volatile int m_ServiceStatus; #endif -INSTANTIATE_SINGLETON_1( Master ); +INSTANTIATE_SINGLETON_1(Master); volatile uint32 Master::m_masterLoopCounter = 0; volatile bool Master::m_handleSigvSignals = false; void freezeDetector(uint32 _delaytime) { - if(!_delaytime) + if (!_delaytime) return; - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Starting up anti-freeze thread (%u seconds max stuck time)...",_delaytime/1000); + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Starting up anti-freeze thread (%u seconds max stuck time)...", _delaytime / 1000); uint32 loops = 0; uint32 lastchange = 0; - while(!World::IsStopped()) + while (!World::IsStopped()) { std::this_thread::sleep_for(std::chrono::seconds(1)); @@ -96,52 +99,25 @@ void freezeDetector(uint32 _delaytime) //sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Anti-freeze thread exiting without problems."); } -void remoteAccess() +std::unique_ptr SetupRemoteAccessServer(IO::IoContext* ioCtx) { - #if defined (ACE_HAS_EVENT_POLL) || defined (ACE_HAS_DEV_POLL) - - ACE_Dev_Poll_Reactor imp; - - imp.max_notify_iterations (128); - imp.restart (1); - - #else - - ACE_TP_Reactor imp; - imp.max_notify_iterations (128); - - #endif - - ACE_Reactor reactor(&imp, 1 /* 1= delete implementation so we don't have to care */); - - RASocket::Acceptor acceptor; - - uint16 raport = sConfig.GetIntDefault ("Ra.Port", 3443); - std::string stringip = sConfig.GetStringDefault ("Ra.IP", "0.0.0.0"); + std::string raBindIp = sConfig.GetStringDefault("Ra.IP", "0.0.0.0"); + uint16 raBindPort = sConfig.GetIntDefault("Ra.Port", 3443); - ACE_INET_Addr listen_addr(raport, stringip.c_str()); - - if (acceptor.open (listen_addr, &reactor, ACE_NONBLOCK) == -1) + std::unique_ptr raServer = IO::Networking::AsyncSocketAcceptor::CreateAndBindServer(ioCtx, raBindIp, raBindPort); + if (!raServer) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MaNGOS RA can not bind to port %d on %s", raport, stringip.c_str ()); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MaNGOS RA can not bind to port %d on %s", raBindPort, raBindIp.c_str()); + return nullptr; } - - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Starting Remote access listner on port %d on %s", raport, stringip.c_str ()); - - while (!reactor.reactor_event_loop_done()) + raServer->AutoAcceptSocketsUntilClose([ioCtx](IO::Networking::SocketDescriptor socketDescriptor) { - ACE_Time_Value interval (0, 10000); - - if (reactor.run_reactor_event_loop (interval) == -1) - break; + // Create a socket and attach it to our global ioCtx + std::make_shared(std::move(IO::Networking::AsyncSocket(ioCtx, std::move(socketDescriptor))))->Start(); + }); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Starting Remote access listener on %s:%d", raBindIp.c_str(), raBindPort); - if(World::IsStopped()) - { - acceptor.close(); - break; - } - } - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "RARunnable thread ended"); + return raServer; } Master::Master() @@ -156,18 +132,18 @@ Master::~Master() int Master::Run() { // worldd PID file creation - std::string pidfile = sConfig.GetStringDefault("PidFile", ""); - if(!pidfile.empty()) + std::string pidFilePath = sConfig.GetStringDefault("PidFile", ""); + if (!pidFilePath.empty()) { - uint32 pid = CreatePIDFile(pidfile); - if( !pid ) + uint32 pid = CreatePIDFile(pidFilePath); + if (!pid) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Cannot create PID file %s.\n", pidfile.c_str() ); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Cannot create PID file %s.\n", pidFilePath.c_str()); Log::WaitBeforeContinueIfNeed(); return 1; } - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Daemon PID: %u\n", pid ); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Daemon PID: %u\n", pid); } // Start the databases @@ -191,12 +167,30 @@ int Master::Run() sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "World server is running realm ID: %d Name: \"%s\"", realmID, realmName.c_str()); sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ""); + std::unique_ptr ioCtxUniquePtr = IO::IoContext::CreateIoContext(); + IO::IoContext* ioCtx = ioCtxUniquePtr.get(); + std::vector ioCtxRunners; + int ioNetworkThreadCount = sConfig.GetIntDefault("Network.Threads", 1); + if (ioNetworkThreadCount <= 0) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Config 'Network.Threads' must be greater than 0"); + World::StopNow(ERROR_EXIT_CODE); + return 1; + } + for (int32 i = 0; i < ioNetworkThreadCount; ++i) + { + ioCtxRunners.emplace_back(IO::Multithreading::CreateThread("IO[" + std::to_string(i) + "]", [ioCtx]() + { + ioCtx->RunUntilShutdown(); + })); + } + // Initialize the World sWorld.SetInitialWorldSettings(); - #ifndef WIN32 +#ifndef WIN32 detachDaemon(); - #endif +#endif // server loaded successfully => enable async DB requests // this is done to forbid any async transactions during server startup! CharacterDatabase.AllowAsyncTransactions(); @@ -208,8 +202,8 @@ int Master::Run() _HookSignals(); // Launch WorldRunnable thread - std::thread world_thread{WorldRunnable()}; - // world_thread.setPriority(ACE_Based::Highest); + std::thread world_thread = IO::Multithreading::CreateThread("WorldRunnable", WorldRunnable()); + // world_thread.setPriority(ACE_Based::Highest); // TODO // set realmbuilds depend on mangosd expected builds, and set server online { @@ -219,7 +213,7 @@ int Master::Run() LoginDatabase.PExecute("UPDATE `realmlist` SET `realmflags` = `realmflags` & ~(%u), `population` = 0, `realmbuilds` = '%s' WHERE `id` = '%u'", REALM_FLAG_OFFLINE, builds.c_str(), realmID); } - std::thread* cliThread = nullptr; + std::unique_ptr cliThread = nullptr; #ifdef WIN32 if (sConfig.GetBoolDefault("Console.Enable", true) && (m_ServiceStatus == -1)/* need disable console in service mode*/) @@ -228,38 +222,38 @@ int Master::Run() #endif { // Launch CliRunnable thread - cliThread = new std::thread(CliRunnable()); + cliThread = IO::Multithreading::CreateThreadPtr("CLI", CliRunnable()); } - std::thread* rar_thread = nullptr; - if (sConfig.GetBoolDefault ("Ra.Enable", false)) - rar_thread = new std::thread(&remoteAccess); + std::unique_ptr remoteAccessServer = nullptr; + if (sConfig.GetBoolDefault("Ra.Enable", false)) + remoteAccessServer = SetupRemoteAccessServer(ioCtx); // Handle affinity for multiple processors and process priority on Windows - #ifdef WIN32 +#ifdef WIN32 { - HANDLE hProcess = GetCurrentProcess(); + HANDLE hProcess = ::GetCurrentProcess(); uint32 Aff = sConfig.GetIntDefault("UseProcessors", 0); - if(Aff > 0) + if (Aff > 0) { ULONG_PTR appAff; ULONG_PTR sysAff; - if(GetProcessAffinityMask(hProcess,&appAff,&sysAff)) + if (::GetProcessAffinityMask(hProcess, &appAff, &sysAff)) { ULONG_PTR curAff = Aff & appAff; // remove non accessible processors - if(!curAff ) + if (!curAff) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Processors marked in UseProcessors bitmask (hex) %x not accessible for mangosd. Accessible processors bitmask (hex): %x",Aff,appAff); } else { - if(SetProcessAffinityMask(hProcess,curAff)) + if (::SetProcessAffinityMask(hProcess, curAff)) sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Using processors (bitmask, hex): %x", curAff); else - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Can't set used processors (hex): %x",curAff); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Can't set used processors (hex): %x", curAff); } } sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ""); @@ -270,85 +264,86 @@ int Master::Run() // if(Prio && (m_ServiceStatus == -1)/* need set to default process priority class in service mode*/) if(Prio) { - if(SetPriorityClass(hProcess,HIGH_PRIORITY_CLASS)) + if (::SetPriorityClass(hProcess,HIGH_PRIORITY_CLASS)) sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "mangosd process priority class set to HIGH"); else sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Can't set mangosd process priority class."); } } - #endif +#endif + + (void)sAsyncSystemTimer; // <-- Pre-Initialize SystemTimer + IO::Multithreading::RenameCurrentThread("Main"); // Start soap serving thread - std::thread* soap_thread = nullptr; + std::unique_ptr soap_thread = nullptr; - if(sConfig.GetBoolDefault("SOAP.Enabled", false)) + if (sConfig.GetBoolDefault("SOAP.Enabled", false)) { - soap_thread = new std::thread([](){ - MaNGOSsoapRunnable runnable; - runnable.setListenArguments(sConfig.GetStringDefault("SOAP.IP", "127.0.0.1"), sConfig.GetIntDefault("SOAP.Port", 7878)); - runnable.run(); - }); + std::string soapBindIp = sConfig.GetStringDefault("SOAP.IP", "127.0.0.1"); + uint16 soapBindPort = sConfig.GetIntDefault("SOAP.Port", 7878); + soap_thread = StartSoapThread(soapBindIp, soapBindPort); } // Start up freeze catcher thread - std::thread* freeze_thread = nullptr; - if(uint32 freeze_delay = sConfig.GetIntDefault("MaxCoreStuckTime", 0)) + std::unique_ptr freeze_thread = nullptr; + if (uint32 freeze_delay = sConfig.GetIntDefault("MaxCoreStuckTime", 0)) { - freeze_thread = new std::thread(std::bind(&freezeDetector,freeze_delay*1000)); - //freeze_thread->setPriority(ACE_Based::Highest); + freeze_thread = IO::Multithreading::CreateThreadPtr("FreezeDetector", std::bind(&freezeDetector, freeze_delay * 1000)); } - // Wait for clients ? // Launch the world listener socket - uint16 wsport = sWorld.getConfig(CONFIG_UINT32_PORT_WORLD); - std::string bind_ip = sConfig.GetStringDefault("BindIP", "0.0.0.0"); + std::string bindIp = sConfig.GetStringDefault("BindIP", "0.0.0.0"); + uint16 bindPort = sWorld.getConfig(CONFIG_UINT32_PORT_WORLD); + int socketOutByteBufferSize = sConfig.GetIntDefault("Network.SystemSendBuffer", -1); + bool doExplicitTcpNoDelay = sConfig.GetBoolDefault("Network.TcpNoDelay", true); + std::vector trustedProxyIps = SplitStringByDelimiter(sConfig.GetStringDefault("Network.TrustedProxyServers", ""), ','); - // Start WorldSockets - sWorldSocketMgr->SetOutKBuff(sConfig.GetIntDefault("Network.OutKBuff", -1)); - sWorldSocketMgr->SetOutUBuff(sConfig.GetIntDefault("Network.OutUBuff", 65536)); - sWorldSocketMgr->SetThreads(sConfig.GetIntDefault("Network.Threads", 1) + 1); - sWorldSocketMgr->SetInterval(sConfig.GetIntDefault("Network.Interval", 10)); - sWorldSocketMgr->SetTcpNodelay(sConfig.GetBoolDefault("Network.TcpNodelay", true)); - - if (sWorldSocketMgr->StartNetwork(wsport, bind_ip) == -1) + WorldSocketMgrOptions socketOptions + { + bindIp, + bindPort, + socketOutByteBufferSize, + doExplicitTcpNoDelay, + trustedProxyIps, + }; + + if (!sWorldSocketMgr.StartWorldNetworking(ioCtx, socketOptions)) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Failed to start WorldSocket network"); Log::WaitBeforeContinueIfNeed(); World::StopNow(ERROR_EXIT_CODE); - // go down and shutdown the server } - sWorldSocketMgr->Wait(); - // Stop freeze protection before shutdown tasks + world_thread.join(); // <-- This will block until the world stops + + _UnhookSignals(); // Remove signal handling before leaving + if (freeze_thread) - { freeze_thread->join(); - delete freeze_thread; - } - // Stop soap thread - if(soap_thread) - { + if (soap_thread) soap_thread->join(); - delete soap_thread; - } // Set server offline in realmlist - //LoginDatabase.DirectPExecute("UPDATE realmlist SET realmflags = realmflags | %u WHERE id = '%u'", REALM_FLAG_OFFLINE, realmID); + LoginDatabase.DirectPExecute("UPDATE realmlist SET realmflags = realmflags | %u WHERE id = '%u'", REALM_FLAG_OFFLINE, realmID); - // Remove signal handling before leaving - _UnhookSignals(); + sWorldSocketMgr.StopWorldNetworking(); - // when the main thread closes the singletons get unloaded - // since worldrunnable uses them, it will crash if unloaded after master - world_thread.join(); - - if(rar_thread) + if (remoteAccessServer) { - rar_thread->join(); - delete rar_thread; + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Stop remote access..."); + remoteAccessServer->ClosePortAndStopAcceptingNewConnections(); + remoteAccessServer.reset(); } + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Stop system timers..."); + sAsyncSystemTimer.RemoveAllTimersAndStopThread(); + + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Stop IO context..."); + ioCtx->Shutdown(); + for (std::thread& thread : ioCtxRunners) + thread.join(); + // Clean account database before leaving sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Cleaning character database..."); clearOnlineAccounts(); @@ -374,7 +369,7 @@ int Master::Run() //_exit(1); // send keyboard input to safely unblock the CLI thread INPUT_RECORD b[5]; - HANDLE hStdIn = GetStdHandle(STD_INPUT_HANDLE); + HANDLE hStdIn = ::GetStdHandle(STD_INPUT_HANDLE); b[0].EventType = KEY_EVENT; b[0].Event.KeyEvent.bKeyDown = TRUE; b[0].Event.KeyEvent.uChar.AsciiChar = 'X'; @@ -403,21 +398,19 @@ int Master::Run() b[3].Event.KeyEvent.wVirtualScanCode = 0x1c; b[3].Event.KeyEvent.wRepeatCount = 1; DWORD numb; - WriteConsoleInput(hStdIn, b, 4, &numb); + ::WriteConsoleInput(hStdIn, b, 4, &numb); #else - fclose(stdin); + ::fclose(stdin); #endif if (cliThread->joinable()) cliThread->join(); - - delete cliThread; } // Exit the process with specified return value return World::GetExitCode(); } -bool StartDB(std::string name, DatabaseType& database, const char **migrations) +bool StartDB(const std::string& name, DatabaseType& database, char const** migrations) { // Get database info from configuration file std::string dbstring = sConfig.GetStringDefault((name + "Database.Info").c_str(), ""); @@ -478,7 +471,7 @@ bool Master::_StartDB() { // Get the realm Id from the configuration file realmID = sConfig.GetIntDefault("RealmID", 0); - if(!realmID) + if (!realmID) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Realm ID not defined in configuration file"); return false; @@ -518,16 +511,16 @@ void Master::clearOnlineAccounts() #include "ObjectAccessor.h" #include "Language.h" -void createdump(void) +void CreateCrashDump() { #ifndef WIN32 - if (!fork()) { //child process - // Crash the app - abort(); + if (!::fork()) // Create child process + { + ::abort(); // Crash the app immediately } #endif - } + // Handle termination signals void Master::SigvSignalHandler() { @@ -535,6 +528,7 @@ void Master::SigvSignalHandler() _OnSignal(SIGSEGV); exit(1); } + void Master::_OnSignal(int s) { switch (s) @@ -543,71 +537,77 @@ void Master::_OnSignal(int s) World::StopNow(RESTART_EXIT_CODE); break; case SIGTERM: - #ifdef _WIN32 - case SIGBREAK: - #endif +#ifdef _WIN32 + case SIGBREAK: +#endif World::StopNow(SHUTDOWN_EXIT_CODE); break; case SIGSEGV: - signal(SIGSEGV, 0); + ::signal(SIGSEGV, nullptr); if (!m_handleSigvSignals) return; + m_handleSigvSignals = false; // Disarm anti-crash + std::exception_ptr exc = std::current_exception(); - m_handleSigvSignals = false; // Desarm anticrash sWorld.SetAnticrashRearmTimer(sWorld.getConfig(CONFIG_UINT32_ANTICRASH_REARM_TIMER)); uint32 anticrashOptions = sWorld.getConfig(CONFIG_UINT32_ANTICRASH_OPTIONS); + // Log crash stack sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Received SIGSEGV"); - ACE_Stack_Trace st; - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "%s", st.c_str()); + MaNGOS::Errors::PrintStacktrace(); + if (anticrashOptions & ANTICRASH_GENERATE_COREDUMP) - createdump(); + CreateCrashDump(); + if (anticrashOptions & ANTICRASH_OPTION_ANNOUNCE_PLAYERS) { if (anticrashOptions & ANTICRASH_OPTION_SAVEALL) sWorld.SendWorldText(LANG_SYSTEMMESSAGE, "Server has crashed. Now saving online players ..."); else sWorld.SendWorldText(LANG_SYSTEMMESSAGE, "Crash server occurred :("); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); } + if (anticrashOptions & ANTICRASH_OPTION_SAVEALL) { CharacterDatabase.ThreadStart(); sObjectAccessor.SaveAllPlayers(); std::this_thread::sleep_for(std::chrono::seconds(25)); } + std::rethrow_exception(exc); // Crash for real now. return; } - signal(s, _OnSignal); + ::signal(s, _OnSignal); } void Master::_HookSignals() { - signal(SIGINT, _OnSignal); - signal(SIGTERM, _OnSignal); - signal(SIGSEGV, _OnSignal); - #ifdef _WIN32 - signal(SIGBREAK, _OnSignal); - #endif + ::signal(SIGINT, _OnSignal); + ::signal(SIGTERM, _OnSignal); + ::signal(SIGSEGV, _OnSignal); +#ifdef _WIN32 + ::signal(SIGBREAK, _OnSignal); +#endif ArmAnticrash(); } void Master::ArmAnticrash() { - //signal(SIGSEGV, _OnSignal); + //::signal(SIGSEGV, _OnSignal); m_handleSigvSignals = true; } // Unhook the signals before leaving void Master::_UnhookSignals() { - signal(SIGINT, 0); - signal(SIGTERM, 0); - signal(SIGSEGV, 0); - #ifdef _WIN32 - signal(SIGBREAK, 0); - #endif + ::signal(SIGINT, nullptr); + ::signal(SIGTERM, nullptr); + ::signal(SIGSEGV, nullptr); +#ifdef _WIN32 + ::signal(SIGBREAK, nullptr); +#endif m_handleSigvSignals = false; } diff --git a/src/mangosd/RASocket.cpp b/src/mangosd/RASocket.cpp deleted file mode 100644 index 75fe8abbd42..00000000000 --- a/src/mangosd/RASocket.cpp +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -/** \file - \ingroup mangosd -*/ - -#include "Common.h" -#include "Database/DatabaseEnv.h" -#include "Log.h" -#include "RASocket.h" -#include "World.h" -#include "Config/Config.h" -#include "Util.h" -#include "AccountMgr.h" -#include "Language.h" -#include "ObjectMgr.h" - -// RASocket constructor -RASocket::RASocket() -:RAHandler(), -pendingCommands(0, USYNC_THREAD, "pendingCommands"), -outActive(false), -inputBufferLen(0), -outputBufferLen(0), -stage(NONE) -{ - // Get the config parameters - bSecure = sConfig.GetBoolDefault( "RA.Secure", true ); - bStricted = sConfig.GetBoolDefault( "RA.Stricted", false ); - iMinLevel = AccountTypes(sConfig.GetIntDefault( "RA.MinLevel", SEC_ADMINISTRATOR )); - reference_counting_policy ().value (ACE_Event_Handler::Reference_Counting_Policy::ENABLED); -} - -// RASocket destructor -RASocket::~RASocket() -{ - peer().close(); - sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "Connection was closed."); -} - -// Accept an incoming connection -int RASocket::open(void* ) -{ - if (reactor ()->register_handler(this, ACE_Event_Handler::READ_MASK | ACE_Event_Handler::WRITE_MASK) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "RASocket::open: unable to register client handler errno = %s", ACE_OS::strerror (errno)); - return -1; - } - - ACE_INET_Addr remote_addr; - - if (peer ().get_remote_addr (remote_addr) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "RASocket::open: peer ().get_remote_addr errno = %s", ACE_OS::strerror (errno)); - return -1; - } - - - sLog.Out(LOG_RA, LOG_LVL_BASIC, "Incoming connection from %s.",remote_addr.get_host_addr()); - - // print Motd - sendf(sWorld.GetMotd()); - sendf("\r\n"); - sendf(sObjectMgr.GetMangosStringForDBCLocale(LANG_RA_USER)); - - return 0; -} - -int RASocket::close(int) -{ - if(closing_) - return -1; - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "RASocket::close"); - shutdown(); - - closing_ = true; - - remove_reference(); - return 0; -} - -int RASocket::handle_close (ACE_HANDLE h, ACE_Reactor_Mask) -{ - if(closing_) - return -1; - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "RASocket::handle_close"); - std::unique_lock lock (outBufferLock); - - closing_ = true; - - if (h == ACE_INVALID_HANDLE) - peer ().close_writer (); - remove_reference(); - return 0; -} - -int RASocket::handle_output (ACE_HANDLE) -{ - std::unique_lock lock (outBufferLock); - - if(closing_) - return -1; - - if (!outputBufferLen) - { - if(reactor()->cancel_wakeup(this, ACE_Event_Handler::WRITE_MASK) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "RASocket::handle_output: error while cancel_wakeup"); - return -1; - } - outActive = false; - return 0; - } -#ifdef MSG_NOSIGNAL - ssize_t n = peer ().send (outputBuffer, outputBufferLen, MSG_NOSIGNAL); -#else - ssize_t n = peer ().send (outputBuffer, outputBufferLen); -#endif // MSG_NOSIGNAL - - if(n<=0) - return -1; - - ACE_OS::memmove(outputBuffer, outputBuffer+n, outputBufferLen-n); - - outputBufferLen -= n; - - return 0; -} - -// Read data from the network -int RASocket::handle_input(ACE_HANDLE) -{ - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "RASocket::handle_input"); - if(closing_) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Called RASocket::handle_input with closing_ = true"); - return -1; - } - - size_t readBytes = peer().recv(inputBuffer+inputBufferLen, RA_BUFF_SIZE-inputBufferLen-1); - - if(readBytes <= 0) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "read %u bytes in RASocket::handle_input", readBytes); - return -1; - } - - // Discard data after line break or line feed - bool gotenter=false; - for(; readBytes > 0 ; --readBytes) - { - char c = inputBuffer[inputBufferLen]; - if (c=='\r'|| c=='\n') - { - gotenter=true; - break; - } - ++inputBufferLen; - } - - if (gotenter) - { - inputBuffer[inputBufferLen]=0; - inputBufferLen=0; - switch(stage) - { - //
  • If the input is '' - case NONE: - { - std::string szLogin=inputBuffer; - - accId = sAccountMgr.GetId(szLogin); - - // If the user is not found, deny access - if(!accId) - { - sendf("-No such user.\r\n"); - sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "User %s does not exist.",szLogin.c_str()); - if(bSecure) - { - handle_output(); - return -1; - } - sendf("\r\n"); - sendf(sObjectMgr.GetMangosStringForDBCLocale(LANG_RA_USER)); - break; - } - - accAccessLevel = sAccountMgr.GetSecurity(accId); - - // - if gmlevel is too low, deny access - if (accAccessLevel < iMinLevel) - { - sendf("-Not enough privileges.\r\n"); - sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "User %s has no privilege.",szLogin.c_str()); - if(bSecure) - { - handle_output(); - return -1; - } - sendf("\r\n"); - sendf(sObjectMgr.GetMangosStringForDBCLocale(LANG_RA_USER)); - break; - } - - // - allow by remotely connected admin use console level commands dependent from config setting - if (accAccessLevel >= SEC_ADMINISTRATOR && !bStricted) - accAccessLevel = SEC_CONSOLE; - - stage=LG; - sendf(sObjectMgr.GetMangosStringForDBCLocale(LANG_RA_PASS)); - break; - } - //
  • If the input is '' (and the user already gave his username) - case LG: - { //login+pass ok - std::string pw = inputBuffer; - - if (sAccountMgr.CheckPassword(accId, pw)) - { - stage=OK; - - sendf("+Logged in.\r\n"); - sLog.Out(LOG_RA, LOG_LVL_BASIC, "User account %u has logged in.", accId); - sendf("mangos>"); - } - else - { - // Else deny access - sendf("-Wrong pass.\r\n"); - sLog.Out(LOG_RA, LOG_LVL_BASIC, "User account %u has failed to log in.", accId); - if(bSecure) - { - handle_output(); - return -1; - } - sendf("\r\n"); - sendf(sObjectMgr.GetMangosStringForDBCLocale(LANG_RA_PASS)); - } - break; - } - //
  • If user is logged, parse and execute the command - case OK: - if (strlen(inputBuffer)) - { - sLog.Out(LOG_RA, LOG_LVL_BASIC, "Got '%s' cmd.",inputBuffer); - if (strncmp(inputBuffer,"quit",4)==0) - return -1; - else - { - CliCommandHolder* cmd = new CliCommandHolder(accId, accAccessLevel, this, inputBuffer, &RASocket::zprint, &RASocket::commandFinished); - sWorld.QueueCliCommand(cmd); - pendingCommands.acquire(); - } - } - else - sendf("mangos>"); - break; - //
- }; - - } - // no enter yet? wait for next input... - return 0; -} - -// Output function -void RASocket::zprint(void* callbackArg, const char * szText ) -{ - if( !szText ) - return; - - ((RASocket*)callbackArg)->sendf(szText); -} - -void RASocket::commandFinished(void* callbackArg, bool /*sucess*/) -{ - RASocket* raSocket = (RASocket*)callbackArg; - raSocket->sendf("mangos>"); - raSocket->pendingCommands.release(); -} - -int RASocket::sendf(const char* msg) -{ - std::unique_lock lock (outBufferLock); - - if(closing_) - return -1; - - int msgLen = strlen(msg); - - if(msgLen+outputBufferLen > RA_BUFF_SIZE) - return -1; - - ACE_OS::memcpy(outputBuffer+outputBufferLen, msg, msgLen); - outputBufferLen += msgLen; - - if(!outActive) - { - if (reactor ()->schedule_wakeup - (this, ACE_Event_Handler::WRITE_MASK) == -1) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "RASocket::sendf error while schedule_wakeup"); - return -1; - } - outActive = true; - } - return 0; -} diff --git a/src/mangosd/RASocket.h b/src/mangosd/RASocket.h deleted file mode 100644 index 799fb59c41b..00000000000 --- a/src/mangosd/RASocket.h +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -// \addtogroup mangosd -// @{ -// \file - -#ifndef _RASOCKET_H -#define _RASOCKET_H - -#include "Common.h" -#include -#include -#include -#include -#include - -#define RA_BUFF_SIZE 8192 - - -// Remote Administration socket -typedef ACE_Svc_Handler < ACE_SOCK_STREAM, ACE_NULL_SYNCH> RAHandler; -class RASocket: protected RAHandler -{ - public: - ACE_Semaphore pendingCommands; - typedef ACE_Acceptor Acceptor; - friend class ACE_Acceptor; - - int sendf(const char*); - - protected: - // things called by ACE framework. - RASocket(void); - virtual ~RASocket(void); - - // Called on open ,the void* is the acceptor. - virtual int open (void *); - - // Called on failures inside of the acceptor, don't call from your code. - virtual int close (int); - - // Called when we can read from the socket. - virtual int handle_input (ACE_HANDLE = ACE_INVALID_HANDLE); - - // Called when the socket can write. - virtual int handle_output (ACE_HANDLE = ACE_INVALID_HANDLE); - - // Called when connection is closed or error happens. - virtual int handle_close (ACE_HANDLE = ACE_INVALID_HANDLE, - ACE_Reactor_Mask = ACE_Event_Handler::ALL_EVENTS_MASK); - - private: - bool outActive; - - char inputBuffer[RA_BUFF_SIZE]; - uint32 inputBufferLen; - - std::mutex outBufferLock; - char outputBuffer[RA_BUFF_SIZE]; - uint32 outputBufferLen; - - uint32 accId; - AccountTypes accAccessLevel; - bool bSecure; // kick on wrong pass, non exist. user OR user with no priv - // will protect from DOS, bruteforce attacks - bool bStricted; // not allow execute console only commands (SEC_CONSOLE) remotly - AccountTypes iMinLevel; - enum - { - NONE, // initial value - LG, // only login was entered - OK, // both login and pass were given, they were correct and user has enough priv. - }stage; - - static void zprint(void* callbackArg, const char * szText ); - static void commandFinished(void* callbackArg, bool success); -}; -#endif -// @} diff --git a/src/mangosd/WorldRunnable.cpp b/src/mangosd/WorldRunnable.cpp index b4c7e7d15e7..59f6200b25e 100644 --- a/src/mangosd/WorldRunnable.cpp +++ b/src/mangosd/WorldRunnable.cpp @@ -23,7 +23,6 @@ \ingroup mangosd */ -#include "WorldSocketMgr.h" #include "Common.h" #include "World.h" #include "WorldRunnable.h" @@ -41,7 +40,7 @@ #ifdef WIN32 #include "ServiceWin32.h" -extern int m_ServiceStatus; +extern volatile int m_ServiceStatus; #endif // Heartbeat for the World @@ -121,9 +120,6 @@ void WorldRunnable::operator()() // unload battleground templates before different singletons destroyed sBattleGroundMgr.DeleteAllBattleGrounds(); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Stopping network threads..."); - sWorldSocketMgr->StopNetwork(); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Unloading all maps..."); sMapMgr.UnloadAll(); // unload all grids (including locked in memory) diff --git a/src/mangosd/mangosd.conf.dist.in b/src/mangosd/mangosd.conf.dist.in index 439dbf58bd4..2df939edeb0 100644 --- a/src/mangosd/mangosd.conf.dist.in +++ b/src/mangosd/mangosd.conf.dist.in @@ -3,7 +3,7 @@ ##################################### [MangosdConf] -ConfVersion=2010100901 +ConfVersion=2024091701 ################################################################################################################### # CONNECTIONS AND DIRECTORIES @@ -592,7 +592,13 @@ DebuffLimit = 0 # # LogFile.Trades # Log file of trade related messages -# Default: "" (Disable) +# Default: "Trades.log" +# "" - Empty name for disable +# +# LogFile.Network +# Log file of network related messages +# Default: "Network.log" +# "" - Empty name for disable # # CharLogDump # Write character dump before deleting in Char.log @@ -685,6 +691,7 @@ LogFile.Performance = "Perf.log" LogFile.Gm = "" LogFile.GmCriticalCommands = "gm_critical.log" LogFile.Trades = "Trades.log" +LogFile.Network = "Network.log" LogMoneyTreshold = 10000 CharLogDump = 0 @@ -1081,7 +1088,8 @@ PerformanceLog.SlowPacketBroadcast = 0 # WaitAtStartupError # After startup error report wait for or some time before continuing (and possibly close console window) # -1 (wait until press) -# Default: 0 (not wait) +# 0 (no wait) +# Default: 5 (wait 5 sec) # N (>0, wait N secs) # # Motd @@ -1155,7 +1163,7 @@ Spell.EffectDelay = 400 Spell.ProcDelay = 800 BeepAtStart = 1 ShowProgressBars = 0 -WaitAtStartupError = 0 +WaitAtStartupError = 5 # consider fixing your actual problem before changing this value! Motd = "Welcome to World of Warcraft!" ################################################################################################################### @@ -2761,14 +2769,10 @@ LFG.MatchmakingTimer = 600 # Number of threads for network, recommend 1 thread per 1000 connections. # Default: 1 # -# Network.OutKBuff +# Network.SystemSendBuffer # The size of the output kernel buffer used ( SO_SNDBUF socket option, tcp manual ). # Default: -1 (Use system default setting) # -# Network.OutUBuff -# Userspace buffer for output. This is amount of memory reserved per each connection. -# Default: 65536 -# # Network.TcpNoDelay: # TCP Nagle algorithm setting # Default: 0 (enable Nagle algorithm, less traffic, more latency) @@ -2787,21 +2791,30 @@ LFG.MatchmakingTimer = 600 # How often packet broadcasting threads run in milliseconds. # Default: 50 # -# Network.Interval -# How often ACE will transmit the client's outbound packet buffer in milliseconds. -# Default: 10 +# Network.PacketBroadcast.ReduceVisDistance.DiffAbove +# Description: TODO +# Default: 0 +# +# Network.TrustedProxyServers +# Description: Enables the parsing of Proxy Protocol v2 for specific IPs. +# You can use this feature when your server is behind a proxy, load balancer, or similar component, +# to retrieve the real IP address of players. +# You need to enable Proxy Protocol v2 on both this server and the proxy/load balancer. +# For example see HaProxy "send-proxy-v2" option. +# Multiple servers can be separated with ',' +# Default: "" - (Disabled, no proxy) +# Example "10.13.37.1,10.13.37.2" - (to allow multiple proxy servers) # ################################################################################################################### Network.Threads = 1 -Network.OutKBuff = -1 -Network.OutUBuff = 65536 +Network.SystemSendBuffer = -1 Network.TcpNodelay = 1 Network.KickOnBadPacket = 0 Network.PacketBroadcast.Threads = 0 Network.PacketBroadcast.Frequency = 50 Network.PacketBroadcast.ReduceVisDistance.DiffAbove = 0 -Network.Interval = 10 +Network.TrustedProxyServers = "" ################################################################################################################### # CONSOLE, REMOTE ACCESS AND SOAP @@ -2823,17 +2836,12 @@ Network.Interval = 10 # Default remote console port # Default: 3443 # -# Ra.MinLevel -# Minimum level that's required to login,3 by default +# Ra.MinAccountLevel +# Minimum account level that's required to login, 3 by default # Default: 3 (Administrator) # -# Ra.Secure -# Kick client on wrong pass -# 0 - off -# Default: 1 - on -# -# Ra.Stricted -# Not allow execute console level only commands remotly by RA +# Ra.Restricted +# Not allow execute console level only commands remotely by RA # 0 - off # Default: 1 - on # @@ -2857,9 +2865,8 @@ Console.Enable = 1 Ra.Enable = 0 Ra.IP = 0.0.0.0 Ra.Port = 3443 -Ra.MinLevel = 3 -Ra.Secure = 1 -Ra.Stricted = 1 +Ra.MinAccountLevel = 3 +Ra.Restricted = 1 SOAP.Enabled = 0 SOAP.IP = 127.0.0.1 diff --git a/src/mangosd/remote/RemoteAccess/RASocket.cpp b/src/mangosd/remote/RemoteAccess/RASocket.cpp new file mode 100644 index 00000000000..0f108f8a529 --- /dev/null +++ b/src/mangosd/remote/RemoteAccess/RASocket.cpp @@ -0,0 +1,288 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include "Common.h" +#include "Database/DatabaseEnv.h" +#include "Log.h" +#include "RASocket.h" +#include "World.h" +#include "Config/Config.h" +#include "Util.h" +#include "AccountMgr.h" +#include "Language.h" +#include "ObjectMgr.h" +#include "Memory/ArrayDeleter.h" + +#include +#include +#include + +static std::string const NEWLINE = "\r\n"; +static std::string const PROMPT = "mangos>"; + +RASocket::RASocket(IO::Networking::AsyncSocket socket) + : m_socket(std::move(socket)), + m_connectionState(ConnectionState::FreshConnection), + m_atLeastOnePacketWasReceived(false), + m_accountId(0), + m_username(), + m_accountLevel(AccountTypes::SEC_PLAYER) +{ + if (sConfig.IsSet("Ra.Stricted")) + { + sLog.Out(LOG_RA, LOG_LVL_ERROR, "Deprecated config option Ra.Stricted being used. Use Ra.Restricted instead."); + m_restricted = sConfig.GetBoolDefault("Ra.Stricted", true); + } + else + m_restricted = sConfig.GetBoolDefault("Ra.Restricted", true); +} + +RASocket::~RASocket() +{ + sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "[%s] Connection was closed", m_socket.GetRemoteIpString().c_str()); +} + +void RASocket::Start() +{ + if (IO::NetworkError initError = m_socket.InitializeAndFixateMemoryLocation()) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[%s] Failed to initialize RASocket %s", m_socket.GetRemoteIpString().c_str(), initError.ToString().c_str()); + return; // implicit close() + } + + sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "[%s] Incoming RA connection", m_socket.GetRemoteIpString().c_str()); + + std::string welcomeMessage; + welcomeMessage += sWorld.GetMotd(); // <-- technically, we should replace all '\n' in MOTD with NEWLINE + welcomeMessage += NEWLINE; + welcomeMessage += sObjectMgr.GetMangosStringForDBCLocale(LANG_RA_USER); + + SendAndRecvNextInput(welcomeMessage); +} + +void RASocket::DoRecvIncomingData() +{ + sLog.Out(LOG_RA, LOG_LVL_DEBUG, "RASocket::DoRecvIncomingData"); + + // Check if we got a full line in our buffer first + std::string::size_type newLinePos = m_pendingInputBuffer.find_first_of(NEWLINE); + if (newLinePos != std::string::npos) + { + // remove newline from buffer and forward line + std::string line = m_pendingInputBuffer.substr(0, newLinePos); + m_pendingInputBuffer.erase(0, newLinePos + NEWLINE.size()); + + if (line.size() == 4095) // Exact length match of the terminal limit. Maybe the user tries to execute a really long command. + sLog.Out(LOG_RA, LOG_LVL_ERROR, "[%s] A default telnet terminal only allows 4096 characters per line. This command could be executed incorrectly!", m_socket.GetRemoteIpString().c_str()); + + HandleInput(line); // This function must ensure that DoRecvIncomingData() is executed when done + return; + } + + if (m_connectionState != ConnectionState::Authenticated && m_pendingInputBuffer.size() > MAX_INPUT_BUFFER_SIZE_WHILE_UNAUTHENTICATED) + { + sLog.Out(LOG_RA, LOG_LVL_ERROR, "[%s] Unauthenticated connection had too large buffer", m_socket.GetRemoteIpString().c_str()); + return; // implicit socket close + } + + // we need more data to process this message + std::shared_ptr> recvBuffer(new std::vector(1024)); + m_socket.ReadSome(recvBuffer->data(), recvBuffer->size(), [self = shared_from_this(), recvBuffer](IO::NetworkError const& error, std::size_t amountRead) + { + if (error) + { + sLog.Out(LOG_RA, LOG_LVL_ERROR, "[%s] Connection had error: %s", self->m_socket.GetRemoteIpString().c_str(), error.ToString().c_str()); + return; // implicit socket close + } + + if (!self->m_atLeastOnePacketWasReceived) + { + // Some terminals send a negotiation packet in the very first message + self->m_atLeastOnePacketWasReceived = true; + if (amountRead >= 1 && (static_cast(recvBuffer->at(0)) == 0xFF)) + { + // We got a telnet protocol packet, most likely the terminal wants us to tell the capabilities it has, but we are not really interested in it + std::vector endOfNegotiationResponse = { 0xFF, 0xF0 }; + self->m_socket.Write(std::move(endOfNegotiationResponse), [self](IO::NetworkError const& error) { self->DoRecvIncomingData(); }); + return; + } + } + + self->m_pendingInputBuffer.append(recvBuffer->data(), amountRead); + self->DoRecvIncomingData(); // reprocesses our pending buffer + }); +} + +void RASocket::HandleInput(std::string const& line) +{ + switch (m_connectionState) + { + // If the input is '' + case ConnectionState::FreshConnection: + HandleInput_FreshConnection(line); + break; + + // If the input is '' (and the user already gave his username) + case ConnectionState::GotUsername: + HandleInput_GotUsername(line); + break; + + // If user is logged in: parse and execute the command + case ConnectionState::Authenticated: + HandleInput_Authenticated(line); + break; + + default: + MANGOS_ASSERT(false); + } +} + +void RASocket::HandleInput_FreshConnection(std::string const& line) +{ + m_username = line; + m_connectionState = ConnectionState::GotUsername; + SendAndRecvNextInput(sObjectMgr.GetMangosStringForDBCLocale(LANG_RA_PASS)); +} + +void RASocket::HandleInput_GotUsername(std::string const& line) +{ + AccountTypes minRequiredAccLevel = static_cast(sConfig.GetIntDefault("Ra.MinLevel", AccountTypes::SEC_ADMINISTRATOR)); + + bool loginSuccessful = true; + + if (loginSuccessful) // check username + { + m_accountId = sAccountMgr.GetId(m_username); + if (!m_accountId) + { + sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "[%s] Account '%s' does not exist", m_socket.GetRemoteIpString().c_str(), m_username.c_str()); + loginSuccessful = false; + } + } + + if (loginSuccessful) // check password + { + if (!sAccountMgr.CheckPassword(m_accountId, line)) + { + sLog.Out(LOG_RA, LOG_LVL_MINIMAL,"[%s] Wrong password for account %s", m_socket.GetRemoteIpString().c_str(), m_username.c_str()); + loginSuccessful = false; + } + } + + if (loginSuccessful) // check account level + { + m_accountLevel = sAccountMgr.GetSecurity(m_accountId); + + if (m_accountLevel < minRequiredAccLevel) + { + sLog.Out(LOG_RA, LOG_LVL_MINIMAL,"[%s] Account %s has no privilege for RA", m_socket.GetRemoteIpString().c_str(), m_username.c_str()); + loginSuccessful = false; + } + else + { + // allow by remotely connected admin use console level commands dependent from config setting + if (m_accountLevel >= SEC_ADMINISTRATOR && !m_restricted) + m_accountLevel = SEC_CONSOLE; + } + } + + if (loginSuccessful) + { + sLog.Out(LOG_RA, LOG_LVL_MINIMAL,"[%s] Account %s has logged in", m_socket.GetRemoteIpString().c_str(), m_username.c_str()); + m_connectionState = ConnectionState::Authenticated; + SendAndRecvNextInput("+Logged in." + NEWLINE + " " + PROMPT); + } + else + { + SendAndDisconnect("-Authentication failed. Verify username, password and required accountLevel." + NEWLINE); + sLog.Out(LOG_RA, LOG_LVL_MINIMAL,"[%s] Account %s has failed to log in", m_socket.GetRemoteIpString().c_str(), m_username.c_str()); + } +} + +void RASocket::HandleInput_Authenticated(std::string const& line) +{ + if (line.empty()) + { + SendAndRecvNextInput(" " + PROMPT); + return; + } + + sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "[%s/%s] Received command: %s", m_socket.GetRemoteIpString().c_str(), m_username.c_str(), line.c_str()); + + // handle quit, exit and logout commands to terminate connection + if (line == "quit" || line == "exit" || line == "logout") + return; + + // TODO: Make CliCommandHolder able to use std::function + struct InvokeOutputEnvironment + { + std::shared_ptr self; + std::string output; + }; + auto* invokeEnvironmentPtr = new InvokeOutputEnvironment + { + shared_from_this(), + "", + }; + + sWorld.QueueCliCommand(new CliCommandHolder( + m_accountId, + m_accountLevel, + invokeEnvironmentPtr, + line.c_str(), + [](void* opaquePointer, const char* buffer) + { + auto* invokeEnvironmentPtr = static_cast(opaquePointer); + invokeEnvironmentPtr->output.append(buffer); + }, + [](void* opaquePointer, bool commandWasSuccessful) + { + char const* statusSymbol = commandWasSuccessful ? "+" : "-"; + + auto* invokeEnvironmentPtr = static_cast(opaquePointer); + invokeEnvironmentPtr->output.append(statusSymbol + PROMPT); + invokeEnvironmentPtr->self->SendAndRecvNextInput(invokeEnvironmentPtr->output); + delete invokeEnvironmentPtr; + } + )); +} + + +void RASocket::SendAndDisconnect(std::string const& message) +{ + std::shared_ptr rawMessage(new uint8_t[message.size()], MaNGOS::Memory::array_deleter()); + memcpy(rawMessage.get(), message.c_str(), message.size()); + m_socket.Write({ rawMessage, message.size() }, [self = shared_from_this()](IO::NetworkError const& error) + { + if (error) + sLog.Out(LOG_RA, LOG_LVL_ERROR, "[%s] Sending message failed: %s", self->m_socket.GetRemoteIpString().c_str(), error.ToString().c_str()); + }); +} + +void RASocket::SendAndRecvNextInput(std::string const& message) +{ + std::shared_ptr rawMessage(new uint8_t[message.size()], MaNGOS::Memory::array_deleter()); + memcpy(rawMessage.get(), message.c_str(), message.size()); + m_socket.Write({ rawMessage, message.size() }, [self = shared_from_this()](IO::NetworkError const& error) + { + if (error) + { + sLog.Out(LOG_RA, LOG_LVL_ERROR, "[%s] Sending message failed: %s", self->m_socket.GetRemoteIpString().c_str(), error.ToString().c_str()); + return; + } + self->DoRecvIncomingData(); + }); +} diff --git a/src/mangosd/remote/RemoteAccess/RASocket.h b/src/mangosd/remote/RemoteAccess/RASocket.h new file mode 100644 index 00000000000..412d65ebc8f --- /dev/null +++ b/src/mangosd/remote/RemoteAccess/RASocket.h @@ -0,0 +1,67 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef _RASOCKET_H +#define _RASOCKET_H + +#include "Common.h" + +#include "IO/Networking/AsyncSocket.h" + +#include +#include +#include + +/// Remote Administration socket +class RASocket final : public std::enable_shared_from_this +{ + public: + RASocket(IO::Networking::AsyncSocket socket); + virtual ~RASocket(); + + void Start(); + + private: + IO::Networking::AsyncSocket m_socket; + + enum class ConnectionState + { + FreshConnection, + GotUsername, + Authenticated, + }; + + bool m_restricted; + + std::string m_pendingInputBuffer; // might contain multiple lines + std::size_t static constexpr MAX_INPUT_BUFFER_SIZE_WHILE_UNAUTHENTICATED = 128; + + ConnectionState m_connectionState; + bool m_atLeastOnePacketWasReceived; + std::string m_username; + uint32 m_accountId; + AccountTypes m_accountLevel; + + void DoRecvIncomingData(); + void HandleInput(std::string const& line); + void HandleInput_FreshConnection(std::string const& line); + void HandleInput_GotUsername(std::string const& line); + void HandleInput_Authenticated(std::string const& line); + + void SendAndDisconnect(std::string const& message); + void SendAndRecvNextInput(std::string const& message); +}; +#endif diff --git a/src/mangosd/remote/soap/MaNGOSsoap.cpp b/src/mangosd/remote/soap/MaNGOSsoap.cpp new file mode 100644 index 00000000000..6126e5d72ef --- /dev/null +++ b/src/mangosd/remote/soap/MaNGOSsoap.cpp @@ -0,0 +1,144 @@ +#include "MaNGOSsoap.h" +#include "stdsoap2.h" + +#include "World.h" +#include "Log.h" +#include "AccountMgr.h" + +#include "IO/Networking/IpAddress.h" +#include "IO/Multithreading/CreateThread.h" + +class SOAPCommand +{ + public: + /// Blocks until OnCommandFinished is called + bool WaitAndGetSuccessStatus() + { + return m_successStatusPromise.get_future().get(); + } + + static void OnPrint(void* opaquePointer, char const* msg) + { + SOAPCommand* self = static_cast(opaquePointer); + self->m_printBuffer += msg; + } + + static void OnCommandFinished(void* opaquePointer, bool success) + { + SOAPCommand* self = static_cast(opaquePointer); + self->m_successStatusPromise.set_value(success); + } + + std::string m_printBuffer; + std::promise m_successStatusPromise; +}; + +void SoapThreadBody(struct soap* soap) +{ + while (!World::IsStopped()) + { + if (!soap_valid_socket(soap_accept(soap))) + continue; // most likely, we ran into an accept timeout + + auto ip = IO::Networking::IpAddress::FromIpv4Uint32(soap->ip); + sLog.Out(LOG_RA, LOG_LVL_BASIC, "MaNGOSsoap: Accepted connection from %s", ip.ToString().c_str()); + + soap_serve(soap); // handle soap request + } + + sLog.Out(LOG_RA, LOG_LVL_MINIMAL, "MaNGOSsoap: Stopping..."); + soap_end(soap); + soap_done(soap); + soap_destroy(soap); +} + +std::unique_ptr StartSoapThread(std::string const& bindHost, uint16 bindPort) +{ + struct soap* soap = soap_new(); + soap_init(soap); + soap_set_imode(soap, SOAP_C_UTFSTRING); + soap_set_omode(soap, SOAP_C_UTFSTRING); + + soap->accept_timeout = 3; // sec | Check every 3 seconds if World::IsStopped() + soap->recv_timeout = 5; // sec + soap->send_timeout = 5; // sec + + int const acceptBacklogCount = 50; + + if (!soap_valid_socket(soap_bind(soap, bindHost.c_str(), bindPort, acceptBacklogCount))) + { + sLog.Out(LOG_RA, LOG_LVL_ERROR, "MaNGOSsoap: Couldn't bind to %s:%d", bindHost.c_str(), bindPort); + soap_done(soap); + soap_destroy(soap); + return nullptr; + } + + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "MaNGOSsoap: Bound to http://%s:%d/", bindHost.c_str(), bindPort); + + return IO::Multithreading::CreateThreadPtr("SOAP", [soap]() + { SoapThreadBody(soap); }); +} + +/// Defined by soap.stub +int ns1__executeCommand(soap* soap, char* command, char** result) +{ + // security check + if (!soap->userid || !soap->passwd) + { + sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "MaNGOSsoap: Client didn't provide login information"); + return 401; + } + + uint32 accountId = sAccountMgr.GetId(soap->userid); + if (!accountId) + { + sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "MaNGOSsoap: Client used invalid username '%s'", soap->userid); + return 401; + } + + if (!sAccountMgr.CheckPassword(accountId, soap->passwd)) + { + sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "MaNGOSsoap: invalid password for account '%s'", soap->userid); + return 401; + } + + if (sAccountMgr.GetSecurity(accountId) < SEC_ADMINISTRATOR) + { + sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "MaNGOSsoap: %s's gmlevel is too low", soap->userid); + return 403; + } + + if (!command || !*command) + return soap_sender_fault(soap, "Parameter 'command' can not be empty", "The supplied command was an empty string"); + + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "MaNGOSsoap: Received command '%s'", command); + + // Commands are executed in the world thread. We have to wait for them to be completed + SOAPCommand commandHolder; + { + // CliCommandHolder will be deleted from world, accessing after queueing is NOT safe + CliCommandHolder* cmd = new CliCommandHolder(accountId, SEC_CONSOLE, &commandHolder, command, &SOAPCommand::OnPrint, &SOAPCommand::OnCommandFinished); + sWorld.QueueCliCommand(cmd); + } + + // Wait for callback to complete command + bool wasSuccessful = commandHolder.WaitAndGetSuccessStatus(); + + char* printBuffer = soap_strdup(soap, commandHolder.m_printBuffer.c_str()); + if (!wasSuccessful) + return soap_sender_fault(soap, printBuffer, printBuffer); + + *result = printBuffer; + return SOAP_OK; +} + +/// Namespace definition for gSOAP. +/// We must define this, because gSOAP is using it as an external symbol +struct Namespace namespaces[] = + {{ "SOAP-ENV", "http://schemas.xmlsoap.org/soap/envelope/" }, // must be first + { "SOAP-ENC", "http://schemas.xmlsoap.org/soap/encoding/" }, // must be second + { "xsi", "http://www.w3.org/1999/XMLSchema-instance", "http://www.w3.org/*/XMLSchema-instance" }, + { "xsd", "http://www.w3.org/1999/XMLSchema", "http://www.w3.org/*/XMLSchema" }, + { "ns1", "urn:MaNGOS" }, // "ns1" namespace prefix + { nullptr, nullptr } + }; diff --git a/src/mangosd/remote/soap/MaNGOSsoap.h b/src/mangosd/remote/soap/MaNGOSsoap.h new file mode 100644 index 00000000000..a2c6b369aa1 --- /dev/null +++ b/src/mangosd/remote/soap/MaNGOSsoap.h @@ -0,0 +1,11 @@ +#ifndef MANGOSSOAP_H +#define MANGOSSOAP_H + +#include +#include +#include +#include "Platform/Define.h" + +std::unique_ptr StartSoapThread(std::string const& bindHost, uint16 bindPort); + +#endif diff --git a/src/realmd/AuthCodes.h b/src/realmd/AuthCodes.h index 3d8742a4d9c..8f6fa8ce391 100644 --- a/src/realmd/AuthCodes.h +++ b/src/realmd/AuthCodes.h @@ -26,7 +26,7 @@ #ifndef _AUTHCODES_H #define _AUTHCODES_H -enum eAuthCmd +enum eAuthCmd : uint8 { CMD_AUTH_LOGON_CHALLENGE = 0x00, CMD_AUTH_LOGON_PROOF = 0x01, @@ -35,10 +35,9 @@ enum eAuthCmd CMD_REALM_LIST = 0x10, CMD_XFER_INITIATE = 0x30, CMD_XFER_DATA = 0x31, - // these opcodes no longer exist in currently supported client CMD_XFER_ACCEPT = 0x32, CMD_XFER_RESUME = 0x33, - CMD_XFER_CANCEL = 0x34 + CMD_XFER_CANCEL = 0x34, }; // not used by us currently @@ -57,7 +56,7 @@ enum eAuthSrvCmd CMD_GRUNT_SUNKEN_ONLINE = 0x46 }; -enum AuthResult +enum AuthResult : uint8 { WOW_SUCCESS = 0x00, WOW_FAIL_UNKNOWN0 = 0x01, // Unknown0 - Unable to connect diff --git a/src/realmd/AuthPackets.h b/src/realmd/AuthPackets.h new file mode 100644 index 00000000000..91d900b405e --- /dev/null +++ b/src/realmd/AuthPackets.h @@ -0,0 +1,155 @@ +#ifndef MANGOS_AUTHPACKETS_H +#define MANGOS_AUTHPACKETS_H + +#include "Platform/Define.h" + +// GCC have alternative #pragma pack(N) syntax and old gcc version not support pack(push,N), also any gcc version not support it at some platform +#if defined( __GNUC__ ) +#pragma pack(1) +#else +#pragma pack(push,1) +#endif + +# define AUTH_LOGON_MAX_NAME 16 + +struct sAuthLogonChallengeHeader +{ + uint8 error; + uint16 size; +}; + +struct sAuthLogonChallengeBody +{ + uint8 gamename[4]; + uint8 version1; + uint8 version2; + uint8 version3; + uint16 build; + uint8 platform[4]; + uint8 os[4]; + uint8 country[4]; + uint32 timezone_bias; + uint32 ip; + uint8 username_len; + uint8 username[AUTH_LOGON_MAX_NAME + 1]; +}; + +//typedef sAuthLogonChallenge_C sAuthReconnectChallenge_C; +/* +struct sAuthLogonChallenge_S +{ + uint8 cmd; + uint8 error; + uint8 unk2; + uint8 B[32]; + uint8 g_len; + uint8 g[1]; + uint8 N_len; + uint8 N[32]; + uint8 s[32]; + uint8 unk3[16]; +}; +*/ + +struct sAuthLogonProof_C_Pre_1_11_0 +{ + uint8 A[32]; + uint8 M1[20]; + uint8 crc_hash[20]; + uint8 number_of_keys; +}; + +enum SecurityFlags : uint8 +{ + SECURITY_FLAG_NONE = 0x00, + SECURITY_FLAG_PIN = 0x01, // pin was added in 1.11.0 + SECURITY_FLAG_UNK = 0x02, + SECURITY_FLAG_AUTHENTICATOR = 0x04, // authenticator was added in 2.4.3 +}; + +struct sAuthLogonProof_C : public sAuthLogonProof_C_Pre_1_11_0 +{ + SecurityFlags securityFlags; // 0x00-0x04 // See enum SecurityFlags +}; + +struct PINData +{ + uint8 salt[16]; + uint8 hash[20]; +}; + +/* +struct sAuthLogonProofKey_C +{ + uint16 unk1; + uint32 unk2; + uint8 unk3[4]; + uint16 unk4[20]; +}; +*/ + +struct AUTH_LOGON_PROOF_S_BUILD_8089 +{ + uint8 cmd; + uint8 error; + uint8 M2[20]; + uint32 accountFlags; // see enum AccountFlags + uint32 surveyId; // SurveyId + uint16 loginFlags; // some flags (AccountMsgAvailable = 0x01) +}; + +struct AUTH_LOGON_PROOF_S_BUILD_6299 +{ + uint8 cmd; + uint8 error; + uint8 M2[20]; + uint32 surveyId; // SurveyId + uint16 loginFlags; // some flags (AccountMsgAvailable = 0x01) +}; + +struct AUTH_LOGON_PROOF_S +{ + uint8 cmd; + uint8 error; + uint8 M2[20]; + uint32 surveyId; // SurveyId +}; + +struct AUTH_RECONNECT_PROOF_C +{ + uint8 R1[16]; + uint8 R2[20]; + uint8 R3[20]; + uint8 number_of_keys; +}; + +struct XFER_INIT +{ + uint8 cmd; // XFER_INITIATE + uint8 fileTypeNameLength; // strlen(fileTypeName); // size without '\0' + uint8 fileTypeName[5]; // fileName[fileTypeNameLength] // As of 1.12 it can only be "Patch" or "Survey" // currently hardcoded to 5, because we only want to "Patch" + uint64 fileSize; // file size (bytes) + uint8 md5[16]; // MD5 of the file, so the client can verify if downloaded correctly or if present patch is correct +}; + +struct XFER_DATA_CHUNK +{ + uint8 cmd; // this must be CMD_XFER_DATA + uint16 data_size; + uint8 data[4096]; // 4096 - page size on most arch // TODO: Is this a client limitation? +}; + +namespace +{ + struct MANGOS_AUTHPACKETS_H_PackedSizeVerification { uint64 a; uint8 b; uint64 c; }; + static_assert(sizeof(MANGOS_AUTHPACKETS_H_PackedSizeVerification) == 17, "It appears that this area does not pack structs as it should be. This will cause misalignment errors when raw send() or recv() is performed with structs!"); +} + +// GCC have alternative #pragma pack() syntax and old gcc version not support pack(pop), also any gcc version not support it at some platform +#if defined( __GNUC__ ) +#pragma pack() +#else +#pragma pack(pop) +#endif + +#endif // MANGOS_AUTHPACKETS_H diff --git a/src/realmd/AuthSocket.cpp b/src/realmd/AuthSocket.cpp index 1c09f03944a..6446d2bd10c 100644 --- a/src/realmd/AuthSocket.cpp +++ b/src/realmd/AuthSocket.cpp @@ -32,8 +32,16 @@ #include "RealmList.h" #include "AuthSocket.h" #include "AuthCodes.h" -#include "PatchHandler.h" #include "Util.h" +#include "ClientPatchCache.h" +#include "Memory/NoDeleter.h" +#include "Errors.h" + +#include "IO/Networking/Utils.h" +#include "IO/Networking/AsyncSocket.h" +#include "IO/Timer/AsyncSystemTimer.h" +#include "IO/Filesystem/FileSystem.h" +#include "ProxyProtocol/ProxyV2Reader.h" #ifdef ENABLE_MAILSENDER #include "MailerService.h" @@ -43,10 +51,6 @@ #include //#include "Util.h" -- for commented utf8ToUpperOnlyLatin -#include -#include -#include - enum AccountFlags { ACCOUNT_FLAG_GM = 0x00000001, @@ -54,139 +58,45 @@ enum AccountFlags ACCOUNT_FLAG_PROPASS = 0x00800000, }; -// GCC have alternative #pragma pack(N) syntax and old gcc version not support pack(push,N), also any gcc version not support it at some paltform -#if defined( __GNUC__ ) -#pragma pack(1) -#else -#pragma pack(push,1) -#endif - -typedef struct AUTH_LOGON_CHALLENGE_C -{ - uint8 cmd; - uint8 error; - uint16 size; - uint8 gamename[4]; - uint8 version1; - uint8 version2; - uint8 version3; - uint16 build; - uint8 platform[4]; - uint8 os[4]; - uint8 country[4]; - uint32 timezone_bias; - uint32 ip; - uint8 I_len; - uint8 I[1]; -} sAuthLogonChallenge_C; - -//typedef sAuthLogonChallenge_C sAuthReconnectChallenge_C; -/* -typedef struct -{ - uint8 cmd; - uint8 error; - uint8 unk2; - uint8 B[32]; - uint8 g_len; - uint8 g[1]; - uint8 N_len; - uint8 N[32]; - uint8 s[32]; - uint8 unk3[16]; -} sAuthLogonChallenge_S; -*/ - -struct sAuthLogonProof_C_Base +typedef struct AuthHandler { - uint8 cmd; - uint8 A[32]; - uint8 M1[20]; - uint8 crc_hash[20]; - uint8 number_of_keys; -}; + eAuthCmd cmd; + uint32 status; + void (AuthSocket::*asyncHandler)(); +} AuthHandler; -struct sAuthLogonProof_C_1_11 : public sAuthLogonProof_C_Base -{ - uint8 securityFlags; // 0x00-0x04 -}; -/* -typedef struct -{ - uint16 unk1; - uint32 unk2; - uint8 unk3[4]; - uint16 unk4[20]; -} sAuthLogonProofKey_C; -*/ -typedef struct AUTH_LOGON_PROOF_S_BUILD_8089 -{ - uint8 cmd; - uint8 error; - uint8 M2[20]; - uint32 accountFlags; // see enum AccountFlags - uint32 surveyId; // SurveyId - uint16 loginFlags; // some flags (AccountMsgAvailable = 0x01) -} sAuthLogonProof_S_BUILD_8089; - -typedef struct AUTH_LOGON_PROOF_S_BUILD_6299 -{ - uint8 cmd; - uint8 error; - uint8 M2[20]; - uint32 surveyId; // SurveyId - uint16 loginFlags; // some flags (AccountMsgAvailable = 0x01) -} sAuthLogonProof_S_BUILD_6299; - -typedef struct AUTH_LOGON_PROOF_S -{ - uint8 cmd; - uint8 error; - uint8 M2[20]; - uint32 surveyId; // SurveyId -} sAuthLogonProof_S; +std::array VersionChallenge = { { 0xBA, 0xA3, 0x1E, 0x99, 0xA0, 0x0B, 0x21, 0x57, 0xFC, 0x37, 0x3F, 0xB3, 0x69, 0xCD, 0xD2, 0xF1 } }; -typedef struct AUTH_RECONNECT_PROOF_C -{ - uint8 cmd; - uint8 R1[16]; - uint8 R2[20]; - uint8 R3[20]; - uint8 number_of_keys; -} sAuthReconnectProof_C; - -typedef struct XFER_INIT +// Accept the connection and set the s random value for SRP6 // TODO where is this SRP6 done? +AuthSocket::AuthSocket(IO::Networking::AsyncSocket socket) : + m_socket(std::move(socket)), + m_remoteIpAddressStringAfterProxy(m_socket.GetRemoteIpString()) { - uint8 cmd; // XFER_INITIATE - uint8 fileNameLen; // strlen(fileName); - uint8 fileName[5]; // fileName[fileNameLen] - uint64 file_size; // file size (bytes) - uint8 md5[Crypto::Hash::MD5::Digest::size()]; // MD5 -}XFER_INIT; +} -typedef struct AuthHandler +void AuthSocket::Start() { - eAuthCmd cmd; - uint32 status; - bool (AuthSocket::*handler)(void); -}AuthHandler; - -// GCC have alternative #pragma pack() syntax and old gcc version not support pack(pop), also any gcc version not support it at some paltform -#if defined( __GNUC__ ) -#pragma pack() -#else -#pragma pack(pop) -#endif - -#define AUTH_TOTAL_COMMANDS sizeof(table)/sizeof(AuthHandler) + if (int secs = sConfig.GetIntDefault("MaxSessionDuration", 300)) + { + m_sessionDurationTimeout = sAsyncSystemTimer.ScheduleFunctionOnce(std::chrono::seconds(secs), [this]() + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[%s] Connection has reached MaxSessionDuration. Closing socket...", this->GetRemoteIpString().c_str()); + // It's correct that we capture _this_ and not a shared_ptr, since the timer will be canceled in destructor + this->CloseSocket(); + }); + } -std::array VersionChallenge = { { 0xBA, 0xA3, 0x1E, 0x99, 0xA0, 0x0B, 0x21, 0x57, 0xFC, 0x37, 0x3F, 0xB3, 0x69, 0xCD, 0xD2, 0xF1 } }; + DoRecvIncomingData(); +} -// Close patch file descriptor before leaving AuthSocket::~AuthSocket() { - if(m_patch != ACE_INVALID_HANDLE) - ACE_OS::close(m_patch); + CloseSocket(); + + if (m_sessionDurationTimeout) + m_sessionDurationTimeout->Cancel(); + + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] Connection closed", GetRemoteIpString().c_str()); } AccountTypes AuthSocket::GetSecurityOn(uint32 realmId) const @@ -197,467 +107,526 @@ AccountTypes AuthSocket::GetSecurityOn(uint32 realmId) const return it->second; } -// Accept the connection and set the s random value for SRP6 -void AuthSocket::OnAccept() -{ - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Accepting connection from '%s'", get_remote_address().c_str()); -} - // Read the packet from the client -void AuthSocket::OnRead() +void AuthSocket::DoRecvIncomingData() { - // benchmarking has demonstrated that this lookup method is faster than std::map - const static AuthHandler table[] = - { - { CMD_AUTH_LOGON_CHALLENGE, STATUS_CHALLENGE, &AuthSocket::_HandleLogonChallenge }, - { CMD_AUTH_LOGON_PROOF, STATUS_LOGON_PROOF, &AuthSocket::_HandleLogonProof }, - { CMD_AUTH_RECONNECT_CHALLENGE, STATUS_CHALLENGE, &AuthSocket::_HandleReconnectChallenge }, - { CMD_AUTH_RECONNECT_PROOF, STATUS_RECON_PROOF, &AuthSocket::_HandleReconnectProof }, - { CMD_REALM_LIST, STATUS_AUTHED, &AuthSocket::_HandleRealmList }, - { CMD_XFER_ACCEPT, STATUS_PATCH, &AuthSocket::_HandleXferAccept }, - { CMD_XFER_RESUME, STATUS_PATCH, &AuthSocket::_HandleXferResume }, - { CMD_XFER_CANCEL, STATUS_PATCH, &AuthSocket::_HandleXferCancel } - }; - - uint8 _cmd; - while (1) + std::shared_ptr cmd = std::make_shared(); + + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[%s] DoRecvIncomingData() Reading... Ready for next opcode", GetRemoteIpString().c_str()); + m_socket.Read((char*)cmd.get(), sizeof(eAuthCmd), [self = shared_from_this(), cmd](IO::NetworkError const& error, size_t) -> void { - if(!recv_soft((char *)&_cmd, 1)) + if (error) + { + if (error.GetErrorType() != IO::NetworkError::ErrorType::SocketClosed) + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[%s] DoRecvIncomingData Read(cmd) error: %s", self->GetRemoteIpString().c_str(), error.ToString().c_str()); return; + } - size_t i; + // benchmarking has demonstrated that this lookup method is faster than std::map + constexpr AuthHandler table[] = + { + { CMD_AUTH_LOGON_CHALLENGE, STATUS_CHALLENGE, &AuthSocket::_HandleLogonChallenge }, + { CMD_AUTH_LOGON_PROOF, STATUS_LOGON_PROOF, &AuthSocket::_HandleLogonProof }, + { CMD_AUTH_RECONNECT_CHALLENGE, STATUS_CHALLENGE, &AuthSocket::_HandleReconnectChallenge }, + { CMD_AUTH_RECONNECT_PROOF, STATUS_RECON_PROOF, &AuthSocket::_HandleReconnectProof }, + { CMD_REALM_LIST, STATUS_AUTHED, &AuthSocket::_HandleRealmList }, + { CMD_XFER_ACCEPT, STATUS_PATCH, &AuthSocket::_HandleXferAccept }, + { CMD_XFER_RESUME, STATUS_PATCH, &AuthSocket::_HandleXferResume }, + { CMD_XFER_CANCEL, STATUS_PATCH, &AuthSocket::_HandleXferCancel } + }; + + constexpr size_t tableLength = sizeof(table) / sizeof(AuthHandler); + size_t i; // Circle through known commands and call the correct command handler - for (i = 0; i < AUTH_TOTAL_COMMANDS; ++i) + for (i = 0; i < tableLength; ++i) { - if (table[i].cmd != _cmd) + if (table[i].cmd != *cmd) continue; - // unauthorized - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[Auth] Status %u, table status %u", m_status, table[i].status); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[%s] CMD: %u requires status %u, user has %u", self->GetRemoteIpString().c_str(), *cmd, table[i].status, self->m_status); - if (table[i].status != m_status) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[Auth] Received unauthorized command %u length %u", _cmd, (uint32)recv_len()); + if (table[i].status != self->m_status) + { // unauthorized + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[%s] Received unauthorized command %u", self->GetRemoteIpString().c_str(), *cmd); return; } - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[Auth] Got data for cmd %u recv length %u", _cmd, (uint32)recv_len()); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[%s] Got data for cmd %u", self->GetRemoteIpString().c_str(), *cmd); - if (!(*this.*table[i].handler)()) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[Auth] Command handler failed for cmd %u recv length %u", _cmd, (uint32)recv_len()); - close_connection(); - return; - } + // this handler will async call Read and Write, and hopefully will call DoRecvIncomingData or CloseSocket when done. + ((*self).*table[i].asyncHandler)(); break; } // Report unknown commands in the debug log - if (i == AUTH_TOTAL_COMMANDS) + if (i == tableLength) { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[Auth] got unknown packet %u", (uint32)_cmd); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[Auth] got unknown packet cmd %u", *cmd); return; } - } + + // if we reach here, it means that a valid opcode was found and the handler completed successfully + // TODO: self->m_timeoutTimer.reset(); + }); } -void AuthSocket::SendProof(Crypto::Hash::SHA1::Digest sha) +std::shared_ptr AuthSocket::GenerateLogonProofResponse(Crypto::Hash::SHA1::Digest const& shaDigest) { + std::shared_ptr pkt(new ByteBuffer()); + if (m_build < 6299) // before version 2.0.3 (exclusive) { - sAuthLogonProof_S proof; - memcpy(proof.M2, sha.data(), sha.size()); + AUTH_LOGON_PROOF_S proof{}; + memcpy(proof.M2, shaDigest.data(), 20); proof.cmd = CMD_AUTH_LOGON_PROOF; proof.error = 0; proof.surveyId = 0x00000000; - send((char *)&proof, sizeof(proof)); + pkt->append(&proof, 1); } else if (m_build < 8089) // before version 2.4.0 (exclusive) { - sAuthLogonProof_S_BUILD_6299 proof; - memcpy(proof.M2, sha.data(), sha.size()); + AUTH_LOGON_PROOF_S_BUILD_6299 proof{}; + memcpy(proof.M2, shaDigest.data(), 20); proof.cmd = CMD_AUTH_LOGON_PROOF; proof.error = 0; proof.surveyId = 0x00000000; proof.loginFlags = 0x0000; - send((char *)&proof, sizeof(proof)); + pkt->append(&proof, 1); } else { - sAuthLogonProof_S_BUILD_8089 proof; - memcpy(proof.M2, sha.data(), sha.size()); + AUTH_LOGON_PROOF_S_BUILD_8089 proof{}; + memcpy(proof.M2, shaDigest.data(), 20); proof.cmd = CMD_AUTH_LOGON_PROOF; proof.error = 0; proof.accountFlags = ACCOUNT_FLAG_PROPASS; proof.surveyId = 0x00000000; proof.loginFlags = 0x0000; - send((char *)&proof, sizeof(proof)); + pkt->append(&proof, 1); } + + return pkt; } // Logon Challenge command handler -bool AuthSocket::_HandleLogonChallenge() +void AuthSocket::_HandleLogonChallenge() { sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleLogonChallenge"); - if (recv_len() < sizeof(sAuthLogonChallenge_C)) - return false; - - // Read the first 4 bytes (header) to get the length of the remaining of the packet - std::vector buf; - buf.resize(4); - - recv((char *)&buf[0], 4); + m_status = STATUS_INVALID; - EndianConvert(*((uint16*)(&buf[0]))); - uint16 remaining = ((sAuthLogonChallenge_C *)&buf[0])->size; - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] got header, body is %#04x bytes", remaining); + std::shared_ptr header = std::make_shared(); - if ((remaining < sizeof(sAuthLogonChallenge_C) - buf.size()) || (recv_len() < remaining)) - return false; - - // Session is closed unless overriden - m_status = STATUS_CLOSED; - - // No big fear of memory outage (size is int16, i.e. < 65536) - buf.resize(remaining + buf.size() + 1); - buf[buf.size() - 1] = 0; - sAuthLogonChallenge_C *ch = (sAuthLogonChallenge_C*)&buf[0]; - - // Read the remaining of the packet - recv((char *)&buf[4], remaining); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] got full packet, %#04x bytes", ch->size); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] name(%d): '%s'", ch->I_len, ch->I); - - // BigEndian code, nop in little endian case - // size already converted - EndianConvert(*((uint32*)(&ch->gamename[0]))); - EndianConvert(ch->build); - EndianConvert(*((uint32*)(&ch->os[0]))); - EndianConvert(*((uint32*)(&ch->country[0]))); - EndianConvert(ch->timezone_bias); - EndianConvert(ch->ip); - - ByteBuffer pkt; - - m_login = (const char*)ch->I; - m_build = ch->build; - - ch->os[3] = '\0'; - std::reverse(ch->os, ch->os + 3); - memcpy(&m_os, ch->os, sizeof(m_os)); - - ch->platform[3] = '\0'; - std::reverse(ch->platform, ch->platform + 3); - memcpy(&m_platform, ch->platform, sizeof(m_platform)); - - // Normalize account name - // utf8ToUpperOnlyLatin(m_login); -- client already send account in expected form - - // Escape the user login to avoid further SQL injection - // Memory will be freed on AuthSocket object destruction - m_safelogin = m_login; - LoginDatabase.escape_string(m_safelogin); - - pkt << (uint8) CMD_AUTH_LOGON_CHALLENGE; - pkt << (uint8) 0x00; - - // Verify that this IP is not in the ip_banned table - // No SQL injection possible (paste the IP address as passed by the socket) - std::string address = get_remote_address(); - LoginDatabase.escape_string(address); - std::unique_ptr result = LoginDatabase.PQuery("SELECT `unbandate` FROM `ip_banned` WHERE " - // permanent still banned - "(`unbandate` = `bandate` OR `unbandate` > UNIX_TIMESTAMP()) AND `ip` = '%s'", address.c_str()); - if (result) - { - pkt << (uint8)WOW_FAIL_DB_BUSY; - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Banned ip '%s' tries to login with account '%s'!", get_remote_address().c_str(), m_login.c_str()); - } - else + // Read the header first, to get the length of the remaining packet + m_socket.Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error, size_t) -> void { - // Get the account details from the account table - // No SQL injection (escaped user name) - // 0 1 2 3 4 5 6 7 8 9 - result = LoginDatabase.PQuery("SELECT `id`, `locked`, `last_ip`, `v`, `s`, `security`, `email_verif`, `geolock_pin`, `email`, UNIX_TIMESTAMP(`joindate`) FROM `account` WHERE `username` = '%s'",m_safelogin.c_str ()); - if (result) + if (error) { - Field* fields = result->Fetch(); - - // Prevent login if the user's email address has not been verified - bool requireVerification = sConfig.GetBoolDefault("ReqEmailVerification", false); - int32 requireEmailSince = sConfig.GetIntDefault("ReqEmailSince", 0); - bool verified = (*result)[6].GetBool(); - - // Prevent login if the user's join date is bigger than the timestamp in configuration - if (requireEmailSince > 0) - { - uint32 t = (*result)[9].GetUInt32(); - requireVerification = requireVerification && (t >= uint32(requireEmailSince)); - } - - if (requireVerification && !verified) - { - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s 'email address requires email verification - rejecting login", m_login.c_str(), get_remote_address().c_str()); - pkt << (uint8)WOW_FAIL_UNKNOWN_ACCOUNT; - send((char const*)pkt.contents(), pkt.size()); - return true; - } + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[Auth] HandleLogonChallenge Read(header) error"); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } - // If the IP is 'locked', check that the player comes indeed from the correct IP address - bool locked = false; - m_lockFlags = (LockFlag)(*result)[1].GetUInt32(); - m_securityInfo = (*result)[5].GetCppString(); - m_lastIP = fields[2].GetString(); - m_geoUnlockPIN = fields[7].GetUInt32(); - m_email = fields[8].GetCppString(); + uint16* pUint16 = reinterpret_cast(header.get()); + EndianConvert(*pUint16); + uint16 actualBodySize = header->size; - if (m_lockFlags & IP_LOCK) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account '%s' is locked to IP - '%s'", m_login.c_str(), m_lastIP.c_str()); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Player address is '%s'", get_remote_address().c_str()); - - if (m_lastIP != get_remote_address()) - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account IP differs"); + if ((actualBodySize < sizeof(sAuthLogonChallengeBody) - AUTH_LOGON_MAX_NAME)) + { // The paket is too small and has no username??? + return; + } - // account is IP locked and the player does not have 2FA enabled - if (((m_lockFlags & TOTP) != TOTP && (m_lockFlags & FIXED_PIN) != FIXED_PIN)) - pkt << (uint8) WOW_FAIL_SUSPENDED; + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] got header, body is %#04x bytes", actualBodySize); - locked = true; - } - else - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account IP matches"); - } - } - else + // Read the remaining of the packet + std::shared_ptr body = std::make_shared(); + self->m_socket.Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error, size_t) + { + if (error) { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account '%s' is not locked to ip", m_login.c_str()); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->m_socket.Read(body): ERROR"); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; } - std::string databaseV = fields[3].GetCppString(); - std::string databaseS = fields[4].GetCppString(); - bool broken = false; - - if (!srp.SetVerifier(databaseV.c_str()) || !srp.SetSalt(databaseS.c_str())) + if (body->username_len > AUTH_LOGON_MAX_NAME) + return; + body->username[body->username_len] = '\0'; + + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] got full packet, %#04x bytes", header->size); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] name(%d): '%s'", body->username_len, body->username); + + // BigEndian code, nop in little endian case + // size already converted + EndianConvert(*((uint32*)(&body->gamename[0]))); + EndianConvert(body->build); + EndianConvert(*((uint32*)(&body->platform[0]))); + EndianConvert(*((uint32*)(&body->os[0]))); + EndianConvert(*((uint32*)(&body->country[0]))); + EndianConvert(body->timezone_bias); + EndianConvert(body->ip); + + std::shared_ptr pkt = std::make_shared(); + + self->m_build = body->build; + + // Convert uint8[4] to string, restore string order as its byte order is reversed + // To it for os + body->os[3] = '\0'; + self->m_os = (char*)body->os; + std::reverse(self->m_os.begin(), self->m_os.end()); + // To it for platform + body->platform[3] = '\0'; + self->m_platform = (char*)body->platform; + std::reverse(self->m_platform.begin(), self->m_platform.end()); + // Do it for locale + self->m_localizationName.resize(sizeof(body->country)); + self->m_localizationName.assign(body->country, (body->country + sizeof(body->country))); + std::reverse(self->m_localizationName.begin(), self->m_localizationName.end()); + + // Escape the user input used in DB to avoid further SQL injection + // Memory will be freed on AuthSocket object destruction + self->m_login = (const char*)body->username; + self->m_safelogin = self->m_login; + LoginDatabase.escape_string(self->m_safelogin); + + *pkt << (uint8) CMD_AUTH_LOGON_CHALLENGE; + *pkt << (uint8) 0x00; + + // Verify that this IP is not in the ip_banned table + // No SQL injection possible (paste the IP address as passed by the socket) + // permanent ban OR still banned + std::unique_ptr sqlIpBanResult = LoginDatabase.PQuery("SELECT `unbandate` FROM `ip_banned` WHERE (`unbandate` = `bandate` OR `unbandate` > UNIX_TIMESTAMP()) AND `ip` = '%s'", self->GetRemoteIpString().c_str()); + if (sqlIpBanResult) { - pkt << (uint8)WOW_FAIL_FAIL_NOACCESS; - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Broken v/s values in database for account %s!", m_login.c_str()); - broken = true; + *pkt << uint8(WOW_FAIL_FAIL_NOACCESS); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Banned ip '%s' tries to login with account '%s'!", self->GetRemoteIpString().c_str(), self->m_login.c_str()); } - - if ((!locked || (locked && (m_lockFlags & FIXED_PIN || m_lockFlags & TOTP))) && !broken) + else { - uint32 account_id = fields[0].GetUInt32(); - // If the account is banned, reject the logon attempt - std::unique_ptr banResult = LoginDatabase.PQuery("SELECT `bandate`, `unbandate` FROM `account_banned` WHERE " - "`id` = %u AND `active` = 1 AND (`unbandate` > UNIX_TIMESTAMP() OR `unbandate` = `bandate`) LIMIT 1", account_id); - if (banResult) + // Get the account details from the account table + // No SQL injection (escaped username) + // 0 1 2 3 4 5 6 7 8 9 + std::unique_ptr sqlAccountResult = LoginDatabase.PQuery("SELECT `id`, `locked`, `last_ip`, `v`, `s`, `security`, `email_verif`, `geolock_pin`, `email`, UNIX_TIMESTAMP(`joindate`) FROM `account` WHERE `username` = '%s'", self->m_safelogin.c_str()); + if (sqlAccountResult) { - if((*banResult)[0].GetUInt64() == (*banResult)[1].GetUInt64()) - { - pkt << (uint8) WOW_FAIL_BANNED; - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Banned account '%s' using IP '%s' tries to login!",m_login.c_str (), get_remote_address().c_str()); - } - else - { - pkt << (uint8) WOW_FAIL_SUSPENDED; - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Temporarily banned account '%s' using IP '%s' tries to login!",m_login.c_str (), get_remote_address().c_str()); - } - } - else - { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "database authentication values: v='%s' s='%s'", databaseV.c_str(), databaseS.c_str()); - - BigNumber s; - s.SetHexStr(databaseS.c_str()); - - srp.CalculateHostPublicEphemeral(); - - // Fill the response packet with the result - pkt << uint8(WOW_SUCCESS); - - // B may be calculated < 32B so we force minimal length to 32B - pkt.append(srp.GetHostPublicEphemeral().AsByteArray(32).data(), 32); // 32 bytes - pkt << uint8(1); - pkt.append(srp.GetGeneratorModulo().AsByteArray().data(), 1); - pkt << uint8(32); - pkt.append(srp.GetPrime().AsByteArray(32).data(), 32); - pkt.append(s.AsByteArray()); // 32 bytes - pkt.append(VersionChallenge.data(), VersionChallenge.size()); + Field* fields = sqlAccountResult->Fetch(); - // figure out whether we need to display the PIN grid - m_promptPin = locked; // always prompt if the account is IP locked & 2FA is enabled + // Prevent login if the user's email address has not been verified + bool requireVerification = sConfig.GetBoolDefault("ReqEmailVerification", false); + int32 requireEmailSince = sConfig.GetIntDefault("ReqEmailSince", 0); + bool isVerified = fields[6].GetBool(); - if ((!locked && ((m_lockFlags & ALWAYS_ENFORCE) == ALWAYS_ENFORCE)) || m_geoUnlockPIN) + // Prevent login if the user's join date is bigger than the timestamp in configuration + if (requireEmailSince > 0) { - m_promptPin = true; // prompt if the lock hasn't been triggered but ALWAYS_ENFORCE is set + uint32 t = fields[9].GetUInt32(); + requireVerification = requireVerification && (t >= uint32(requireEmailSince)); } - if (m_promptPin) + if (requireVerification && !isVerified) { - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' requires PIN authentication", m_login.c_str(), get_remote_address().c_str()); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s 'email address requires email verification - rejecting login", self->m_login.c_str(), self->GetRemoteIpString().c_str()); + *pkt << (uint8) WOW_FAIL_UNKNOWN_ACCOUNT; + + self->m_socket.Write(std::move(pkt), [self](IO::NetworkError const& error) { + if (error) + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write() Error: %s", error.ToString().c_str()); + else + self->DoRecvIncomingData(); + }); + return; // TODO refactor? + } - uint32 gridSeedPkt = m_gridSeed = static_cast(rand32()); - EndianConvert(gridSeedPkt); - m_serverSecuritySalt.SetRand(16 * 8); // 16 bytes random + // If the IP is 'locked', check that the player comes indeed from the correct IP address + bool locked = false; + self->m_lockFlags = (LockFlag)fields[1].GetUInt32(); + self->m_securityInfo = fields[5].GetCppString(); + self->m_lastIP = fields[2].GetString(); + self->m_geoUnlockPIN = fields[7].GetUInt32(); + self->m_email = fields[8].GetCppString(); - pkt << uint8(1); // securityFlags, only '1' is available in classic (PIN input) - pkt << gridSeedPkt; - pkt.append(m_serverSecuritySalt.AsByteArray(16).data(), 16); + if (self->m_lockFlags & IP_LOCK) + { + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account '%s' is locked to IP - '%s'", self->m_login.c_str(), self->m_lastIP.c_str()); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Player address is '%s'", self->GetRemoteIpString().c_str()); + + if (self->m_lastIP != self->GetRemoteIpString()) + { + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account IP differs"); + + // account is IP locked and the player does not have 2FA enabled + if (((self->m_lockFlags & TOTP) != TOTP && (self->m_lockFlags & FIXED_PIN) != FIXED_PIN)) + *pkt << (uint8) WOW_FAIL_SUSPENDED; + + locked = true; + } + else + { + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account IP matches"); + } } else { - if (m_build >= 5428) // version 1.11.0 or later - pkt << uint8(0); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Account '%s' is not locked to ip", self->m_login.c_str()); } - m_localizationName.resize(4); - for(int i = 0; i < 4; ++i) - m_localizationName[i] = ch->country[4-i-1]; - - LoadAccountSecurityLevels(account_id); - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' is using '%c%c%c%c' locale (%u)", m_login.c_str (), get_remote_address().c_str(), ch->country[3], ch->country[2], ch->country[1], ch->country[0], GetLocaleByName(m_localizationName)); + std::string databaseV = fields[3].GetCppString(); + std::string databaseS = fields[4].GetCppString(); + bool broken = false; - m_accountId = account_id; + if (!self->srp.SetVerifier(databaseV.c_str()) || !self->srp.SetSalt(databaseS.c_str())) + { + *pkt << uint8(WOW_FAIL_FAIL_NOACCESS); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[AuthChallenge] Broken v/s values in database for account %s!", self->m_login.c_str()); + broken = true; + } - // All good, await client's proof - m_status = STATUS_LOGON_PROOF; + if ((!locked || (locked && (self->m_lockFlags & FIXED_PIN || self->m_lockFlags & TOTP))) && !broken) + { + uint32 pendingAccountId = fields[0].GetUInt32(); + + // If the account is banned, reject the logon attempt + std::unique_ptr sqlAccountBanResult = LoginDatabase.PQuery("SELECT `bandate`, `unbandate` FROM `account_banned` WHERE `id` = %u AND `active` = 1 AND (`unbandate` > UNIX_TIMESTAMP() OR `unbandate` = `bandate`) LIMIT 1", pendingAccountId); + if (sqlAccountBanResult) + { + uint64_t banTimestamp = (*sqlAccountBanResult)[0].GetUInt64(); + uint64_t unbanTimestamp = (*sqlAccountBanResult)[1].GetUInt64(); + if (banTimestamp == unbanTimestamp) + { + *pkt << (uint8) WOW_FAIL_BANNED; + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Banned account '%s' using IP '%s' tries to login!", self->m_login.c_str(), self->GetRemoteIpString().c_str()); + } + else + { + *pkt << (uint8) WOW_FAIL_SUSPENDED; + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Temporarily banned account '%s' using IP '%s' tries to login!", self->m_login.c_str(), self->GetRemoteIpString().c_str()); + } + } + else + { + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "database authentication values: v='%s' s='%s'", databaseV.c_str(), databaseS.c_str()); + + BigNumber s; + s.SetHexStr(databaseS.c_str()); + + self->srp.CalculateHostPublicEphemeral(); + + // Fill the response packet with the result + *pkt << uint8(WOW_SUCCESS); + + // B may be calculated < 32B so we force minimal length to 32B + pkt->append(self->srp.GetHostPublicEphemeral().AsByteArray(32)); // 32 bytes + *pkt << uint8(1); + pkt->append(self->srp.GetGeneratorModulo().AsByteArray()); + *pkt << uint8(32); + pkt->append(self->srp.GetPrime().AsByteArray(32)); + pkt->append(s.AsByteArray());// 32 bytes + pkt->append(VersionChallenge.data(), VersionChallenge.size()); + + // figure out whether we need to display the PIN grid + self->m_promptPin = locked; // always prompt if the account is IP locked & 2FA is enabled + + if ((!locked && ((self->m_lockFlags & ALWAYS_ENFORCE) == ALWAYS_ENFORCE)) || self->m_geoUnlockPIN) + { + self->m_promptPin = true; // prompt if the lock hasn't been triggered but ALWAYS_ENFORCE is set + } + + if (self->m_promptPin) + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' requires PIN authentication", self->m_login.c_str(), self->GetRemoteIpString().c_str()); + + uint32 gridSeedPkt = self->m_gridSeed = static_cast(rand32()); + EndianConvert(gridSeedPkt); + self->m_serverSecuritySalt.SetRand(16 * 8); // 16 bytes random + + *pkt << uint8(1); // securityFlags, only '1' is available in classic (PIN input) + *pkt << gridSeedPkt; + pkt->append(self->m_serverSecuritySalt.AsByteArray(16).data(), 16); + } + else + { + if (self->m_build >= 5428) // version 1.11.0 or later + *pkt << uint8(0); + } + + self->LoadAccountSecurityLevels(pendingAccountId); + self->m_accountId = pendingAccountId; + + // All good, await client's proof + self->m_status = STATUS_LOGON_PROOF; + } + } + } + else + { // no account + *pkt << (uint8) WOW_FAIL_UNKNOWN_ACCOUNT; } } - } - else // no account - { - pkt<< (uint8) WOW_FAIL_UNKNOWN_ACCOUNT; - } - } - send((char const*)pkt.contents(), pkt.size()); - return true; + + self->m_socket.Write(std::move(pkt), [self](IO::NetworkError const& error) + { + if (error) + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write() Error: %s", error.ToString().c_str()); + else + self->DoRecvIncomingData(); + }); + }); + }); } // Logon Proof command handler -bool AuthSocket::_HandleLogonProof() +void AuthSocket::_HandleLogonProof() { sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleLogonProof"); + m_status = STATUS_INVALID; - sAuthLogonProof_C_1_11 lp; - // Read the packet - if (m_build < 5428) // before version 1.11.0 (exclusive) - { - if (!recv((char *)&lp, sizeof(sAuthLogonProof_C_Base))) - return false; - lp.securityFlags = 0; - } - else - { - if (!recv((char *)&lp, sizeof(sAuthLogonProof_C_1_11))) - return false; + std::shared_ptr lp = std::make_shared(); + size_t expectedSize = sizeof(sAuthLogonProof_C); + if (m_build < 5428) { // Pin support was added in 1.11.0, so if an older client connects, we need to skip those fields + lp->securityFlags = SECURITY_FLAG_NONE; + expectedSize = sizeof(sAuthLogonProof_C_Pre_1_11_0); } - PINData pinData; - - if (lp.securityFlags) + m_socket.Read((char*) lp.get(), expectedSize, [self = shared_from_this(), lp](IO::NetworkError const& error, size_t) { - if (!recv((char*)&pinData, sizeof(pinData))) - return false; - } - - // Check if the client has one of the expected version numbers - bool valid_version = FindBuildInfo(m_build) != nullptr; - - // Session is closed unless overriden - m_status = STATUS_CLOSED; - - //
  • If the client has no valid version - if(!valid_version) - { - if (this->m_patch != ACE_INVALID_HANDLE) - return false; - - // Check if we have the apropriate patch on the disk - // file looks like: 65535enGB.mpq - char tmp[256]; - - snprintf(tmp, 256, "%s/%d%s.mpq", sConfig.GetStringDefault("PatchesDir","./patches").c_str(), m_build, m_localizationName.c_str()); - - char filename[PATH_MAX]; - if (ACE_OS::realpath(tmp, filename) != nullptr) + if (error) { - m_patch = ACE_OS::open(filename, GENERIC_READ | FILE_FLAG_SEQUENTIAL_SCAN); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge Read(): ERROR"); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; } - if (m_patch == ACE_INVALID_HANDLE) + if (lp->securityFlags) { - // no patch found - ByteBuffer pkt; - pkt << (uint8) CMD_AUTH_LOGON_CHALLENGE; - pkt << (uint8) 0x00; - pkt << (uint8) WOW_FAIL_VERSION_INVALID; - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] %u is not a valid client version!", m_build); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Patch %s not found", tmp); - send((char const*)pkt.contents(), pkt.size()); - return true; + if (!(lp->securityFlags & SECURITY_FLAG_PIN)) + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "_HandleLogonChallenge Invalid/Unsupported securityFlags: %u", lp->securityFlags); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } + + std::shared_ptr pinData(new PINData()); + self->m_socket.Read((char*) pinData.get(), sizeof(PINData), [self, lp, pinData](IO::NetworkError const& error, size_t) + { + self->_HandleLogonProof__PostRecv(lp, pinData); + }); + return; } - XFER_INIT xferh; + self->_HandleLogonProof__PostRecv(lp, nullptr); + }); +} + +void AuthSocket::_HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_ptr const& lp) +{ + if (m_pendingPatchFile) + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "_HandleLogonProof__PostRecv m_patch is already set?? The client should accept the XFER!"); + return; + } + + // Check if we have the apropriate patch on the disk + // file looks like: 65535enGB.mpq + char tmp[256]; - ACE_OFF_T file_size = ACE_OS::filesize(this->m_patch); + snprintf(tmp, 256, "%s/%d%s.mpq", sConfig.GetStringDefault("PatchesDir","./patches").c_str(), m_build, m_localizationName.c_str()); - if (file_size == -1) - { - close_connection(); - return false; - } + std::string pathFilePath = IO::Filesystem::ToAbsolutePath(tmp); + m_pendingPatchFile = IO::Filesystem::TryOpenFileReadonly(pathFilePath); - if (!PatchCache::instance()->GetHash(tmp, (uint8*)&xferh.md5)) + if (m_pendingPatchFile == nullptr) + { + // no patch found + std::shared_ptr pkt(new ByteBuffer()); + *pkt << (uint8) CMD_AUTH_LOGON_CHALLENGE; + *pkt << (uint8) 0x00; + *pkt << (uint8) WOW_FAIL_VERSION_INVALID; + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] %u is not a valid client version!", m_build); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Patch %s not found", tmp); + m_socket.Write(std::move(pkt), [self = shared_from_this(), pkt](IO::NetworkError const& error) { - // calculate patch md5, happens if patch was added while realmd was running - PatchCache::instance()->LoadPatchMD5(tmp); - PatchCache::instance()->GetHash(tmp, (uint8*)&xferh.md5); - } + if (error) + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "_HandleLogonProof__PostRecv Write(...) failed"); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } + self->DoRecvIncomingData(); + }); + } + else + { + Crypto::Hash::MD5::Digest md5Hash = sRealmdPatchCache.GetOrCalculateHash(m_pendingPatchFile); + std::string wowClientPathType = "Patch"; // Must be patch "Patch" + MANGOS_ASSERT(wowClientPathType.size() <= 255); // Filename must fit inside a byte - uint8 data[2] = { CMD_AUTH_LOGON_PROOF, WOW_FAIL_VERSION_UPDATE}; - send((const char*)data, sizeof(data)); + std::shared_ptr pkt(new ByteBuffer()); - memcpy(&xferh, "0\x05Patch", 7); - xferh.cmd = CMD_XFER_INITIATE; - xferh.file_size = file_size; + // packet 1 + *pkt << (uint8) CMD_AUTH_LOGON_PROOF; + *pkt << (uint8) WOW_FAIL_VERSION_UPDATE; - send((const char*)&xferh, sizeof(xferh)); + // packet 2 - XFER_INIT + XFER_INIT initPkt{}; + initPkt.cmd = CMD_XFER_INITIATE; + initPkt.fileTypeNameLength = wowClientPathType.size(); + memcpy(initPkt.fileTypeName, wowClientPathType.c_str(), wowClientPathType.size()); + initPkt.fileSize = m_pendingPatchFile->GetTotalFileSize(); + memcpy(initPkt.md5, md5Hash.data(), md5Hash.size()); + pkt->append(&initPkt, 1); // Set right status m_status = STATUS_PATCH; - return true; + m_socket.Write(std::move(pkt), [self = shared_from_this()](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); + } +} + +void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr const& lp, std::shared_ptr const& pinData) +{ + MANGOS_ASSERT(!lp->securityFlags || pinData.get() != nullptr); // PinData must be present, when securityFlags is set + + // Check if the client has one of the expected version numbers + bool valid_version = FindBuildInfo(m_build) != nullptr; + + // If the client has no valid version + if(!valid_version) + { + _HandleLogonProof__PostRecv_HandleInvalidVersion(lp); + return; } - //
// Continue the SRP6 calculation based on data received from the client - if (!srp.CalculateSessionKey(lp.A, 32)) - return false; + if (!srp.CalculateSessionKey(lp->A, 32)) + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Session calculation failed for account %s!", this->m_login.c_str()); + return; + } srp.HashSessionKey(); - srp.CalculateProof(m_login); + srp.CalculateProof(this->m_login); // Check PIN data is correct bool pinResult = true; - if (m_promptPin && !lp.securityFlags) + if (m_promptPin && !lp->securityFlags) pinResult = false; // expected PIN data but did not receive it - if (m_promptPin && lp.securityFlags) + if (m_promptPin && lp->securityFlags) { if ((m_lockFlags & FIXED_PIN) == FIXED_PIN) { - pinResult = VerifyPinData(std::stoi(m_securityInfo), pinData); - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' PIN result: %u", m_login.c_str(), get_remote_address().c_str(), pinResult); + pinResult = VerifyPinData(std::stoi(m_securityInfo), *pinData); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' PIN result: %u", m_login.c_str(), GetRemoteIpString().c_str(), pinResult); } else if ((m_lockFlags & TOTP) == TOTP) { @@ -668,38 +637,42 @@ bool AuthSocket::_HandleLogonProof() if (pin == uint32(-1)) break; - if ((pinResult = VerifyPinData(pin, pinData))) + if ((pinResult = VerifyPinData(pin, *pinData))) break; } } else if (m_geoUnlockPIN) { - pinResult = VerifyPinData(m_geoUnlockPIN, pinData); + pinResult = VerifyPinData(m_geoUnlockPIN, *pinData); } else { pinResult = false; - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[ERROR] Invalid PIN flags set for user %s - user cannot log-in until fixed", m_login.c_str()); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Invalid PIN flags set for user %s - user cannot log-in until fixed", m_login.c_str()); } } // Check if SRP6 results match (password is correct), else send an error - if (!srp.Proof(lp.M1, 20) && pinResult) + if (!srp.Proof(lp->M1, 20) && pinResult) { - if (!VerifyVersion(lp.A, sizeof(lp.A), lp.crc_hash, false)) + if (!VerifyVersion(lp->A, sizeof(lp->A), lp->crc_hash, false)) { sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account %s tried to login with modified client!", m_login.c_str()); - char data[2] = { CMD_AUTH_LOGON_PROOF, WOW_FAIL_VERSION_INVALID }; - send(data, sizeof(data)); - return true; + + std::shared_ptr pkt(new ByteBuffer()); + *pkt << (uint8) CMD_AUTH_LOGON_PROOF; + *pkt << (uint8) WOW_FAIL_VERSION_INVALID; + m_socket.Write(std::move(pkt), [self = shared_from_this()](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); + return; } // Geolocking checks must be done after an otherwise successful login to prevent lockout attacks if (m_geoUnlockPIN) // remove the PIN to unlock the account since login succeeded { - auto result = LoginDatabase.PExecute("UPDATE `account` SET `geolock_pin` = 0 WHERE `username` = '%s'", - m_safelogin.c_str()); - + bool result = LoginDatabase.PExecute("UPDATE `account` SET `geolock_pin` = 0 WHERE `username` = '%s'", m_safelogin.c_str()); if (!result) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Unable to remove geolock PIN for %s - account has not been unlocked", m_safelogin.c_str()); @@ -707,19 +680,22 @@ bool AuthSocket::_HandleLogonProof() } else if (GeographicalLockCheck()) { - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Account '%s' (%u) using IP '%s' has been geolocked", m_login.c_str(), m_accountId, get_remote_address().c_str()); // todo, add additional logging info - - auto pin = urand(100000, 999999); // check rand32_max - auto result = LoginDatabase.PExecute("UPDATE `account` SET `geolock_pin` = %u WHERE `username` = '%s'", - pin, m_safelogin.c_str()); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Account '%s' (%u) using IP '%s' has been geolocked", m_login.c_str(), m_accountId, GetRemoteIpString().c_str()); // todo, add additional logging info + uint32_t pin = urand(100000, 999999); // check rand32_max + bool result = LoginDatabase.PExecute("UPDATE `account` SET `geolock_pin` = %u WHERE `username` = '%s'", pin, m_safelogin.c_str()); if (!result) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Unable to write geolock PIN for %s - account has not been locked", m_safelogin.c_str()); - char data[2] = { CMD_AUTH_LOGON_PROOF, WOW_FAIL_DB_BUSY }; - send(data, sizeof(data)); - return true; + std::shared_ptr pkt(new ByteBuffer()); + *pkt << (uint8) CMD_AUTH_LOGON_PROOF; + *pkt << (uint8) WOW_FAIL_DB_BUSY; + m_socket.Write(std::move(pkt), [self = shared_from_this()](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); + return; } #ifdef ENABLE_MAILSENDER @@ -735,7 +711,7 @@ bool AuthSocket::_HandleLogonProof() mail->from(sConfig.GetStringDefault("MailFrom", "")); mail->substitution("%username%", m_login); mail->substitution("%unlock_pin%", std::to_string(pin)); - mail->substitution("%originating_ip%", get_remote_address()); + mail->substitution("%originating_ip%", GetRemoteIpString()); MailerService::get_global_mailer()->send(std::move(mail), [](SendgridMail::Result res) @@ -746,41 +722,45 @@ bool AuthSocket::_HandleLogonProof() } #endif - char data[2] = { CMD_AUTH_LOGON_PROOF, WOW_FAIL_PARENTCONTROL }; - send(data, sizeof(data)); - return true; + std::shared_ptr pkt(new ByteBuffer()); + *pkt << (uint8) CMD_AUTH_LOGON_PROOF; + *pkt << (uint8) WOW_FAIL_PARENTCONTROL; + m_socket.Write(std::move(pkt), [self = shared_from_this()](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); + return; } - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' successfully authenticated", m_login.c_str(), get_remote_address().c_str()); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' successfully authenticated", m_login.c_str(), GetRemoteIpString().c_str()); // Update the sessionkey, last_ip, last login time and reset number of failed logins in the account table for this account - // No SQL injection (escaped user name) and IP address as received by socket + // No SQL injection (escaped username) and IP address as received by socket std::string K_hex = srp.GetStrongSessionKey().AsHexStr(); - const char *os = reinterpret_cast(&m_os); // no injection as there are only two possible values - const char *platform = reinterpret_cast(&m_platform); // no injection as there are only two possible values - std::unique_ptr result = LoginDatabase.PQuery("UPDATE `account` SET `sessionkey` = '%s', `last_ip` = '%s', `last_login` = NOW(), `locale` = '%u', `failed_logins` = 0, `os` = '%s', `platform` = '%s' WHERE `username` = '%s'", - K_hex.c_str(), get_remote_address().c_str(), GetLocaleByName(m_localizationName), os, platform, m_safelogin.c_str() ); + // Why it must be sync: The new network implementation is so fast that the async db cant execute the UPDATE statement before the client tries to reach mangosd + // If it is async there would be a race condition + bool result = LoginDatabase.PExecute(DbExecMode::MustBeSync, "UPDATE `account` SET `sessionkey` = '%s', `last_ip` = '%s', `last_login` = NOW(), `locale` = '%u', `failed_logins` = 0, `os` = '%s', `platform` = '%s' WHERE `username` = '%s'", + K_hex.c_str(), GetRemoteIpString().c_str(), GetLocaleByName(m_localizationName), m_os.c_str(), m_platform.c_str(), m_safelogin.c_str() ); + if (!result) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Unable to update login stats for account '%s'", m_safelogin.c_str()); + } // Finish SRP6 and send the final result to the client - Crypto::Hash::SHA1::Digest sha = srp.Finalize(); + Crypto::Hash::SHA1::Digest shaDigest = srp.Finalize(); - SendProof(sha); + std::shared_ptr pkt = GenerateLogonProofResponse(shaDigest); m_status = STATUS_AUTHED; + + m_socket.Write(std::move(pkt), [self = shared_from_this()](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); } else { - if (m_build > 6005) // > 1.12.2 - { - char data[4] = { CMD_AUTH_LOGON_PROOF, WOW_FAIL_UNKNOWN_ACCOUNT, 0, 0}; - send(data, sizeof(data)); - } - else - { - // 1.x not react incorrectly at 4-byte message use 3 as real error - char data[2] = { CMD_AUTH_LOGON_PROOF, WOW_FAIL_UNKNOWN_ACCOUNT}; - send(data, sizeof(data)); - } - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' tried to login with wrong password!", m_login.c_str (), get_remote_address().c_str()); + // We are here because the password was incorrect + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' tried to login with wrong password!", m_login.c_str (), GetRemoteIpString().c_str()); uint32 MaxWrongPassCount = sConfig.GetIntDefault("WrongPass.MaxCount", 0); if(MaxWrongPassCount > 0) @@ -788,7 +768,7 @@ bool AuthSocket::_HandleLogonProof() //Increment number of failed logins by one and if it reaches the limit temporarily ban that account or IP LoginDatabase.PExecute("UPDATE `account` SET `failed_logins` = `failed_logins` + 1 WHERE `username` = '%s'",m_safelogin.c_str()); - if(std::unique_ptr failedLoginsDbResult = LoginDatabase.PQuery("SELECT `id`, `failed_logins` FROM `account` WHERE `username` = '%s'", m_safelogin.c_str())) + if (std::unique_ptr failedLoginsDbResult = LoginDatabase.PQuery("SELECT `id`, `failed_logins` FROM `account` WHERE `username` = '%s'", m_safelogin.c_str())) { Field* fields = failedLoginsDbResult->Fetch(); uint32 failed_logins = fields[1].GetUInt32(); @@ -805,11 +785,11 @@ bool AuthSocket::_HandleLogonProof() "VALUES ('%u',UNIX_TIMESTAMP(),UNIX_TIMESTAMP()+'%u','MaNGOS realmd','Failed login autoban',1,1)", acc_id, WrongPassBanTime); sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s' got banned for '%u' seconds because it failed to authenticate '%u' times", - m_login.c_str(), get_remote_address().c_str(), WrongPassBanTime, failed_logins); + m_login.c_str(), GetRemoteIpString().c_str(), WrongPassBanTime, failed_logins); } else { - std::string current_ip = get_remote_address(); + std::string current_ip = GetRemoteIpString(); LoginDatabase.escape_string(current_ip); LoginDatabase.PExecute("INSERT INTO `ip_banned` VALUES ('%s',UNIX_TIMESTAMP(),UNIX_TIMESTAMP()+'%u','MaNGOS realmd','Failed login autoban')", current_ip.c_str(), WrongPassBanTime); @@ -819,200 +799,245 @@ bool AuthSocket::_HandleLogonProof() } } } + + std::shared_ptr pkt(new ByteBuffer()); + *pkt << (uint8) CMD_AUTH_LOGON_PROOF; + *pkt << (uint8) WOW_FAIL_UNKNOWN_ACCOUNT; + if (m_build > 6005) // > 1.12.2 + { + *pkt << (uint8) 0; + *pkt << (uint8) 0; + } + m_socket.Write(std::move(pkt), [self = shared_from_this()](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); } - return true; } // Reconnect Challenge command handler -bool AuthSocket::_HandleReconnectChallenge() +void AuthSocket::_HandleReconnectChallenge() { sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleReconnectChallenge"); - if (recv_len() < sizeof(sAuthLogonChallenge_C)) - return false; - - // Read the first 4 bytes (header) to get the length of the remaining of the packet - std::vector buf; - buf.resize(4); - - recv((char *)&buf[0], 4); - - EndianConvert(*((uint16*)(&buf[0]))); - uint16 remaining = ((sAuthLogonChallenge_C *)&buf[0])->size; - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[ReconnectChallenge] got header, body is %#04x bytes", remaining); - - if ((remaining < sizeof(sAuthLogonChallenge_C) - buf.size()) || (recv_len() < remaining)) - return false; - - // Session is closed unless overriden - m_status = STATUS_CLOSED; - - //No big fear of memory outage (size is int16, i.e. < 65536) - buf.resize(remaining + buf.size() + 1); - buf[buf.size() - 1] = 0; - sAuthLogonChallenge_C *ch = (sAuthLogonChallenge_C*)&buf[0]; + m_status = STATUS_INVALID; - // Read the remaining of the packet - recv((char *)&buf[4], remaining); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[ReconnectChallenge] got full packet, %#04x bytes", ch->size); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[ReconnectChallenge] name(%d): '%s'", ch->I_len, ch->I); - - EndianConvert(ch->build); - m_build = ch->build; + // Read the header first, to get the length of the remaining packet + std::shared_ptr header = std::make_shared(); + m_socket.Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error, size_t) + { + if (error) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleReconnectChallenge Read(header): ERROR"); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } - ch->os[3] = '\0'; - std::reverse(ch->os, ch->os + 3); - memcpy(&m_os, ch->os, sizeof(m_os)); + uint16* pUint16 = reinterpret_cast(header.get()); + EndianConvert(*pUint16); + uint16 actualBodySize = header->size; - ch->platform[3] = '\0'; - std::reverse(ch->platform, ch->platform + 3); - memcpy(&m_platform, ch->platform, sizeof(m_platform)); + if (actualBodySize < sizeof(sAuthLogonChallengeBody) - AUTH_LOGON_MAX_NAME) // TODO: @cMangos: Why is here "-10" and not AUTH_LOGON_MAX_NAME + { // The paket is too small and has no username??? + return; + } - m_login = (const char*)ch->I; - m_safelogin = m_login; - LoginDatabase.escape_string(m_safelogin); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[ReconnectChallenge] got header, body is %#04x bytes", actualBodySize); - std::unique_ptr result = LoginDatabase.PQuery("SELECT `sessionkey`, `id` FROM `account` WHERE `username` = '%s'", m_safelogin.c_str()); + // Read the remaining of the packet + std::shared_ptr body = std::make_shared(); + self->m_socket.Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error, size_t) + { + if (error) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleReconnectChallenge self->m_socket.Read(body): ERROR"); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } - // Stop if the account is not found - if (!result) - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[ERROR] user %s tried to login and we cannot find his session key in the database.", m_login.c_str()); - close_connection(); - return false; - } + if (body->username_len > AUTH_LOGON_MAX_NAME) + return; + body->username[body->username_len] = '\0'; + + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[ReconnectChallenge] got full packet, %#04x bytes", header->size); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[ReconnectChallenge] name(%d): '%s'", body->username_len, body->username); + + // BigEndian code, nop in little endian case + // size already converted + EndianConvert(*((uint32*)(&body->gamename[0]))); + EndianConvert(body->build); + EndianConvert(*((uint32*)(&body->platform[0]))); + EndianConvert(*((uint32*)(&body->os[0]))); + EndianConvert(*((uint32*)(&body->country[0]))); + EndianConvert(body->timezone_bias); + EndianConvert(body->ip); + + self->m_build = body->build; + + // Convert uint8[4] to string, restore string order as its byte order is reversed + // To it for os + body->os[3] = '\0'; + self->m_os = (char*)body->os; + std::reverse(self->m_os.begin(), self->m_os.end()); + // To it for platform + body->platform[3] = '\0'; + self->m_platform = (char*)body->platform; + std::reverse(self->m_platform.begin(), self->m_platform.end()); + // Do it for locale + self->m_localizationName.resize(sizeof(body->country)); + self->m_localizationName.assign(body->country, (body->country + sizeof(body->country))); + std::reverse(self->m_localizationName.begin(), self->m_localizationName.end()); + + // Escape the user input used in DB to avoid further SQL injection + // Memory will be freed on AuthSocket object destruction + self->m_login = (char const*)body->username; + self->m_safelogin = self->m_login; + LoginDatabase.escape_string(self->m_safelogin); + + std::unique_ptr queryResult = LoginDatabase.PQuery("SELECT `sessionkey`, `id` FROM `account` WHERE `username` = '%s'", self->m_safelogin.c_str()); + + // Stop if the account is not found + if (!queryResult) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "user %s tried to login and we cannot find his session key in the database.", self->m_login.c_str()); + self->CloseSocket(); + return; + } - Field* fields = result->Fetch (); - srp.SetStrongSessionKey(fields[0].GetString()); - m_accountId = fields[1].GetUInt32(); - - // All good, await client's proof - m_status = STATUS_RECON_PROOF; - - // Sending response - ByteBuffer pkt; - pkt << (uint8) CMD_AUTH_RECONNECT_CHALLENGE; - pkt << (uint8) 0x00; - m_reconnectProof.SetRand(16 * 8); - pkt.append(m_reconnectProof.AsByteArray(16)); // 16 bytes random - pkt.append(VersionChallenge.data(), VersionChallenge.size()); - send((char const*)pkt.contents(), pkt.size()); - return true; + Field* fields = queryResult->Fetch(); + self->srp.SetStrongSessionKey(fields[0].GetString()); + self->m_accountId = fields[1].GetUInt32(); + + // All good, await client's proof + self->m_status = STATUS_RECON_PROOF; + + // Sending response + std::shared_ptr pkt = std::make_shared(); + *pkt << (uint8)CMD_AUTH_RECONNECT_CHALLENGE; + *pkt << (uint8)0x00; + self->m_reconnectProof.SetRand(16 * 8); + pkt->append(self->m_reconnectProof.AsByteArray(16)); // 16 bytes random + pkt->append(VersionChallenge.data(), VersionChallenge.size()); + self->m_socket.Write(std::move(pkt), [self](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); + }); + }); } // Reconnect Proof command handler -bool AuthSocket::_HandleReconnectProof() +void AuthSocket::_HandleReconnectProof() { sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleReconnectProof"); - // Read the packet - sAuthReconnectProof_C lp; - if(!recv((char *)&lp, sizeof(sAuthReconnectProof_C))) - return false; + m_status = STATUS_INVALID; - // Session is closed unless overriden - m_status = STATUS_CLOSED; + // Read the packet + std::shared_ptr lp(new AUTH_RECONNECT_PROOF_C()); + m_socket.Read((char*) lp.get(), sizeof(AUTH_RECONNECT_PROOF_C), [self = shared_from_this(), lp](IO::NetworkError const& error, size_t) + { + if (error) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleReconnectProof self->m_socket.Read(): ERROR"); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } - BigNumber K = srp.GetStrongSessionKey(); - if (m_login.empty() || !m_reconnectProof.GetNumBytes() || !K.GetNumBytes()) - return false; + BigNumber K = self->srp.GetStrongSessionKey(); + if (self->m_login.empty() || !self->m_reconnectProof.GetNumBytes() || !K.GetNumBytes()) + return; - BigNumber t1; - t1.SetBinary(lp.R1, 16); + BigNumber t1; + t1.SetBinary(lp->R1, 16); - Crypto::Hash::SHA1::Generator sha; - sha.UpdateData(m_login); - sha.UpdateData(t1); - sha.UpdateData(m_reconnectProof); - sha.UpdateData(K); - Crypto::Hash::SHA1::Digest digest = sha.GetDigest(); + Crypto::Hash::SHA1::Generator sha; + sha.UpdateData(self->m_login); + sha.UpdateData(t1); + sha.UpdateData(self->m_reconnectProof); + sha.UpdateData(K); + Crypto::Hash::SHA1::Digest digest = sha.GetDigest(); - if (!memcmp(digest.data(), lp.R2, digest.size())) - { - if (!VerifyVersion(lp.R1, sizeof(lp.R1), lp.R3, true)) + if (!memcmp(digest.data(), lp->R2, digest.size())) { - ByteBuffer pkt; - pkt << uint8(CMD_AUTH_RECONNECT_PROOF); - pkt << uint8(WOW_FAIL_VERSION_INVALID); - send((char const*)pkt.contents(), pkt.size()); - return true; - } + if (!self->VerifyVersion(lp->R1, sizeof(lp->R1), lp->R3, true)) + { + std::shared_ptr pkt = std::make_shared(); + *pkt << uint8(CMD_AUTH_RECONNECT_PROOF); + *pkt << uint8(WOW_FAIL_VERSION_INVALID); + return; + } - // Sending response - ByteBuffer pkt; - pkt << uint8(CMD_AUTH_RECONNECT_PROOF); - pkt << uint8(WOW_SUCCESS); - send((char const*)pkt.contents(), pkt.size()); + // Sending response + std::shared_ptr pkt = std::make_shared(); + *pkt << uint8(CMD_AUTH_RECONNECT_PROOF); + *pkt << uint8(WOW_SUCCESS); + self->m_socket.Write(std::move(pkt), [self](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); - m_status = STATUS_AUTHED; - return true; - } - else - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[ERROR] user %s tried to login, but session invalid.", m_login.c_str()); - close_connection(); - return false; - } + self->m_status = STATUS_AUTHED; + return; + } + else + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "user %s tried to login, but session invalid.", self->m_login.c_str()); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } + }); } // %Realm List command handler -bool AuthSocket::_HandleRealmList() +void AuthSocket::_HandleRealmList() { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleRealmList"); - if (recv_len() < 5) - return false; - - recv_skip(5); + assert(this->m_accountId); - // this shouldn't be possible, but just in case - if (!m_accountId) - return false; - - // check for too frequent requests - auto const minDelay = sConfig.GetIntDefault("MinRealmListDelay", 1); - auto const now = time(nullptr); - auto const delay = now - m_lastRealmListRequest; - - m_lastRealmListRequest = now; - - if (delay < minDelay) + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleRealmList"); + m_socket.ReadSkip(4, [self = shared_from_this()](IO::NetworkError const& error) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[ERROR] user %s IP %s is sending CMD_REALM_LIST too frequently. Delay = %d seconds", m_login.c_str(), get_remote_address().c_str(), delay); - return false; - } + if (error) + { + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } - // Update realm list if need - sRealmList.UpdateIfNeed(); + // check for too frequent requests + auto const minDelay = sConfig.GetIntDefault("MinRealmListDelay", 1); + auto const now = std::chrono::steady_clock::now(); + if (self->m_lastRealmListRequest.has_value()) + { + auto const delay = std::chrono::duration_cast(now - self->m_lastRealmListRequest.value()).count(); + if (delay < minDelay) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "user %s IP %s is sending CMD_REALM_LIST too frequently. Delay = %d seconds", self->m_login.c_str(), self->GetRemoteIpString().c_str(), delay); - // Circle through realms in the RealmList and construct the return packet (including # of user characters in each realm) - ByteBuffer pkt; - LoadRealmlist(pkt); + self->CloseSocket(); // TODO: Remove me. Closing the socket will be done implicitly if all references to this socket are deleted (when there is no IO anymore) + return; + } + } - ByteBuffer hdr; - hdr << (uint8) CMD_REALM_LIST; - hdr << (uint16)pkt.size(); - hdr.append(pkt); + self->m_lastRealmListRequest = now; - send((char const*)hdr.contents(), hdr.size()); + // Update realm list if need + sRealmList.UpdateIfNeed(); - return true; -} + // Circle through realms in the RealmList and construct the return packet (including # of user characters in each realm) + ByteBuffer realmlistBuffer; + self->LoadRealmlistAndWriteIntoBuffer(realmlistBuffer); -std::string AuthSocket::GetRealmAddress(Realm const& realm) const -{ - ACE_INET_Addr addr; - if (peer().get_remote_addr(addr) == 0) - { - ACE_INET_Addr localAddress; - if (localAddress.set(realm.localAddress.c_str()) == 0) - { - if ((addr.get_ip_address() & realm.localSubnetMask) == (localAddress.get_ip_address() & realm.localSubnetMask)) - return realm.localAddress; - } - } + std::shared_ptr pkt(new ByteBuffer()); + *pkt << (uint8) CMD_REALM_LIST; + *pkt << (uint16)realmlistBuffer.size(); + pkt->append(realmlistBuffer); - return realm.address; + self->m_socket.Write(std::move(pkt), [self](IO::NetworkError const& error) + { + self->DoRecvIncomingData(); + }); + }); } -void AuthSocket::LoadRealmlist(ByteBuffer &pkt) +void AuthSocket::LoadRealmlistAndWriteIntoBuffer(ByteBuffer &pkt) { if (m_build < 6299) // before version 2.0.3 (exclusive) { @@ -1054,12 +1079,13 @@ void AuthSocket::LoadRealmlist(ByteBuffer &pkt) if (!ok_build || (i->second.allowedSecurityLevel > GetSecurityOn(i->second.id))) realmflags = RealmFlags(realmflags | REALM_FLAG_OFFLINE); + std::string realmIpPortStr = i->second.GetAddressForClient(m_socket.GetRemoteEndpoint().ip).toString(); uint8 const categoryId = GetRealmCategoryIdByBuildAndZone(m_build, RealmZone(i->second.timeZone)); pkt << uint32(i->second.icon); // realm type pkt << uint8(realmflags); // realmflags pkt << name; // name - pkt << GetRealmAddress(i->second); // address + pkt << realmIpPortStr; // address pkt << float(i->second.populationLevel); pkt << uint8(AmountOfCharacters); pkt << uint8(categoryId); // realm category @@ -1104,13 +1130,14 @@ void AuthSocket::LoadRealmlist(ByteBuffer &pkt) if (!buildInfo) realmFlags = RealmFlags(realmFlags & ~REALM_FLAG_SPECIFYBUILD); + std::string realmIpPortStr = i->second.GetAddressForClient(m_socket.GetRemoteEndpoint().ip).toString(); uint8 const categoryId = GetRealmCategoryIdByBuildAndZone(m_build, RealmZone(i->second.timeZone)); pkt << uint8(i->second.icon); // realm type (this is second column in Cfg_Configs.dbc) pkt << uint8(lock); // flags, if 0x01, then realm locked pkt << uint8(realmFlags); // see enum RealmFlags pkt << i->first; // name - pkt << GetRealmAddress(i->second); // address + pkt << realmIpPortStr; // address pkt << float(i->second.populationLevel); pkt << uint8(AmountOfCharacters); pkt << uint8(categoryId); // realm category (Cfg_Categories.dbc) @@ -1129,69 +1156,59 @@ void AuthSocket::LoadRealmlist(ByteBuffer &pkt) } } -// Resume patch transfer -bool AuthSocket::_HandleXferResume() +// Accept patch transfer +void AuthSocket::_HandleXferAccept() { - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleXferResume"); - - if(recv_len() < 9) - return false; - - recv_skip(1); - - uint64 start_pos; - recv((char *)&start_pos, 8); + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleXferAccept"); - if(m_patch == ACE_INVALID_HANDLE) + if (!m_pendingPatchFile) { - close_connection(); - return false; + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "User '%s' tried to get patch file, but there is no patch file defined?", m_safelogin.c_str()); + return; } - ACE_OFF_T file_size = ACE_OS::filesize(m_patch); + InitAndHandOverControlToPatchHandler(); +} + +// Resume transfer. +// This function is called when the user disconnected during transfer and already has a `wow-patch.mpq.partial`. +// The client may not be closed, this only works if the client is not closed. +void AuthSocket::_HandleXferResume() +{ + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleXferResume"); - if(file_size == -1 || start_pos >= (uint64)file_size) + if (!m_pendingPatchFile) { - close_connection(); - return false; + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "User '%s' tried to get patch file, but there is no patch file defined?", m_safelogin.c_str()); + return; } - if(ACE_OS::lseek(m_patch, start_pos, SEEK_SET) == -1) + auto startPosPtr = std::make_shared(); + m_socket.Read(reinterpret_cast(startPosPtr.get()), sizeof(int64), [self = shared_from_this(), startPosPtr](IO::NetworkError const& error, std::size_t) { - close_connection(); - return false; - } + int64 startPos = *startPosPtr; + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[XFER] User '%s' wants to resume download at byte %llu", self->m_safelogin.c_str(), startPos); - InitPatch(); + if (startPos >= self->m_pendingPatchFile->GetTotalFileSize() || startPos < 0) + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[XFER] User '%s' tried to resume download outside file bounds", self->m_safelogin.c_str()); + return; + } - return true; + self->m_pendingPatchFile->Seek(IO::Filesystem::SeekDirection::Start, startPos); + self->InitAndHandOverControlToPatchHandler(); + }); } // Cancel patch transfer -bool AuthSocket::_HandleXferCancel() +void AuthSocket::_HandleXferCancel() { sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleXferCancel"); - - recv_skip(1); - close_connection(); - - return true; -} - -// Accept patch transfer -bool AuthSocket::_HandleXferAccept() -{ - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleXferAccept"); - - recv_skip(1); - - InitPatch(); - - return true; + // Socket will close implicitly } // Verify PIN entry data -bool AuthSocket::VerifyPinData(uint32 pin, const PINData& clientData) +bool AuthSocket::VerifyPinData(uint32 pin, PINData const& clientData) { // remap the grid to match the client's layout std::vector grid { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; @@ -1263,9 +1280,10 @@ bool AuthSocket::VerifyPinData(uint32 pin, const PINData& clientData) return hash.AsDecStr() == clientHash.AsDecStr(); } -uint32 AuthSocket::GenerateTotpPin(const std::string& secret, int interval) { +uint32 AuthSocket::GenerateTotpPin(std::string const& secret, int interval) +{ std::vector decoded_key((secret.size() + 7) / 8 * 5); - int key_size = base32_decode((const uint8_t*)secret.data(), decoded_key.data(), decoded_key.size()); + int key_size = base32_decode((uint8_t const*)secret.data(), decoded_key.data(), decoded_key.size()); if (key_size == -1) { @@ -1294,25 +1312,48 @@ uint32 AuthSocket::GenerateTotpPin(const std::string& secret, int interval) { return pin; } -void AuthSocket::InitPatch() +/// Will Read() a chunk from m_pendingPatchFile into dataChunkHolder->data +/// This function will recursion call itself when the the sending callback is invoked +void AuthSocket::RepeatInternalXferLoop(std::shared_ptr const& chunk) { - PatchHandler* handler = new PatchHandler(ACE_OS::dup(get_handle()), m_patch); - - m_patch = ACE_INVALID_HANDLE; - - if(handler->open() == -1) + // Will the `chunk->data` array with actual data from the file + uint64_t actualReadAmount = m_pendingPatchFile->ReadSync(&(chunk->data[0]), sizeof(chunk->data)); + if (actualReadAmount == 0) { - handler->close(); - close_connection(); + sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "[XFER]: Done"); + return; } + chunk->data_size = (uint16_t) actualReadAmount; + + // This `fakeSharedPtr` is a bit hacky, we cannot simply Write() a XFER_DATA_CHUNK pointer. + // This is why we convert it to an uint8 pointer without a deallocator. + std::shared_ptr fakeSharedPtr((uint8_t const*)chunk.get(), MaNGOS::Memory::no_deleter()); + m_socket.Write({ fakeSharedPtr, size_t(sizeof(chunk->cmd) + sizeof(chunk->data_size) + actualReadAmount) }, [self = shared_from_this(), chunk](IO::NetworkError const& error) + { + if (error) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[XFER]: Write(...) failed: %s", error.ToString().c_str()); + return; + } + self->RepeatInternalXferLoop(chunk); // Do it again, until everything is transferred + }); +} + +void AuthSocket::InitAndHandOverControlToPatchHandler() +{ + MANGOS_ASSERT(m_pendingPatchFile); + + std::shared_ptr rawChunk(new XFER_DATA_CHUNK()); + rawChunk->cmd = CMD_XFER_DATA; + + RepeatInternalXferLoop(rawChunk); } void AuthSocket::LoadAccountSecurityLevels(uint32 accountId) { - std::unique_ptr result = LoginDatabase.PQuery("SELECT `gmlevel`, `RealmID` FROM `account_access` WHERE `id` = %u", - accountId); + std::unique_ptr result = LoginDatabase.PQuery("SELECT `gmlevel`, `RealmID` FROM `account_access` WHERE `id` = %u", accountId); if (!result) - return; + return; // The account has no special permissions (most likely a normal user) do { @@ -1333,7 +1374,7 @@ bool AuthSocket::GeographicalLockCheck() return false; } - if (m_lastIP.empty() || m_lastIP == get_remote_address()) + if (m_lastIP.empty() || m_lastIP == GetRemoteIpString()) { return false; } @@ -1348,7 +1389,7 @@ bool AuthSocket::GeographicalLockCheck() "FROM geoip " "WHERE network_last_integer >= INET_ATON('%s') " "ORDER BY network_last_integer ASC LIMIT 1", - get_remote_address().c_str(), get_remote_address().c_str()) + GetRemoteIpString().c_str(), GetRemoteIpString().c_str()) ); auto result_prev = std::unique_ptr(LoginDatabase.PQuery( @@ -1375,12 +1416,11 @@ bool AuthSocket::GeographicalLockCheck() uint32_t ip = result->Fetch()[0].GetUInt32(); uint32_t ip_prev = result_prev->Fetch()[0].GetUInt32(); - /* The optimised query will return the next highest range in the event - * of the address not being found in the database. Therefore, we need - * to perform a second check to ensure our address falls within - * the returned range. - * See: https://blog.jcole.us/2007/11/24/on-efficiently-geo-referencing-ips-with-maxmind-geoip-and-mysql-gis/ - */ + // The optimised query will return the next highest range in the event + // of the address not being found in the database. Therefore, we need + // to perform a second check to ensure our address falls within + // the returned range. + // See: https://blog.jcole.us/2007/11/24/on-efficiently-geo-referencing-ips-with-maxmind-geoip-and-mysql-gis/ if (net_start > ip || net_start_prev > ip_prev) { return false; @@ -1438,3 +1478,8 @@ bool AuthSocket::VerifyVersion(uint8 const* a, int32 aLength, uint8 const* versi return false; } + +void AuthSocket::CloseSocket() +{ + m_socket.CloseSocket(); +} diff --git a/src/realmd/AuthSocket.h b/src/realmd/AuthSocket.h index a8c5b8e6c97..12048506e3c 100644 --- a/src/realmd/AuthSocket.h +++ b/src/realmd/AuthSocket.h @@ -19,10 +19,6 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -// \addtogroup realmd -// @{ -// \file - #ifndef _AUTHSOCKET_H #define _AUTHSOCKET_H @@ -31,14 +27,11 @@ #include "Crypto/Hash/SHA1.h" #include "Crypto/Authentication/SRP6.h" #include "ByteBuffer.h" - -#include "BufferedSocket.h" - -struct PINData -{ - uint8 salt[16]; - uint8 hash[20]; -}; +#include "IO/Networking/AsyncSocket.h" +#include "IO/Timer/TimerHandle.h" +#include "IO/Filesystem/FileHandle.h" +#include "Policies/ObjectConstructorTraits.h" +#include "AuthPackets.h" enum LockFlag { @@ -51,32 +44,43 @@ enum LockFlag GEO_CITY = 0x20 }; -// Handle login commands -class AuthSocket: public BufferedSocket +struct sAuthLogonProof_C; + +/// Handle login commands +class AuthSocket : public std::enable_shared_from_this, MaNGOS::Policies::NoCopyNoMove { public: - const static int s_BYTE_SIZE = 32; - - AuthSocket() = default; + explicit AuthSocket(IO::Networking::AsyncSocket socket); ~AuthSocket(); - void OnAccept(); - void OnRead(); - void SendProof(Crypto::Hash::SHA1::Digest sha); - void LoadRealmlist(ByteBuffer &pkt); - bool VerifyPinData(uint32 pin, const PINData& clientData); - uint32 GenerateTotpPin(const std::string& secret, int interval); - - bool _HandleLogonChallenge(); - bool _HandleLogonProof(); - bool _HandleReconnectChallenge(); - bool _HandleReconnectProof(); - bool _HandleRealmList(); + void Start(); + + void DoRecvIncomingData(); + std::shared_ptr GenerateLogonProofResponse(Crypto::Hash::SHA1::Digest const& shaDigest); + void LoadRealmlistAndWriteIntoBuffer(ByteBuffer& pkt); + bool VerifyPinData(uint32 pin, PINData const& clientData); + uint32 GenerateTotpPin(std::string const& secret, int interval); + + void _HandleLogonChallenge(); + void _HandleLogonProof(); + void _HandleLogonProof__PostRecv(std::shared_ptr const& lp, std::shared_ptr const& pinData); + void _HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_ptr const& lp); + void _HandleReconnectChallenge(); + void _HandleReconnectProof(); + void _HandleRealmList(); + //data transfer handle for patch + void _HandleXferAccept(); + void _HandleXferResume(); + void _HandleXferCancel(); - bool _HandleXferResume(); - bool _HandleXferCancel(); - bool _HandleXferAccept(); + /// Returns the IP of the peer e.g. "192.168.13.37" + inline std::string const& GetRemoteIpString() const { return m_remoteIpAddressStringAfterProxy; } + void CloseSocket(); + + public: // A bit hacky, that this is public. In WorldSocket we have WorldSocketMgr as friend, this is not possible here + IO::Networking::AsyncSocket m_socket; + std::string m_remoteIpAddressStringAfterProxy; // might differ from `m_socket.m_descriptor` if behind proxy private: enum eStatus @@ -84,13 +88,12 @@ class AuthSocket: public BufferedSocket STATUS_CHALLENGE, STATUS_LOGON_PROOF, STATUS_RECON_PROOF, - STATUS_PATCH, // unused in CMaNGOS + STATUS_PATCH, STATUS_AUTHED, - STATUS_CLOSED + STATUS_INVALID, }; bool VerifyVersion(uint8 const* a, int32 aLength, uint8 const* versionProof, bool isReconnect); - std::string GetRealmAddress(Realm const& realm) const; SRP6 srp; BigNumber m_reconnectProof; @@ -116,10 +119,10 @@ class AuthSocket: public BufferedSocket static constexpr uint32 X86 = 'x86'; static constexpr uint32 PPC = 'PPC'; - uint32 m_os = 0; - uint32 m_platform = 0; + std::string m_os; + std::string m_platform; uint32 m_accountId = 0; - uint32 m_lastRealmListRequest = 0; + nonstd::optional m_lastRealmListRequest; // Since GetLocaleByName() is _NOT_ bijective, we have to store the locale as a string. Otherwise we can't differ // between enUS and enGB, which is important for the patch system @@ -134,9 +137,14 @@ class AuthSocket: public BufferedSocket typedef std::map AccountSecurityMap; AccountSecurityMap m_accountSecurityOnRealm; - ACE_HANDLE m_patch = ACE_INVALID_HANDLE; + // Auto kick realmd client connection after some time + std::shared_ptr m_sessionDurationTimeout = nullptr; + + // Patching stuff + void InitAndHandOverControlToPatchHandler(); + std::unique_ptr m_pendingPatchFile = nullptr; - void InitPatch(); + void RepeatInternalXferLoop(std::shared_ptr const& chunkSharedPtr); }; + #endif -// @} diff --git a/src/realmd/BufferedSocket.cpp b/src/realmd/BufferedSocket.cpp deleted file mode 100644 index 950a5dde287..00000000000 --- a/src/realmd/BufferedSocket.cpp +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -/** \file - \ingroup realmd - */ - -#include "BufferedSocket.h" -#include "Config/Config.h" - -#include -#include -#include - -#ifndef MSG_NOSIGNAL -#define MSG_NOSIGNAL 0 -#endif - -BufferedSocket::BufferedSocket(void): - input_buffer_(4096), - remote_address_("") -{ -} - -/*virtual*/ BufferedSocket::~BufferedSocket(void) -{ -} - -/*virtual*/ int BufferedSocket::open(void * arg) -{ - if(Base::open(arg) == -1) - return -1; - - if (time_t timer = sConfig.GetIntDefault("MaxSessionDuration", 300)) - { - ACE_Time_Value interval(timer); - Base::reactor_timer_interface()->schedule_timer(this, NULL, interval); - } - - ACE_INET_Addr addr; - - if(peer().get_remote_addr(addr) == -1) - return -1; - - char address[1024]; - - addr.get_host_addr(address, 1024); - - this->remote_address_ = address; - - this->OnAccept(); - - return 0; -} - -const std::string& BufferedSocket::get_remote_address(void) const -{ - return this->remote_address_; -} - -size_t BufferedSocket::recv_len(void) const -{ - return this->input_buffer_.length(); -} - -bool BufferedSocket::recv_soft(char *buf, size_t len) -{ - if(this->input_buffer_.length() < len) - return false; - - ACE_OS::memcpy(buf, this->input_buffer_.rd_ptr(), len); - - return true; -} - -bool BufferedSocket::recv(char *buf, size_t len) -{ - bool ret = this->recv_soft(buf, len); - - if(ret) - this->recv_skip(len); - - return ret; -} - -void BufferedSocket::recv_skip(size_t len) -{ - this->input_buffer_.rd_ptr(len); -} - -ssize_t BufferedSocket::noblk_send(ACE_Message_Block &message_block) -{ - const size_t len = message_block.length(); - - if(len == 0) - return -1; - - // Try to send the message directly. - ssize_t n = this->peer().send(message_block.rd_ptr(), len, MSG_NOSIGNAL); - - if(n < 0) - { - if(errno == EWOULDBLOCK) - // Blocking signal - return 0; - else - // Error - return -1; - } - else if(n == 0) - { - // Can this happen ? - return -1; - } - - // return bytes transmitted - return n; -} - -bool BufferedSocket::send(const char *buf, size_t len) -{ - if(buf == nullptr || len == 0) - return true; - - ACE_Data_Block db( - len, - ACE_Message_Block::MB_DATA, - (const char*)buf, - 0, - 0, - ACE_Message_Block::DONT_DELETE, - 0); - - ACE_Message_Block message_block( - &db, - ACE_Message_Block::DONT_DELETE, - 0); - - message_block.wr_ptr(len); - - if(this->msg_queue()->is_empty()) - { - // Try to send it directly. - ssize_t n = this->noblk_send(message_block); - - if(n < 0) - return false; - else if(n == len) - return true; - - // adjust how much bytes we sent - message_block.rd_ptr((size_t)n); - - // fall down - } - - // enqueue the message, note: clone is needed cause we cant enqueue stuff on the stack - ACE_Message_Block *mb = message_block.clone(); - - if(this->msg_queue()->enqueue_tail(mb, (ACE_Time_Value *) &ACE_Time_Value::zero) == -1) - { - mb->release(); - return false; - } - - // tell reactor to call handle_output() when we can send more data - return this->reactor()->schedule_wakeup(this, ACE_Event_Handler::WRITE_MASK) != -1; -} - -/*virtual*/ int BufferedSocket::handle_output(ACE_HANDLE /*= ACE_INVALID_HANDLE*/) -{ - ACE_Message_Block *mb = 0; - - if(this->msg_queue()->is_empty()) - { - // if no more data to send, then cancel notification - this->reactor()->cancel_wakeup(this, ACE_Event_Handler::WRITE_MASK); - return 0; - } - - if(this->msg_queue()->dequeue_head(mb, (ACE_Time_Value *) &ACE_Time_Value::zero) == -1) - return -1; - - ssize_t n = this->noblk_send(*mb); - - if(n < 0) - { - mb->release(); - return -1; - } - else if(n == mb->length()) - { - mb->release(); - return 1; - } - else - { - mb->rd_ptr(n); - - if(this->msg_queue()->enqueue_head(mb, (ACE_Time_Value *) &ACE_Time_Value::zero) == -1) - { - mb->release(); - return -1; - } - - return 0; - } - - ACE_NOTREACHED(return -1); -} - -/*virtual*/ int BufferedSocket::handle_input(ACE_HANDLE /*= ACE_INVALID_HANDLE*/) -{ - const ssize_t space = this->input_buffer_.space(); - - ssize_t n = this->peer().recv(this->input_buffer_.wr_ptr(), space); - - if(n < 0) - { - // blocking signal or error - return errno == EWOULDBLOCK ? 0 : -1; - } - else if(n == 0) - { - // EOF - return -1; - } - - this->input_buffer_.wr_ptr((size_t)n); - - this->OnRead(); - - // move data in the buffer to the beginning of the buffer - this->input_buffer_.crunch(); - - // return 1 in case there might be more data to read from OS - return n == space ? 1 : 0; -} - -/*virtual*/ int BufferedSocket::handle_close(ACE_HANDLE /*h*/, ACE_Reactor_Mask /*m*/) -{ - this->OnClose(); - - Base::handle_close(); - - return 0; -} - -int BufferedSocket::handle_timeout(ACE_Time_Value const& current_time, void const* act) -{ - this->close_connection(); - - this->OnClose(); - - Base::handle_close(); - - return 0; -} - -void BufferedSocket::close_connection(void) -{ - this->peer().close_reader(); - this->peer().close_writer(); - - reactor()->remove_handler(this, ACE_Event_Handler::DONT_CALL | ACE_Event_Handler::ALL_EVENTS_MASK); -} diff --git a/src/realmd/BufferedSocket.h b/src/realmd/BufferedSocket.h deleted file mode 100644 index f395c952b2b..00000000000 --- a/src/realmd/BufferedSocket.h +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -/** \file - \ingroup realmd - */ - -#ifndef _BUFFEREDSOCKET_H_ -#define _BUFFEREDSOCKET_H_ - -#include -#include -#include -#include -#include -#include - -#include - -class BufferedSocket: public ACE_Svc_Handler -{ - protected: - typedef ACE_Svc_Handler Base; - - virtual void OnRead(void) { } - virtual void OnAccept(void) { } - virtual void OnClose(void) { } - - public: - BufferedSocket(void); - virtual ~BufferedSocket(void); - - size_t recv_len(void) const; - bool recv_soft(char *buf, size_t len); - bool recv(char *buf, size_t len); - void recv_skip(size_t len); - - bool send(const char *buf, size_t len); - - const std::string& get_remote_address(void) const; - - virtual int open(void *) override; - - void close_connection(void); - - virtual int handle_input(ACE_HANDLE = ACE_INVALID_HANDLE) override; - virtual int handle_output(ACE_HANDLE = ACE_INVALID_HANDLE) override; - - virtual int handle_close(ACE_HANDLE = ACE_INVALID_HANDLE, - ACE_Reactor_Mask = ACE_Event_Handler::ALL_EVENTS_MASK) override; - virtual int handle_timeout(ACE_Time_Value const& current_time, - void const* act = 0) override; - - private: - ssize_t noblk_send(ACE_Message_Block &message_block); - - private: - ACE_Message_Block input_buffer_; - - protected: - std::string remote_address_; - -}; - -#endif /* _BUFFEREDSOCKET_H_ */ diff --git a/src/realmd/CMakeLists.txt b/src/realmd/CMakeLists.txt index 6d82d3b470b..2adf7f49b1c 100644 --- a/src/realmd/CMakeLists.txt +++ b/src/realmd/CMakeLists.txt @@ -17,16 +17,15 @@ set(EXECUTABLE_NAME realmd) set(EXECUTABLE_SRCS + AuthPackets.h AuthCodes.h AuthSocket.h - BufferedSocket.h - PatchHandler.h RealmList.h AuthSocket.cpp - BufferedSocket.cpp - Main.cpp - PatchHandler.cpp RealmList.cpp + ClientPatchCache.h + ClientPatchCache.cpp + Main.cpp ) @@ -34,8 +33,21 @@ if(WIN32) list(APPEND EXECUTABLE_SRCS realmd.rc ) +endif() - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D__ACE_INLINE__") +include_directories( + ${CMAKE_SOURCE_DIR}/src/shared + ${CMAKE_SOURCE_DIR}/src/framework + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/src/shared + ${MYSQL_INCLUDE_DIR} + ${OPENSSL_INCLUDE_DIR} +) + +if(WIN32) + include_directories( + ${CMAKE_SOURCE_DIR}/dep/windows/include + ) endif() add_executable(${EXECUTABLE_NAME} @@ -45,7 +57,6 @@ add_executable(${EXECUTABLE_NAME} target_link_libraries(${EXECUTABLE_NAME} shared framework - ${ACE_LIBRARIES} ) if(WIN32) diff --git a/src/realmd/ClientPatchCache.cpp b/src/realmd/ClientPatchCache.cpp new file mode 100644 index 00000000000..f88320fffe3 --- /dev/null +++ b/src/realmd/ClientPatchCache.cpp @@ -0,0 +1,95 @@ +#include "./ClientPatchCache.h" +#include "Policies/SingletonImp.h" +#include "Log.h" +#include "Errors.h" +#include "Config/Config.h" +#include "Crypto/Hash/HMACSHA1.h" +#include "Crypto/Hash/MD5.h" +#include "IO/Filesystem/FileSystem.h" +#include "IO/Filesystem/FileHandle.h" + +INSTANTIATE_SINGLETON_1(ClientPatchCache); + +ClientPatchCache::ClientPatchCache() +{ + LoadPatchesInfo(); +} + +void ClientPatchCache::LoadPatchesInfo() +{ + std::string folderPath = sConfig.GetStringDefault("PatchesDir", "./patches") + "/"; + std::string fullFolderPath = IO::Filesystem::ToAbsolutePath(folderPath); + + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[PatchCache] Loading available game client patches from folder %s", fullFolderPath.c_str()); + + for (const std::string& filePath : IO::Filesystem::GetAllFilesInFolder(fullFolderPath, IO::Filesystem::OutputFilePath::FullFilePath)) + { + auto fileHandle = IO::Filesystem::TryOpenFileReadonly(filePath); + if (fileHandle) + { + sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[PatchCache] Calculate hash of %s", filePath.c_str()); + CalculateAndCacheHash(std::move(fileHandle)); + } + else + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[PatchCache] Failed to open %s", filePath.c_str()); + } + } +} + +Crypto::Hash::MD5::Digest ClientPatchCache::GetOrCalculateHash(std::unique_ptr const& fileHandle) +{ + auto filePath = fileHandle->GetFilePath(); + auto lastModifyDate = fileHandle->GetLastModifyDate(); + auto fileSize = fileHandle->GetTotalFileSize(); + + m_knownPatches_mutex.lock(); + auto const& exisingEntry = m_knownPatches.find(filePath); + if (exisingEntry == m_knownPatches.end() || exisingEntry->second.lastModifyDate != lastModifyDate || exisingEntry->second.fileSize != fileSize) + { // file does not exist in cache or was changed + m_knownPatches_mutex.unlock(); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[PatchCache] Detected change of file '%s'. Will recalculate hash.", filePath.c_str()); + // It's important to have a duplicate file handle here, since we want to guarantee easy access + return CalculateAndCacheHash(fileHandle->DuplicateFileHandle()); + } + else + { // we can use the existent entry + Crypto::Hash::MD5::Digest md5Hash = exisingEntry->second.md5Hash; + m_knownPatches_mutex.unlock(); + return md5Hash; + } +} + +Crypto::Hash::MD5::Digest ClientPatchCache::CalculateAndCacheHash(std::unique_ptr fileHandle) +{ + Crypto::Hash::MD5::Generator md5; + + size_t constexpr CHECK_CHUNK_SIZE = 1024 * 1024; // 1 MiB Chunks + std::vector buffer(CHECK_CHUNK_SIZE); + + uint64_t totalRead = 0; + + do { // Read the file chunk by chunk and add insert it into our MD5_Update + uint64_t actuallyRead = fileHandle->ReadSync(buffer.data(), CHECK_CHUNK_SIZE); + md5.UpdateData(buffer.data(), (size_t) actuallyRead); + + totalRead += actuallyRead; + + if (actuallyRead < CHECK_CHUNK_SIZE) + break; // we read less than expected, meaning the file is done + } while (true); + + PatchCacheEntry entry; + entry.filePath = fileHandle->GetFilePath(); + entry.lastModifyDate = fileHandle->GetLastModifyDate(); + entry.fileSize = fileHandle->GetTotalFileSize(); + entry.md5Hash = md5.GetDigest(); + + MANGOS_ASSERT(totalRead == entry.fileSize); + + m_knownPatches_mutex.lock(); + m_knownPatches.emplace(entry.filePath, entry); + m_knownPatches_mutex.unlock(); + + return entry.md5Hash; +} diff --git a/src/realmd/ClientPatchCache.h b/src/realmd/ClientPatchCache.h new file mode 100644 index 00000000000..99f2b87117f --- /dev/null +++ b/src/realmd/ClientPatchCache.h @@ -0,0 +1,48 @@ +#ifndef MANGOS_CLIENTPATCHCACHE_H +#define MANGOS_CLIENTPATCHCACHE_H + +#include +#include +#include +#include +#include +#include + +#include "Crypto/Hash/HMACSHA1.h" +#include "Crypto/Hash/MD5.h" +#include "IO/Filesystem/FileHandle.h" +#include "Policies/Singleton.h" + +struct PatchCacheEntry +{ + // To figure out if the file was changed + std::string filePath; + uint64_t fileSize; + std::chrono::system_clock::time_point lastModifyDate; + + // The stuff we are actually interested in + Crypto::Hash::MD5::Digest md5Hash; +}; + +/// Caches MD5 hash of client patches present on the server +class ClientPatchCache : public MaNGOS::Singleton> +{ + public: + explicit ClientPatchCache(); + + /// This function will detect changes in the size or modification date of the file + /// The FileHandle will be untouched (You can use the same handle to send the data to the client) + Crypto::Hash::MD5::Digest GetOrCalculateHash(std::unique_ptr const& fileHandle); + + /// The FileHandle will be taken over and freed + Crypto::Hash::MD5::Digest CalculateAndCacheHash(std::unique_ptr fileHandle); + + private: + void LoadPatchesInfo(); + std::mutex m_knownPatches_mutex; + std::unordered_map m_knownPatches; +}; + +#define sRealmdPatchCache MaNGOS::Singleton::Instance() + +#endif //MANGOS_CLIENTPATCHCACHE_H diff --git a/src/realmd/Main.cpp b/src/realmd/Main.cpp index ad8b3f9f6bf..55e92959950 100644 --- a/src/realmd/Main.cpp +++ b/src/realmd/Main.cpp @@ -26,9 +26,11 @@ #include "Common.h" #include "Database/DatabaseEnv.h" #include "RealmList.h" +#include "ClientPatchCache.h" #include "Config/Config.h" #include "Log.h" +#include "Errors.h" #include "AuthSocket.h" #include "SystemConfig.h" #include "revision.h" @@ -36,13 +38,12 @@ #include "migrations_list.h" #include #include +#include "ArgparserForServer.h" +#include "ProxyProtocol/ProxyV2Reader.h" -#include -#include -#include -#include -#include -#include +#include "IO/Networking/AsyncSocketAcceptor.h" +#include "IO/Timer/AsyncSystemTimer.h" +#include "IO/Multithreading/CreateThread.h" #ifdef ENABLE_MAILSENDER #include "MailerService.h" @@ -51,6 +52,7 @@ #ifdef WIN32 #include "ServiceWin32.h" + char serviceName[] = "realmd"; char serviceLongName[] = "MaNGOS realmd service"; char serviceDescription[] = "Massive Network Game Object Server"; @@ -60,7 +62,7 @@ char serviceDescription[] = "Massive Network Game Object Server"; * 1 - running * 2 - paused */ -int m_ServiceStatus = -1; +volatile int m_ServiceStatus = -1; #else #include "PosixDaemon.h" #endif @@ -69,154 +71,83 @@ bool StartDB(); void UnhookSignals(); void HookSignals(); -bool stopEvent = false; // Setting it to true stops the server - -DatabaseType LoginDatabase; // Accessor to the realm server database - -// Print out the usage string for this program on the console. -void usage(const char *prog) -{ - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Usage: \n %s []\n" - " -v, --version print version and exist\n\r" - " -c config_file use config_file as configuration file\n\r" - #ifdef WIN32 - " Running as service functions:\n\r" - " -s run run as service\n\r" - " -s install install service\n\r" - " -s uninstall uninstall service\n\r" - #else - " Running as daemon functions:\n\r" - " -s run run as daemon\n\r" - " -s stop stop daemon\n\r" - #endif - ,prog); -} - -char const* g_mainLogFileName = "Realmd.log"; +// Global initialization +char const* g_mainLogFileName = "Realmd.log"; // Log file path for sLog +volatile bool stopEvent = false; // Setting it to true stops the server +DatabaseType LoginDatabase; // Accessor to the realm server database // Launch the realm server -extern int main(int argc, char **argv) +extern int main(int argc, char** argv) { - // Command line parsing - char const* cfg_file = _REALMD_CONFIG; - - char const *options = ":c:s:"; - - ACE_Get_Opt cmd_opts(argc, argv, options); - cmd_opts.long_option("version", 'v'); - - char serviceDaemonMode = '\0'; - - int option; - while ((option = cmd_opts()) != EOF) + ServerStartupArguments args; { - switch (option) - { - case 'c': - cfg_file = cmd_opts.opt_arg(); - break; - case 'v': - printf("Core revion: %s\n", _FULLVERSION); - return 0; + // parseResult is std::expected, where the error is the return code, that might be present when invalid args or "--help" is given + auto parseResult = ParseServerStartupArguments(argc, argv); + if (!parseResult) + return parseResult.error(); - case 's': - { - const char *mode = cmd_opts.opt_arg(); - - if (!strcmp(mode, "run")) - serviceDaemonMode = 'r'; -#ifdef WIN32 - else if (!strcmp(mode, "install")) - serviceDaemonMode = 'i'; - else if (!strcmp(mode, "uninstall")) - serviceDaemonMode = 'u'; -#else - else if (!strcmp(mode, "stop")) - serviceDaemonMode = 's'; -#endif - else - { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Runtime-Error: -%c unsupported argument %s", cmd_opts.opt_opt(), mode); - usage(argv[0]); - Log::WaitBeforeContinueIfNeed(); - return 1; - } - break; - } - case ':': - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Runtime-Error: -%c option requires an input argument", cmd_opts.opt_opt()); - usage(argv[0]); - Log::WaitBeforeContinueIfNeed(); - return 1; - default: - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Runtime-Error: bad format of commandline arguments"); - usage(argv[0]); - Log::WaitBeforeContinueIfNeed(); - return 1; - } + args = parseResult.value(); } -#ifdef WIN32 // windows service command need execute before config read - switch (serviceDaemonMode) - { - case 'i': - if (WinServiceInstall()) - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Installing service"); - return 1; - case 'u': - if (WinServiceUninstall()) - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Uninstalling service"); - return 1; - case 'r': - WinServiceRun(); - break; - } -#endif + if (args.configFilePath.empty()) + args.configFilePath = _REALMD_CONFIG; - if (!sConfig.SetSource(cfg_file)) + if (!sConfig.LoadFromFile(args.configFilePath)) // must be done before (linux) service init { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Could not find configuration file %s.", cfg_file); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Could not find or parse configuration file %s", args.configFilePath.c_str()); Log::WaitBeforeContinueIfNeed(); - return 1; + return EXIT_FAILURE; } -#ifndef WIN32 // posix daemon commands need apply after config read - switch (serviceDaemonMode) + switch (args.inputServiceMode) { - case 'r': - startDaemon(); - break; - case 's': - stopDaemon(); - break; - } + case ServiceDaemonAction::NotSet: + break; +#ifdef WIN32 + // windows service command need execute before config read + case ServiceDaemonAction::Install: + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Installing service..."); + return WinServiceInstall() ? EXIT_SUCCESS : EXIT_FAILURE; + case ServiceDaemonAction::Uninstall: + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Uninstalling service..."); + return WinServiceUninstall() ? EXIT_SUCCESS : EXIT_FAILURE; + case ServiceDaemonAction::Start: + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Starting service..."); + return WinServiceRun() ? EXIT_SUCCESS : EXIT_FAILURE; +#else + // posix daemon commands need apply after config read + case ServiceDaemonAction::Start: + startDaemon(); + break; + case ServiceDaemonAction::Stop: + stopDaemon(); + break; #endif + } sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Core revision: %s [realm-daemon]", _FULLVERSION); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, " to stop.\n" ); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Using configuration file %s.", cfg_file); + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, " to stop.\n"); + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Using configuration file %s.", sConfig.GetFilename().c_str()); // Check the version of the configuration file uint32 confVersion = sConfig.GetIntDefault("ConfVersion", 0); if (confVersion < _REALMDCONFVERSION) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "*****************************************************************************"); - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, " WARNING: Your realmd.conf version indicates your conf file is out of date!"); - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, " Please check for updates, as your current default values may cause"); - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, " strange behavior."); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, " WARNING: Your realmd.conf version indicates your conf file is out of date! "); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, " Please check for updates, as your current default values may cause "); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, " strange behavior. "); sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "*****************************************************************************"); Log::WaitBeforeContinueIfNeed(); } sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "%s (Library: %s)", OPENSSL_VERSION_TEXT, SSLeay_version(SSLEAY_VERSION)); - if (SSLeay() < 0x009080bfL ) + if (SSLeay() < 0x009080bfL) { sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "WARNING: Outdated version of OpenSSL lib. Logins to server may not work!"); sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "WARNING: Minimal required version [OpenSSL 0.9.8k]"); } - sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "Using ACE: %s", ACE_VERSION); - #ifdef ENABLE_MAILSENDER sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "Using CURL version %s", curl_version()); @@ -225,31 +156,23 @@ extern int main(int argc, char **argv) MailerService::set_global_mailer(&mailer); #endif -#if defined (ACE_HAS_EVENT_POLL) || defined (ACE_HAS_DEV_POLL) - ACE_Reactor::instance(new ACE_Reactor(new ACE_Dev_Poll_Reactor(ACE::max_handles(), 1), 1), true); -#else - ACE_Reactor::instance(new ACE_Reactor(new ACE_TP_Reactor(), true), true); -#endif - - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Max allowed open files is %d", ACE::max_handles()); - // realmd PID file creation std::string pidfile = sConfig.GetStringDefault("PidFile", ""); - if(!pidfile.empty()) + if (!pidfile.empty()) { uint32 pid = CreatePIDFile(pidfile); - if( !pid ) + if (!pid) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Cannot create PID file %s.\n", pidfile.c_str() ); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Cannot create PID file %s.\n", pidfile.c_str()); Log::WaitBeforeContinueIfNeed(); return 1; } - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Daemon PID: %u\n", pid ); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Daemon PID: %u\n", pid); } // Initialize the database connection - if(!StartDB()) + if (!StartDB()) { Log::WaitBeforeContinueIfNeed(); return 1; @@ -282,6 +205,10 @@ extern int main(int argc, char **argv) return 1; } + (void)sRealmdPatchCache; // <-- This will initialize the singleton. Which will preload all known patches. + (void)sAsyncSystemTimer; // <-- Pre-Initialize SystemTimer + IO::Multithreading::RenameCurrentThread("Main"); + // cleanup query // set expired bans to inactive LoginDatabase.BeginTransaction(); @@ -289,46 +216,89 @@ extern int main(int argc, char **argv) LoginDatabase.Execute("DELETE FROM `ip_banned` WHERE `unbandate`<=UNIX_TIMESTAMP() AND `unbandate`<>`bandate`"); LoginDatabase.CommitTransaction(); - // Launch the listening network socket - ACE_Acceptor acceptor; + std::string bindIp = sConfig.GetStringDefault("BindIP", "0.0.0.0"); + uint16 bindPort = sConfig.GetIntDefault("RealmServerPort", DEFAULT_REALMSERVER_PORT); - uint16 rmport = sConfig.GetIntDefault("RealmServerPort", DEFAULT_REALMSERVER_PORT); - std::string bind_ip = sConfig.GetStringDefault("BindIP", "0.0.0.0"); - - ACE_INET_Addr bind_addr(rmport, bind_ip.c_str()); + std::unique_ptr ioCtx = IO::IoContext::CreateIoContext(); + if (ioCtx == nullptr) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Failed to create IoContext"); + Log::WaitBeforeContinueIfNeed(); + return 1; + } - if(acceptor.open(bind_addr, ACE_Reactor::instance(), ACE_NONBLOCK) == -1) + // Launch the listening network socket + std::unique_ptr listener = IO::Networking::AsyncSocketAcceptor::CreateAndBindServer(ioCtx.get(), bindIp, bindPort); + if (listener == nullptr) { - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MaNGOS realmd can not bind to %s:%d", bind_ip.c_str(), rmport); + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "MaNGOS realmd can not bind to %s:%d - Is the port already in use?", bindIp.c_str(), bindPort); Log::WaitBeforeContinueIfNeed(); return 1; } + std::vector trustedProxyIps = SplitStringByDelimiter(sConfig.GetStringDefault("TrustedProxyServers", ""), ','); + + listener->AutoAcceptSocketsUntilClose([ctx = ioCtx.get(), trustedProxyIps](IO::Networking::SocketDescriptor socketDescriptor) + { + // Create a socket and attach it to our global ioCtx + auto authSocket = std::make_shared(std::move(IO::Networking::AsyncSocket(ctx, std::move(socketDescriptor)))); + + if (IO::NetworkError initError = authSocket->m_socket.InitializeAndFixateMemoryLocation()) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[%s] Failed to initialize AuthSocket %s", authSocket->m_socket.GetRemoteIpString().c_str(), initError.ToString().c_str()); + return; // implicit close() + } + + + // Check if the remote endpoint is actually a trusted proxy, so we can retrieve the real client ip + if (!trustedProxyIps.empty() && std::find(trustedProxyIps.begin(), trustedProxyIps.end(), authSocket->m_socket.GetRemoteIpString()) != trustedProxyIps.end()) + { + // parse proxy header + ProxyProtocol::ReadProxyV2Handshake(&(authSocket->m_socket), [authSocket](nonstd::expected const& maybeIp) + { + if (!maybeIp.has_value()) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] Failed to parse proxy header. Error: %s", authSocket->m_socket.GetRemoteIpString().c_str(), maybeIp.error().ToString().c_str()); + return; // implicit close() + } + authSocket->m_remoteIpAddressStringAfterProxy = maybeIp.value().ToString(); + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] Connection accepted (proxy ip: %s)", authSocket->GetRemoteIpString().c_str(), authSocket->m_socket.GetRemoteIpString().c_str()); + authSocket->Start(); + }); + } + else + { + // no proxy, we can start directly + sLog.Out(LOG_NETWORK, LOG_LVL_BASIC, "[%s] Connection accepted", authSocket->GetRemoteIpString().c_str()); + authSocket->Start(); + } + }); + // Catch termination signals HookSignals(); // Handle affinity for multiple processors and process priority on Windows - #ifdef WIN32 +#ifdef WIN32 { - HANDLE hProcess = GetCurrentProcess(); + HANDLE hProcess = ::GetCurrentProcess(); uint32 Aff = sConfig.GetIntDefault("UseProcessors", 0); - if(Aff > 0) + if (Aff > 0) { ULONG_PTR appAff; ULONG_PTR sysAff; - if(GetProcessAffinityMask(hProcess,&appAff,&sysAff)) + if (::GetProcessAffinityMask(hProcess, &appAff, &sysAff)) { ULONG_PTR curAff = Aff & appAff; // remove non accessible processors - if(!curAff ) + if (!curAff) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Processors marked in UseProcessors bitmask (hex) %x not accessible for realmd. Accessible processors bitmask (hex): %x",Aff,appAff); } else { - if(SetProcessAffinityMask(hProcess,curAff)) + if (::SetProcessAffinityMask(hProcess, curAff)) sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Using processors (bitmask, hex): %x", curAff); else sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Can't set used processors (hex): %x", curAff); @@ -339,55 +309,65 @@ extern int main(int argc, char **argv) bool Prio = sConfig.GetBoolDefault("ProcessPriority", false); + // if(Prio && (m_ServiceStatus == -1)/* need set to default process priority class in service mode*/) if(Prio) { - if(SetPriorityClass(hProcess,HIGH_PRIORITY_CLASS)) + if (::SetPriorityClass(hProcess,HIGH_PRIORITY_CLASS)) sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "realmd process priority class set to HIGH"); else sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Can't set realmd process priority class."); - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ""); } } - #endif +#endif //server has started up successfully => enable async DB requests LoginDatabase.AllowAsyncTransactions(); // maximum counter for next ping - uint32 numLoops = (sConfig.GetIntDefault( "MaxPingTime", 30 ) * (MINUTE * 1000000 / 100000)); + uint32 numLoops = (sConfig.GetIntDefault("MaxPingTime", 30) * (MINUTE * 1000000 / 100000)); // TODO make this loop like mangosd uint32 loopCounter = 0; - #ifndef WIN32 + auto ioThread = IO::Multithreading::CreateThread("MainIoCtx", [&ioCtx]() + { + ioCtx->RunUntilShutdown(); + }); + +#ifndef WIN32 detachDaemon(); - #endif +#endif // Wait for termination signal while (!stopEvent) { - // dont move this outside the loop, the reactor will modify it - ACE_Time_Value interval(0, 100000); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - if (ACE_Reactor::instance()->run_reactor_event_loop(interval) == -1) - break; - - if( (++loopCounter) == numLoops ) + if ((++loopCounter) == numLoops) // TODO make this loop like mangosd { loopCounter = 0; sLog.Out(LOG_BASIC, LOG_LVL_DETAIL, "Ping MySQL to keep connection alive"); LoginDatabase.Ping(); } + #ifdef WIN32 if (m_ServiceStatus == 0) stopEvent = true; while (m_ServiceStatus == 2) Sleep(1000); #endif } + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Waiting for IO thread to finish"); + + listener->ClosePortAndStopAcceptingNewConnections(); + sAsyncSystemTimer.RemoveAllTimersAndStopThread(); + + ioCtx->Shutdown(); + ioThread.join(); + // Wait for the delay thread to exit LoginDatabase.HaltDelayThread(); // Remove signal handling before leaving UnhookSignals(); - sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Halting process..." ); + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Halting process..."); return 0; } @@ -401,21 +381,21 @@ void OnSignal(int s) case SIGTERM: stopEvent = true; break; - #ifdef _WIN32 +#ifdef _WIN32 case SIGBREAK: stopEvent = true; break; - #endif +#endif } - signal(s, OnSignal); + ::signal(s, OnSignal); } // Initialize connection to the database bool StartDB() { std::string dbstring = sConfig.GetStringDefault("LoginDatabaseInfo", ""); - if(dbstring.empty()) + if (dbstring.empty()) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Database not specified"); return false; @@ -454,8 +434,8 @@ bool StartDB() return false; } - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Database: %s", dbStringLog.c_str() ); - if(!LoginDatabase.Initialize(dbstring.c_str())) + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "Database: %s", dbStringLog.c_str()); + if (!LoginDatabase.Initialize(dbstring.c_str())) { sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Cannot connect to database"); return false; @@ -474,21 +454,21 @@ bool StartDB() // Define hook 'OnSignal' for all termination signals void HookSignals() { - signal(SIGINT, OnSignal); - signal(SIGTERM, OnSignal); - #ifdef _WIN32 - signal(SIGBREAK, OnSignal); - #endif + ::signal(SIGINT, OnSignal); + ::signal(SIGTERM, OnSignal); +#ifdef _WIN32 + ::signal(SIGBREAK, OnSignal); +#endif } // Unhook the signals before leaving void UnhookSignals() { - signal(SIGINT, 0); - signal(SIGTERM, 0); - #ifdef _WIN32 - signal(SIGBREAK, 0); - #endif + ::signal(SIGINT, nullptr); + ::signal(SIGTERM, nullptr); +#ifdef _WIN32 + ::signal(SIGBREAK, nullptr); +#endif } // @} diff --git a/src/realmd/PatchHandler.cpp b/src/realmd/PatchHandler.cpp deleted file mode 100644 index 1a6a47a946d..00000000000 --- a/src/realmd/PatchHandler.cpp +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -/** \file - \ingroup realmd - */ - -#include "PatchHandler.h" -#include "AuthCodes.h" -#include "Log.h" -#include "Common.h" - -#include -#include -#include -#include - -#include - -#include "Crypto/Hash/MD5.h" -#include "Policies/SingletonImp.h" -#include "Policies/ThreadingModel.h" - -#ifndef MSG_NOSIGNAL -#define MSG_NOSIGNAL 0 -#endif - -#if defined( __GNUC__ ) -#pragma pack(1) -#else -#pragma pack(push,1) -#endif - -struct Chunk -{ - ACE_UINT8 cmd; - ACE_UINT16 data_size; - ACE_UINT8 data[4096]; // 4096 - page size on most arch -}; - -#if defined( __GNUC__ ) -#pragma pack() -#else -#pragma pack(pop) -#endif - -PatchHandler::PatchHandler(ACE_HANDLE socket, ACE_HANDLE patch) -{ - reactor(nullptr); - set_handle(socket); - patch_fd_ = patch; -} - -PatchHandler::~PatchHandler() -{ - if(patch_fd_ != ACE_INVALID_HANDLE) - ACE_OS::close(patch_fd_); -} - -int PatchHandler::open(void*) -{ - if(get_handle() == ACE_INVALID_HANDLE || patch_fd_ == ACE_INVALID_HANDLE) - return -1; - - int nodelay = 0; - if (-1 == peer().set_option(ACE_IPPROTO_TCP, - TCP_NODELAY, - &nodelay, - sizeof(nodelay))) - { - return -1; - } - -#if defined(TCP_CORK) - int cork = 1; - if (-1 == peer().set_option(ACE_IPPROTO_TCP, - TCP_CORK, - &cork, - sizeof(cork))) - { - return -1; - } -#endif //TCP_CORK - - (void) peer().disable(ACE_NONBLOCK); - - return activate(THR_NEW_LWP | THR_DETACHED | THR_INHERIT_SCHED); -} - -int PatchHandler::svc(void) -{ - // Do 1 second sleep, similar to the one in game/WorldSocket.cpp - // Seems client have problems with too fast sends. - ACE_OS::sleep(1); - - int flags = MSG_NOSIGNAL; - - Chunk data; - data.cmd = CMD_XFER_DATA; - - ssize_t r; - - while((r = ACE_OS::read(patch_fd_, data.data, sizeof(data.data))) > 0) - { - data.data_size = (ACE_UINT16)r; - - if(peer().send((const char*)&data, - ((size_t) r) + sizeof(data) - sizeof(data.data), - flags) == -1) - { - return -1; - } - } - - if(r == -1) - { - return -1; - } - - return 0; -} - -PatchCache::~PatchCache() -{ - for (Patches::iterator i = patches_.begin (); i != patches_.end (); i++) - delete i->second; -} - -PatchCache::PatchCache() -{ - LoadPatchesInfo(); -} - - -using PatchCacheLock = MaNGOS::ClassLevelLockable; -INSTANTIATE_SINGLETON_2(PatchCache, PatchCacheLock); -INSTANTIATE_CLASS_MUTEX(PatchCache, std::mutex); - -PatchCache* PatchCache::instance() -{ - return &MaNGOS::Singleton::Instance(); -} - -void PatchCache::LoadPatchMD5(const char* szFileName) -{ - // Try to open the patch file - std::string path = szFileName; - FILE* pPatch = fopen(path.c_str (), "rb"); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Loading patch info from file %s", path.c_str()); - - if(!pPatch) - return; - - // Calculate the MD5 hash - Crypto::Hash::MD5::Generator md5Generator; - - const size_t check_chunk_size = 4*1024; - - ACE_UINT8 buf[check_chunk_size]; - - while(!feof (pPatch)) - { - size_t read = fread(buf, 1, check_chunk_size, pPatch); - md5Generator.UpdateData(buf, read); - } - - fclose(pPatch); - - // Store the result in the internal patch hash map - patches_[path] = new PATCH_INFO { md5Generator.GetDigest() }; -} - -bool PatchCache::GetHash(const char * pat, ACE_UINT8 mymd5[Crypto::Hash::MD5::Digest::size()]) -{ - for (Patches::iterator i = patches_.begin (); i != patches_.end (); i++) - if (!stricmp(pat, i->first.c_str ())) - { - memcpy(mymd5, i->second->md5.data(), i->second->md5.size()); - return true; - } - - return false; -} - -void PatchCache::LoadPatchesInfo() -{ - std::string path = sConfig.GetStringDefault("PatchesDir", "./patches") + "/"; - std::string fullpath; - ACE_DIR* dirp = ACE_OS::opendir(ACE_TEXT(path.c_str())); - sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Loading patch info from folder %s", path.c_str()); - - if (!dirp) - return; - - ACE_DIRENT* dp; - - while ((dp = ACE_OS::readdir(dirp)) != nullptr) - { - int l = strlen(dp->d_name); - if (l < 8) - continue; - - if (!memcmp(&dp->d_name[l - 4], ".mpq", 4)) - { - fullpath = path + dp->d_name; - LoadPatchMD5(fullpath.c_str()); - } - } - - // causes crash on windows -#ifndef _WIN32 - ACE_OS::closedir(dirp); -#endif -} diff --git a/src/realmd/PatchHandler.h b/src/realmd/PatchHandler.h deleted file mode 100644 index e46161f2101..00000000000 --- a/src/realmd/PatchHandler.h +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -/** \file - \ingroup realmd - */ - -#ifndef _PATCHHANDLER_H_ -#define _PATCHHANDLER_H_ - -#include "Config/Config.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "Crypto/Hash/MD5.h" - -/** - * @brief Caches MD5 hash of client patches present on the server - */ -class PatchCache -{ - public: - ~PatchCache(); - PatchCache(); - - static PatchCache* instance(); - - struct PATCH_INFO - { - Crypto::Hash::MD5::Digest md5; - }; - - typedef std::map Patches; - - Patches::const_iterator begin() const - { - return patches_.begin(); - } - - Patches::const_iterator end() const - { - return patches_.end(); - } - - void LoadPatchMD5(const char*); - bool GetHash(const char * pat, ACE_UINT8 mymd5[Crypto::Hash::MD5::Digest::size()]); - - private: - void LoadPatchesInfo(); - Patches patches_; - -}; - -class PatchHandler: public ACE_Svc_Handler -{ - protected: - typedef ACE_Svc_Handler Base; - - public: - PatchHandler(ACE_HANDLE socket, ACE_HANDLE patch); - virtual ~PatchHandler(); - - int open(void* = 0); - - protected: - virtual int svc(void); - - private: - ACE_HANDLE patch_fd_; - -}; - -#endif /* _BK_PATCHHANDLER_H__ */ diff --git a/src/realmd/RealmList.cpp b/src/realmd/RealmList.cpp index 8e6be7a8a9b..3546b46483e 100644 --- a/src/realmd/RealmList.cpp +++ b/src/realmd/RealmList.cpp @@ -28,16 +28,18 @@ #include "AuthCodes.h" #include "Util.h" // for Tokens typedef #include "Log.h" +#include "Errors.h" #include "Policies/SingletonImp.h" #include "Database/DatabaseEnv.h" -#include +#include "IO/Networking/DNS.h" +#include "IO/Networking/Utils.h" INSTANTIATE_SINGLETON_1( RealmList ); // list sorted from high to low build and first build used as low bound for accepted by default range (any > it will accepted by realmd at least) std::vector ExpectedRealmdClientBuilds; -std::vector FindBuildInfo(uint16 build, uint32 os, uint32 platform) +std::vector FindBuildInfo(uint16 build, std::string const& os, std::string const& platform) { std::vector matchingBuilds; for (auto const& itr : ExpectedRealmdClientBuilds) @@ -107,7 +109,7 @@ void RealmList::Initialize(uint32 updateInterval) UpdateRealms(true); } -void RealmList::UpdateRealm(uint32 realmId, std::string const& name, std::string const& address, std::string const& localAddress, std::string const& localSubnetMask, uint32 port, uint8 icon, RealmFlags realmFlags, uint8 timeZone, AccountTypes allowedSecurityLevel, float population, std::string const& builds) +void RealmList::UpdateRealm(uint32 realmId, std::string const& name, IO::Networking::IpAddress const& externalIpAddress, IO::Networking::IpAddress const& localIpAddress, uint8 localSubnetMaskCidr, uint16 port, uint8 icon, RealmFlags realmFlags, uint8 timeZone, AccountTypes allowedSecurityLevel, float population, std::string const& builds) { // Create new if not exist or update existed Realm& realm = m_realms[name]; @@ -141,23 +143,9 @@ void RealmList::UpdateRealm(uint32 realmId, std::string const& name, std::string if (bInfo->build == first_build) realm.realmBuildInfo = *bInfo; - // Append port to IP address. - std::ostringstream ss; - ss << address << ":" << port; - realm.address = ss.str(); - - // Same for the local address. - ss.str(""); - ss.clear(); - ss << localAddress << ":" << port; - realm.localAddress = ss.str(); - - // Subnet mask does not need port. - ACE_INET_Addr subnetAddress; - if (subnetAddress.set("0", localSubnetMask.c_str()) == -1) - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Failed to parse local subnet mask for realm id %u!", realmId); - else - realm.localSubnetMask = subnetAddress.get_ip_address(); + realm.externalAddress = IO::Networking::IpEndpoint(externalIpAddress, port); + realm.localAddress = IO::Networking::IpEndpoint(localIpAddress, port); + realm.localSubnetMaskCidr = localSubnetMaskCidr; } void RealmList::UpdateIfNeed() @@ -193,11 +181,12 @@ void RealmList::UpdateRealms(bool init) { Field *fields = result->Fetch(); - uint32 id = fields[0].GetUInt32(); + uint32 realmId = fields[0].GetUInt32(); std::string name = fields[1].GetCppString(); - std::string address = fields[2].GetCppString(); - std::string localAddress = fields[3].GetCppString(); - std::string localSubnetMask = fields[4].GetCppString(); + std::string externalAddressString = fields[2].GetCppString(); + std::string localAddressString = fields[3].GetCppString(); + // TODO the db should be changed to a numeric subnet mask, so invalid states cant be represented (instead of "255.255.255.0" it should be "24") + std::string localSubnetMaskString = fields[4].GetCppString(); uint32 port = fields[5].GetUInt32(); uint8 icon = fields[6].GetUInt8(); uint8 realmflags = fields[7].GetUInt8(); @@ -212,8 +201,45 @@ void RealmList::UpdateRealms(bool init) realmflags &= (REALM_FLAG_OFFLINE|REALM_FLAG_NEW_PLAYERS|REALM_FLAG_RECOMMENDED|REALM_FLAG_SPECIFYBUILD); } + auto externalIpAddress = IO::Networking::DNS::ResolveDomainSingle(externalAddressString, IO::Networking::IpAddress::Type::IPv4); + if (!externalIpAddress) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Could not parse externalAddress: %s for realm \"%s\" (id %d) will skip realm update.", externalAddressString.c_str(), name.c_str(), realmId); + continue; + } + + auto localIpAddress = IO::Networking::DNS::ResolveDomainSingle(localAddressString, IO::Networking::IpAddress::Type::IPv4); + if (!localIpAddress) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Could not parse localAddress: %s for realm \"%s\" (id %d) will skip realm update.", localAddressString.c_str(), name.c_str(), realmId); + continue; + } + + auto localSubnetMaskIp = IO::Networking::IpAddress::TryParseFromString(localSubnetMaskString); + if (!localSubnetMaskIp) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Could not parse localSubnetMask: %s for realm \"%s\" (id %d) will skip realm update.", localSubnetMaskString.c_str(), name.c_str(), realmId); + continue; + } + + uint8 localSubnetMaskCidr = 0; + // Check and convert subnet mask + { + uint32 localSubnetMaskBinary = localSubnetMaskIp->_getInternalIPv4ReprAsUint32(); + if (((~localSubnetMaskBinary) & ((~localSubnetMaskBinary) + 1)) != 0) // doing some binary trickery to check if this is really a valid subnet mask without holes + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Invalid localSubnetMask: %s for realm \"%s\" (id %d) will skip realm update.", localSubnetMaskString.c_str(), name.c_str(), realmId); + continue; + } + while (localSubnetMaskBinary) + { + localSubnetMaskCidr += (localSubnetMaskBinary & 0x01); + localSubnetMaskBinary >>= 1; + } + } + UpdateRealm( - id, name, address, localAddress, localSubnetMask, port, icon, RealmFlags(realmflags), timezone, + realmId, name, *externalIpAddress, *localIpAddress, localSubnetMaskCidr, port, icon, RealmFlags(realmflags), timezone, (allowedSecurityLevel <= SEC_ADMINISTRATOR ? AccountTypes(allowedSecurityLevel) : SEC_ADMINISTRATOR), population, realmBuilds); @@ -245,14 +271,8 @@ void RealmList::LoadAllowedClients() std::string hotfixVersion = fields[3].GetCppString(); buildInfo.hotfixVersion = hotfixVersion.empty() ? 0 : hotfixVersion[0]; buildInfo.build = fields[4].GetUInt16(); - - std::string os = fields[5].GetCppString(); - MANGOS_ASSERT(os.size() == 3); - memcpy(&buildInfo.os, os.data(), 4); - - std::string platform = fields[6].GetCppString(); - MANGOS_ASSERT(platform.size() == 3); - memcpy(&buildInfo.platform, platform.data(), 4); + buildInfo.os = fields[5].GetCppString(); + buildInfo.platform = fields[6].GetCppString(); std::string integrityHash = fields[7].GetCppString(); if (!integrityHash.empty()) @@ -266,3 +286,15 @@ void RealmList::LoadAllowedClients() } while (result->NextRow()); } } + +IO::Networking::IpEndpoint Realm::GetAddressForClient(IO::Networking::IpAddress const& clientAddr) const +{ + if (clientAddr.GetType() != IO::Networking::IpAddress::Type::IPv4) + return externalAddress; + + // Check if user connected with an IpAddress that is in the same subnet as the localAddress of the realm + bool clientHasAccessToLocalSubnet = IO::Networking::IsInSameSubnet(clientAddr, localAddress.ip, localSubnetMaskCidr); + return clientHasAccessToLocalSubnet + ? localAddress + : externalAddress; +} diff --git a/src/realmd/RealmList.h b/src/realmd/RealmList.h index d417e511925..e120ef551ed 100644 --- a/src/realmd/RealmList.h +++ b/src/realmd/RealmList.h @@ -28,6 +28,8 @@ #include "Common.h" #include "RealmZone.h" +#include "IO/Networking/IpAddress.h" + #include struct RealmBuildInfo @@ -37,13 +39,13 @@ struct RealmBuildInfo uint8 bugfixVersion = 0; char hotfixVersion = 0; uint16 build = 0; - uint32 os = 0; - uint32 platform = 0; + std::string os; + std::string platform; std::array integrityHash = { }; }; RealmBuildInfo const* FindBuildInfo(uint16 build); -std::vector FindBuildInfo(uint16 build, uint32 os, uint32 platform); +std::vector FindBuildInfo(uint16 build, std::string const& os, std::string const& platform); uint8 GetRealmCategoryIdByBuildAndZone(uint16 build, RealmZone realmZone); extern std::vector ExpectedRealmdClientBuilds; @@ -53,9 +55,9 @@ typedef std::set RealmBuilds; struct Realm { uint32 id = 0; - std::string address; - std::string localAddress; - uint32 localSubnetMask = 0; + IO::Networking::IpEndpoint externalAddress; + IO::Networking::IpEndpoint localAddress; + uint8 localSubnetMaskCidr = 0; // only valid if localAddress is IPv4 uint8 icon = 0; RealmFlags realmFlags = REALM_FLAG_NONE; uint8 timeZone = 0; @@ -63,6 +65,8 @@ struct Realm float populationLevel = 0.0f; RealmBuilds realmBuilds; // list of supported builds (updated in DB by mangosd) RealmBuildInfo realmBuildInfo; // build info for showing version in list + + IO::Networking::IpEndpoint GetAddressForClient(IO::Networking::IpAddress const& clientAddr) const; }; // Storage object for the list of realms on the server @@ -85,7 +89,19 @@ class RealmList uint32 size() const { return m_realms.size(); } private: void UpdateRealms(bool init); - void UpdateRealm(uint32 realmId, std::string const& name, std::string const& address, std::string const& localAddress, std::string const& localSubnetMask, uint32 port, uint8 icon, RealmFlags realmFlags, uint8 timeZone, AccountTypes allowedSecurityLevel, float population, std::string const& builds); + void UpdateRealm( + uint32 realmId, + std::string const& name, + IO::Networking::IpAddress const& externalIpAddress, + IO::Networking::IpAddress const& localIpAddress, + uint8 localSubnetMaskCidr, + uint16 port, + uint8 icon, + RealmFlags realmFlags, + uint8 timeZone, + AccountTypes allowedSecurityLevel, + float population, + std::string const& builds); void LoadAllowedClients(); private: RealmMap m_realms; // Internal map of realms diff --git a/src/realmd/realmd.conf.dist.in b/src/realmd/realmd.conf.dist.in index c5486187855..66f34e65eba 100644 --- a/src/realmd/realmd.conf.dist.in +++ b/src/realmd/realmd.conf.dist.in @@ -3,7 +3,7 @@ ############################################ [RealmdConf] -ConfVersion=2020010501 +ConfVersion=2024091701 ################################################################################################################### # REALMD SETTINGS @@ -39,15 +39,30 @@ ConfVersion=2020010501 # on different IP addresses using default ports. # DO NOT CHANGE THIS UNLESS YOU _REALLY_ KNOW WHAT YOU'RE DOING # +# TrustedProxyServers +# Description: Enables the parsing of Proxy Protocol v2 for specific IPs. +# You can use this feature when your server is behind a proxy, load balancer, or similar component, +# to retrieve the real IP address of players. +# You need to enable Proxy Protocol v2 on both this server and the proxy/load balancer. +# For example see HaProxy "send-proxy-v2" option. +# Multiple servers can be separated with ',' +# Default: "" - (Disabled, no proxy) +# Example "10.13.37.1,10.13.37.2" - (to allow multiple proxy servers) +# # PidFile # Realmd daemon PID file # Default: "" - do not create PID file # "./realmd.pid" - create PID file (recommended name) # -# LogLevel +# LogLevel.Console # Server console level of logging -# 0 = Minimum; 1 = Error; 2 = Detail; 3 = Full/Debug -# Default: 0 +# 0 = Error; 1 = Minimum; 2 = Basic; 3 = Detail; 4 = Full/Debug +# Default: 2 +# +# LogLevel.File +# Server file level of logging +# 0 = Error; 1 = Minimum; 2 = Basic; 3 = Detail; 4 = Full/Debug +# Default: 2 # # LogTime # Include time in server console output [hh:mm:ss] @@ -82,7 +97,8 @@ ConfVersion=2020010501 # WaitAtStartupError # After startup error report wait or some time before continue (and possible close console window) # -1 (wait until press) -# Default: 0 (not wait) +# 0 (no wait) +# Default: 5 (wait 5 sec) # N (>0, wait N secs) # # MinRealmListDelay @@ -160,15 +176,17 @@ PatchesDir = "./patches" MaxPingTime = 30 RealmServerPort = 3724 BindIP = "0.0.0.0" +TrustedProxyServers = "" PidFile = "" -LogLevel = 0 -LogTime = 0 +LogLevel.Console = 2 +LogLevel.File = 2 +LogTime = 1 LogFile = "Realmd.log" LogTimestamp = 0 LogFileLevel = 0 UseProcessors = 0 ProcessPriority = 1 -WaitAtStartupError = 0 +WaitAtStartupError = 5 # consider fixing your actual problem before changing this value! MinRealmListDelay = 1 RealmsStateUpdateDelay = 20 WrongPass.MaxCount = 0 diff --git a/src/scripts/CMakeLists.txt b/src/scripts/CMakeLists.txt index 0d375e5b22d..aa2d21dbb5a 100644 --- a/src/scripts/CMakeLists.txt +++ b/src/scripts/CMakeLists.txt @@ -327,7 +327,6 @@ include_directories( ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/src/shared ${TBB_INCLUDE_DIRS} - ${ACE_INCLUDE_DIR} ${MYSQL_INCLUDE_DIR} ) diff --git a/src/scripts/eastern_kingdoms/eastern_plaguelands/naxxramas/instance_naxxramas.cpp b/src/scripts/eastern_kingdoms/eastern_plaguelands/naxxramas/instance_naxxramas.cpp index 3c3384475ca..e16d4ded14e 100644 --- a/src/scripts/eastern_kingdoms/eastern_plaguelands/naxxramas/instance_naxxramas.cpp +++ b/src/scripts/eastern_kingdoms/eastern_plaguelands/naxxramas/instance_naxxramas.cpp @@ -750,7 +750,7 @@ bool instance_naxxramas::IsEncounterInProgress() const void instance_naxxramas::SetData(uint32 uiType, uint32 uiData) { - ASSERT(this) + ASSERT(this); bool sameStateAsLast = false; if (uiType < MAX_ENCOUNTER) diff --git a/src/shared/ArgparserForServer.cpp b/src/shared/ArgparserForServer.cpp new file mode 100644 index 00000000000..5cf79959274 --- /dev/null +++ b/src/shared/ArgparserForServer.cpp @@ -0,0 +1,118 @@ +#include "ArgparserForServer.h" +#include "SystemConfig.h" + +// Print out the usage string for this program on the console. +void printUsage(char const* thisExecutableName) +{ +#if defined(WIN32) + printf(R"END(Usage: %s [] + -v, --version print version and exist + -c, --config use config_file as configuration file + Running as service functions: + -s run run as service + -s install install service + -s uninstall uninstall service +)END", thisExecutableName); +#else + printf(R"END(Usage: %s [] + -v, --version print version and exist + -c, --config use config_file as configuration file + Running as daemon functions: + -s run run as daemon + -s stop stop daemon +)END", thisExecutableName); +#endif +} + +nonstd::expected ParseServerStartupArguments(int /*MUTABLE*/ argc, char** /*MUTABLE*/ argv) // we are changing the meaning of argc and argv +{ + char const* thisExecutableName = argv[0]; + argv++; // skip the first argument, since it's just the executable name + argc--; + + ServerStartupArguments args + { + ServiceDaemonAction::NotSet, + "" + }; + + while (argc > 0) + { + std::string part = std::string(*argv); + argv++; + argc--; + + if (part == "-h" || part == "--help") + { + printUsage(thisExecutableName); + return nonstd::make_unexpected(EXIT_SUCCESS); // explicitly printing help is okay + } + else if (part == "-v" || part == "--version") + { + printf("Core revision: %s\n", _FULLVERSION); + return nonstd::make_unexpected(EXIT_SUCCESS); // explicitly printing version is okay + } + else if (part == "-c" || part == "--config") + { + if (argc == 0) + { + printf("Error: Config filename required\n"); + printUsage(thisExecutableName); + return nonstd::make_unexpected(EXIT_FAILURE); + } + + std::string nextPart = std::string(*argv); + argv++; + argc--; + + args.configFilePath = nextPart; + } + else if (part == "-s") + { + if (argc == 0) + { + printf("Error: Config filename required\n"); + printUsage(thisExecutableName); + return nonstd::make_unexpected(EXIT_FAILURE); + } + + std::string nextPart = std::string(*argv); + argv++; + argc--; + + if (nextPart == "run") + { + args.inputServiceMode = ServiceDaemonAction::Start; + } +#ifdef WIN32 + else if (nextPart == "install") + { + args.inputServiceMode = ServiceDaemonAction::Install; + } + else if (nextPart == "uninstall") + { + args.inputServiceMode = ServiceDaemonAction::Uninstall; + } +#else + else if (nextPart == "stop") + { + args.inputServiceMode = ServiceDaemonAction::Stop; + } +#endif + else + { + printf("Error: Unknown service mode\n"); + printUsage(thisExecutableName); + return nonstd::make_unexpected(EXIT_FAILURE); + } + } + else + { + printf("Error: Unknown argument provided '%s'\n", part.c_str()); + printUsage(thisExecutableName); + return nonstd::make_unexpected(EXIT_FAILURE); + } + } + + return args; +} diff --git a/src/shared/ArgparserForServer.h b/src/shared/ArgparserForServer.h new file mode 100644 index 00000000000..98ada9fdb92 --- /dev/null +++ b/src/shared/ArgparserForServer.h @@ -0,0 +1,28 @@ +#ifndef MANGOS_ARGPARSERFORSERVER_H +#define MANGOS_ARGPARSERFORSERVER_H + +#include +#include "nonstd/expected.hpp" + +enum class ServiceDaemonAction +{ + NotSet, + Start, +#ifdef WIN32 + Install, + Uninstall, +#else + Stop, +#endif +}; + +struct ServerStartupArguments +{ + ServiceDaemonAction inputServiceMode; // will be empty when no mode was set + std::string configFilePath; +}; + +/// Returns a parsed ServerStartupArguments or exit value +nonstd::expected ParseServerStartupArguments(int argc, char** argv); + +#endif // MANGOS_ARGPARSERFORSERVER_H diff --git a/src/shared/ByteBuffer.cpp b/src/shared/ByteBuffer.cpp index e1b4ac18a9a..f2e1e8174c3 100644 --- a/src/shared/ByteBuffer.cpp +++ b/src/shared/ByteBuffer.cpp @@ -21,6 +21,7 @@ #include "ByteBuffer.h" #include "Log.h" +#include "Errors.h" void ByteBufferException::PrintPosError() const { diff --git a/src/shared/ByteBuffer.h b/src/shared/ByteBuffer.h index 8036e95d1ec..94f90fadbbc 100644 --- a/src/shared/ByteBuffer.h +++ b/src/shared/ByteBuffer.h @@ -22,6 +22,8 @@ #ifndef _BYTEBUFFER_H #define _BYTEBUFFER_H +#include + #include "Common.h" #include "Utilities/ByteConverter.h" @@ -60,7 +62,7 @@ class ByteBuffer } // constructor - ByteBuffer(size_t res): _rpos(0), _wpos(0) + explicit ByteBuffer(size_t res): _rpos(0), _wpos(0) { _storage.reserve(res); } @@ -423,7 +425,13 @@ class ByteBuffer append((uint8 const*)str.c_str(), str.size() + 1); } - void append(std::vector const& src) + void append(std::vector const& src) + { + return append(src.data(), src.size()); + } + + template + void append(std::array const& src) { return append(src.data(), src.size()); } diff --git a/src/shared/CMakeLists.txt b/src/shared/CMakeLists.txt index 9d5dbae2cde..b045a204f29 100644 --- a/src/shared/CMakeLists.txt +++ b/src/shared/CMakeLists.txt @@ -19,8 +19,9 @@ set (shared_SRCS ByteBuffer.h Common.h - DelayExecutor.h + EnumFlag.h Errors.h + Errors.cpp LockedQueue.h Log.h migrations_list.h @@ -32,6 +33,8 @@ set (shared_SRCS ServiceWin32.h SystemConfig.h ThreadPool.h + ThreadSpecificPtr.h + ThreadSpecificPtr.cpp Timer.h Util.h WheatyExceptionReport.h @@ -70,11 +73,11 @@ set (shared_SRCS Database/SQLStorage.h Database/SQLStorageImpl.h Multithreading/Messager.h + nonstd/expected.hpp TimePeriod.h nonstd/optional.hpp ByteBuffer.cpp Common.cpp - DelayExecutor.cpp Log.cpp PosixDaemon.cpp ProgressBar.cpp @@ -100,6 +103,56 @@ set (shared_SRCS Database/SqlPreparedStatement.cpp Database/SQLStorage.cpp Multithreading/Messager.cpp + IO/Context/IoContext.h + IO/Utils.h + IO/Utils.cpp + IO/Utils_Unix.h + IO/ReadableBuffer.h + IO/Context/IoContext_macos.cpp + IO/Context/IoContext_unix.cpp + IO/Context/IoContext_windows.cpp + IO/SystemErrorToString.h + IO/SystemErrorToString.cpp + IO/Networking/Internal.h + IO/Networking/Internal.cpp + IO/Networking/AsyncSocket.h + IO/Networking/AsyncSocket.cpp + IO/Networking/AsyncSocket_posix.cpp + IO/Networking/AsyncSocket_windows.cpp + IO/Networking/AsyncSocketAcceptor.h + IO/Networking/AsyncSocketAcceptor_posix.cpp + IO/Networking/AsyncSocketAcceptor_windows.cpp + IO/Networking/SocketConnector.h + IO/Networking/SocketConnector.cpp + IO/Networking/NetworkError.h + IO/Networking/NetworkError.cpp + IO/Networking/SocketDescriptor.h + IO/Networking/SocketDescriptor.cpp + IO/Networking/Utils.h + IO/Networking/Utils.cpp + IO/Networking/IpAddress.h + IO/Networking/IpAddress.cpp + IO/Networking/DNS.h + IO/Networking/DNS.cpp + IO/Multithreading/CreateThread.h + IO/Multithreading/CreateThread.cpp + IO/Timer/impl/windows/AsyncSystemTimer.cpp + IO/Timer/impl/windows/TimerHandle.cpp + IO/Timer/impl/unix/AsyncSystemTimer.cpp + IO/Timer/impl/unix/TimerHandle.cpp + IO/Timer/AsyncSystemTimer.h + IO/Filesystem/FileSystem.h + IO/Filesystem/FileHandle.h + IO/Filesystem/impl/windows/FileSystem.cpp + IO/Filesystem/impl/windows/FileHandle.cpp + IO/Filesystem/impl/unix/FileSystem.cpp + IO/Filesystem/impl/unix/FileHandle.cpp + ProxyProtocol/ProxyV2Reader.h + ProxyProtocol/ProxyV2Reader.cpp + ArgparserForServer.h + ArgparserForServer.cpp + Memory/ArrayDeleter.h + Memory/NoDeleter.h TimePeriod.cpp ) @@ -113,13 +166,22 @@ if (ENABLE_MAILSENDER) ) endif() -# Exclude Win32 files -if(WIN32) + +if(WIN32) # For window build: Exclude Unix/MacOS files list(REMOVE_ITEM shared_SRCS PosixDaemon.h PosixDaemon.cpp revision.h migrations_list.h + IO/Utils_Unix.h + IO/Context/IoContext_unix.cpp + IO/Context/IoContext_macos.cpp + IO/Networking/AsyncSocket_posix.cpp + IO/Networking/AsyncSocketAcceptor_posix.cpp + IO/Timer/impl/unix/AsyncSystemTimer.cpp + IO/Timer/impl/unix/TimerHandle.cpp + IO/Filesystem/impl/unix/FileSystem.cpp + IO/Filesystem/impl/unix/FileHandle.cpp ) if (NOT MSVC) @@ -128,8 +190,6 @@ if(WIN32) WheatyExceptionReport.h ) endif() - - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D__ACE_INLINE__") else() list(REMOVE_ITEM shared_SRCS WheatyExceptionReport.cpp @@ -138,7 +198,25 @@ else() ServiceWin32.h revision.h migrations_list.h + IO/Context/IoContext_windows.cpp + IO/Networking/AsyncSocket_windows.cpp + IO/Networking/AsyncSocketAcceptor_windows.cpp + IO/Timer/impl/windows/AsyncSystemTimer.cpp + IO/Timer/impl/windows/TimerHandle.cpp + IO/Filesystem/impl/windows/FileSystem.cpp + IO/Filesystem/impl/windows/FileHandle.cpp ) + if(APPLE) + # Remove Linux specific stuff + list(REMOVE_ITEM shared_SRCS + IO/Context/IoContext_unix.cpp + ) + else() + # Remove macOS specific stuff + list(REMOVE_ITEM shared_SRCS + IO/Context/IoContext_macos.cpp + ) + endif() endif() source_group("Util" @@ -172,7 +250,6 @@ target_include_directories(shared PUBLIC ${CMAKE_SOURCE_DIR}/dep/include ${CMAKE_SOURCE_DIR}/src/framework ${CMAKE_BINARY_DIR} - ${ACE_INCLUDE_DIR} ${MYSQL_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR} ) @@ -198,11 +275,13 @@ endif() if(UNIX) find_package(Threads) - target_link_libraries(shared PUBLIC ${ACE_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + target_link_libraries(shared PUBLIC ${CMAKE_THREAD_LIBS_INIT}) endif(UNIX) -if(MINGW) - target_link_libraries(shared PUBLIC ${ACE_LIBRARIES} -lws2_32) +if(WIN32) + target_link_libraries(shared PUBLIC ws2_32 mswsock) endif() -SET_TARGET_PROPERTIES (shared PROPERTIES FOLDER "Game Libs") +target_link_libraries(shared PRIVATE cpptrace::cpptrace) + +SET_TARGET_PROPERTIES(shared PROPERTIES FOLDER "Game Libs") diff --git a/src/shared/Common.h b/src/shared/Common.h index 74bf032eac5..aed443ab667 100644 --- a/src/shared/Common.h +++ b/src/shared/Common.h @@ -54,14 +54,6 @@ #include "Platform/Define.h" -#if COMPILER == COMPILER_MICROSOFT -# pragma warning(disable:4996) // 'function': was declared deprecated -#ifndef __SHOW_STUPID_WARNINGS__ -# pragma warning(disable:4244) // 'argument' : conversion from 'type1' to 'type2', possible loss of data -# pragma warning(disable:4355) // 'this' : used in base member initializer list -#endif // __SHOW_STUPID_WARNINGS__ -#endif // __GNUC__ - #include "Platform/CompilerDefs.h" #include "Platform/Define.h" #include @@ -73,10 +65,6 @@ #include #include -#if defined(__sun__) -#include // finite() on Solaris -#endif - #include #include #include @@ -90,43 +78,12 @@ typedef std::chrono::system_clock Clock; typedef std::chrono::time_point TimePoint; -#include "Errors.h" -#include "LockedQueue.h" - -#include -#include -#include - -// Old ACE versions (pre-ACE-5.5.4) not have this type (add for allow use at Unix side external old ACE versions) -#if PLATFORM != PLATFORM_WINDOWS -# ifndef ACE_OFF_T -typedef off_t ACE_OFF_T; -# endif -#endif - -#if PLATFORM == PLATFORM_WINDOWS -# if !defined (FD_SETSIZE) -# define FD_SETSIZE 4096 -# endif -# include -# include -#else -# include -# include -# include -# include -# include -# include -#endif - #if COMPILER == COMPILER_MICROSOFT -# include +# include # define I32FMT "%08I32X" # define I64FMT "%016I64X" -//# define snprintf _snprintf -# define vsnprintf _vsnprintf #else @@ -142,24 +99,22 @@ typedef off_t ACE_OFF_T; #endif -#define UI64FMTD ACE_UINT64_FORMAT_SPECIFIER -#define UI64LIT(N) ACE_UINT64_LITERAL(N) +#include -#define SI64FMTD ACE_INT64_FORMAT_SPECIFIER -#define SI64LIT(N) ACE_INT64_LITERAL(N) +#define UI64FMTD "%" PRIu64 +#define SI64FMTD "%" PRId64 -#define SIZEFMTD ACE_SIZE_T_FORMAT_SPECIFIER +#define SIZEFMTD "%zu" +/// Will always return a finite float. If the provided float is infinite it will return 0 inline float finiteAlways(float f) { return std::isfinite(f) ? f : 0.0f; } #define atol(a) strtoul(a, nullptr, 10) -#define STRINGIZE(a) #a - // used for creating values for respawn for example #define MAKE_PAIR64(l, h) uint64( uint32(l) | ( uint64(h) << 32 ) ) -#define PAIR64_HIPART(x) (uint32)((uint64(x) >> 32) & UI64LIT(0x00000000FFFFFFFF)) -#define PAIR64_LOPART(x) (uint32)(uint64(x) & UI64LIT(0x00000000FFFFFFFF)) +#define PAIR64_HIPART(x) (uint32)((uint64(x) >> 32) & uint64(0x00000000FFFFFFFF)) +#define PAIR64_LOPART(x) (uint32)(uint64(x) & uint64(0x00000000FFFFFFFF)) #define MAKE_PAIR32(l, h) uint32( uint16(l) | ( uint32(h) << 16 ) ) #define PAIR32_HIPART(x) (uint16)((uint32(x) >> 16) & 0x0000FFFF) @@ -277,10 +232,6 @@ inline char* mangos_strdup(char const* source) # define M_PI_F float(M_PI) #endif -#ifndef countof -#define countof(array) (sizeof(array) / sizeof((array)[0])) -#endif - -#define BATCHING_INTERVAL 400 +#define BATCHING_INTERVAL 400 // TODO, why is this here? What is "Spell.*Delay" in the config used for then? #endif diff --git a/src/shared/Config/Config.cpp b/src/shared/Config/Config.cpp index 0381d33d952..139500b3c65 100644 --- a/src/shared/Config/Config.cpp +++ b/src/shared/Config/Config.cpp @@ -1,9 +1,4 @@ /* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or @@ -20,98 +15,193 @@ */ #include "Config.h" - #include "Policies/SingletonImp.h" INSTANTIATE_SINGLETON_2(Config, Config::Lock); INSTANTIATE_CLASS_MUTEX(Config, std::shared_timed_mutex); -// Defined here as it must not be exposed to end-users. -bool Config::GetValueHelper(char const* name, ACE_TString& result) +static bool IsLineEndChar(char chr) { - GuardType guard(m_configLock); - - if (!mConf) - return false; - - ACE_TString section_name; - ACE_Configuration_Section_Key section_key; - ACE_Configuration_Section_Key const& root_key = mConf->root_section(); - - int i = 0; - while (mConf->enumerate_sections(root_key, i, section_name) == 0) + switch (chr) { - mConf->open_section(root_key, section_name.c_str(), 0, section_key); - if (mConf->get_string_value(section_key, name, result) == 0) + case '\0': + case '\n': + case '\r': return true; - ++i; } return false; } -Config::Config() -: mConf(nullptr) +bool Config::LoadFromFile(std::string const& filename) { + m_fileName = filename; + return Reload(); } -Config::~Config() +bool Config::Reload() { - delete mConf; -} + FILE* pFile = fopen(m_fileName.c_str(), "r"); + if (!pFile) + return false; -bool Config::SetSource(char const* file) -{ - mFilename = file; + std::lock_guard guard(m_configLock); + m_configMap.clear(); - return Reload(); + char buffer[1024]; + while (fgets(buffer, sizeof(buffer), pFile)) + { + ProcessLine(buffer); + } + + fclose(pFile); + return !m_configMap.empty(); } -bool Config::Reload() +enum LineReadStage { - delete mConf; - mConf = new ACE_Configuration_Heap; + STAGE_FIND_NAME, + STAGE_READ_NAME, + STAGE_FIND_VALUE, + STAGE_READ_VALUE +}; - if (mConf->open() != -1) +bool Config::ProcessLine(char const* line) +{ + LineReadStage stage = STAGE_FIND_NAME; + std::string name; + std::string value; + + int i = 0; + while (!IsLineEndChar(line[i])) { - ACE_Ini_ImpExp config_importer(*mConf); - if (config_importer.import_config(mFilename.c_str()) != -1) - return true; + bool stop = false; + bool quoteFound = false; + + switch (stage) + { + case STAGE_FIND_NAME: + case STAGE_READ_NAME: + { + switch (line[i]) + { + case '#': // comment at unexpected place + case '[': // section + return false; + case ' ': // skip white space + break; + case '=': // name has been read + if (stage == STAGE_FIND_NAME) + return false; + if (name.empty()) + return false; + stage = STAGE_FIND_VALUE; + break; + default: + name += line[i]; + stage = STAGE_READ_NAME; + break; + } + break; + } + case STAGE_FIND_VALUE: + case STAGE_READ_VALUE: + { + switch (line[i]) + { + case '#': // comment can only be at end of line, stop reading + if (!quoteFound) + stop = true; + break; + case '"': // handle quoted text + if (quoteFound) + stop = true; + else + { + quoteFound = true; + stage = STAGE_READ_VALUE; + } + break; + case ' ': // ignore white space until text found + if (stage == STAGE_FIND_VALUE) + break; + default: + value += line[i]; + stage = STAGE_READ_VALUE; + break; + } + break; + } + } + + if (stop) + break; + + ++i; } - delete mConf; - mConf = nullptr; - return false; + if (name.empty() || value.empty()) + return false; + + if (!m_configMap.insert({ name, value }).second) + { + printf("Config setting '%s' appear twice in config! Ignoring second occurrence.\n", name.c_str()); + return false; + } + + return true; +} + +std::string Config::GetFilename() const +{ + return m_fileName; } -std::string Config::GetStringDefault(char const* name, char const* def) +bool Config::GetValueHelper(char const* name, std::string &result) const { - ACE_TString val; + std::shared_lock guard(m_configLock); + + auto itr = m_configMap.find(name); + if (itr == m_configMap.end()) + return false; + + result = itr->second; + return true; +} + +bool Config::IsSet(char const* name) const +{ + std::string val; // we are not interested in the value + return GetValueHelper(name, val); +} + +std::string Config::GetStringDefault(char const* name, char const* def) const +{ + std::string val; return GetValueHelper(name, val) ? val.c_str() : def; } -bool Config::GetBoolDefault(char const* name, bool def) +bool Config::GetBoolDefault(char const* name, bool def) const { - ACE_TString val; + std::string val; if (!GetValueHelper(name, val)) return def; - char const* str = val.c_str(); - return strcmp(str, "true") == 0 || strcmp(str, "TRUE") == 0 || - strcmp(str, "yes") == 0 || strcmp(str, "YES") == 0 || - strcmp(str, "1") == 0; + return val == "true" + || val == "TRUE" + || val == "yes" + || val == "YES" + || val == "1"; } - -int32 Config::GetIntDefault(char const* name, int32 def) +int32 Config::GetIntDefault(char const* name, int32 def) const { - ACE_TString val; - return GetValueHelper(name, val) ? atoi(val.c_str()) : def; + std::string val; + return GetValueHelper(name, val) ? std::stoi(val) : def; } - -float Config::GetFloatDefault(char const* name, float def) +float Config::GetFloatDefault(char const* name, float def) const { - ACE_TString val; - return GetValueHelper(name, val) ? (float)atof(val.c_str()) : def; + std::string val; + return GetValueHelper(name, val) ? std::stof(val) : def; } diff --git a/src/shared/Config/Config.h b/src/shared/Config/Config.h index 88b3dae065f..de90f6cb61f 100644 --- a/src/shared/Config/Config.h +++ b/src/shared/Config/Config.h @@ -1,9 +1,4 @@ /* - * Copyright (C) 2005-2011 MaNGOS - * Copyright (C) 2009-2011 MaNGOSZero - * Copyright (C) 2011-2016 Nostalrius - * Copyright (C) 2016-2017 Elysium Project - * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or @@ -24,46 +19,41 @@ #include "Common.h" #include "Platform/Define.h" -#include "ace/Configuration_Import_Export.h" #include "Policies/SingletonImp.h" #include "Policies/ThreadingModel.h" #include - -class ACE_Configuration_Heap; +#include class Config { - public: - using Lock = MaNGOS::ClassLevelLockable; + public: + using Lock = MaNGOS::ClassLevelLockable; - Config(); - ~Config(); + bool LoadFromFile(std::string const& filename); + bool Reload(); - bool SetSource(char const* file); - bool Reload(); + bool IsSet(char const* name) const; - std::string GetStringDefault(char const* name, char const* def); - bool GetBoolDefault(char const* name, bool const def = false); - int32 GetIntDefault(char const* name, int32 const def); - float GetFloatDefault(char const* name, float const def); + std::string GetStringDefault(char const* name, char const* def) const; + bool GetBoolDefault(char const* name, bool def) const; + int32 GetIntDefault(char const* name, int32 def) const; + float GetFloatDefault(char const* name, float def) const; - std::string GetFilename() const { return mFilename; } - bool GetValueHelper(char const* name, ACE_TString &result); + std::string GetFilename() const; + bool GetValueHelper(char const* name, std::string& result) const; - private: - friend class MaNGOS::Singleton; + private: + friend class MaNGOS::Singleton; - std::string mFilename; - ACE_Configuration_Heap* mConf; + bool ProcessLine(char const* line); - using LockType = std::mutex; - using GuardType = std::unique_lock; + std::string m_fileName; + std::unordered_map m_configMap; - std::string _filename; - LockType m_configLock; + using LockType = std::shared_timed_mutex; + mutable LockType m_configLock; }; -// Nostalrius : multithreading lock #define sConfig (MaNGOS::Singleton::Instance()) #endif diff --git a/src/shared/Crypto/Authentication/SRP6.cpp b/src/shared/Crypto/Authentication/SRP6.cpp index 9057f80da41..bd6bfa045d8 100644 --- a/src/shared/Crypto/Authentication/SRP6.cpp +++ b/src/shared/Crypto/Authentication/SRP6.cpp @@ -18,6 +18,7 @@ #include "Common.h" #include "Log.h" +#include "Errors.h" #include "Crypto/Hash/HMACSHA1.h" #include "Auth/base32.h" #include "SRP6.h" @@ -65,7 +66,7 @@ void SRP6::CalculateProof(std::string username) M.SetBinary(hashM.data(), hashM.size()); } -bool SRP6::CalculateSessionKey(uint8* lp_A, int l) +bool SRP6::CalculateSessionKey(uint8 const* lp_A, int l) { A.SetBinary(lp_A, l); @@ -152,7 +153,7 @@ void SRP6::HashSessionKey(void) K.SetBinary(vK, 40); } -bool SRP6::Proof(uint8* lp_M, int l) +bool SRP6::Proof(uint8 const* lp_M, int l) { if (!memcmp(M.AsByteArray().data(), lp_M, l)) return false; diff --git a/src/shared/Crypto/Authentication/SRP6.h b/src/shared/Crypto/Authentication/SRP6.h index 8b391247097..4e597e28a35 100644 --- a/src/shared/Crypto/Authentication/SRP6.h +++ b/src/shared/Crypto/Authentication/SRP6.h @@ -51,7 +51,7 @@ class SRP6 \param l the length of client public ephemeral (A) \return true on valid safeguard conditions otherwise false */ - bool CalculateSessionKey(uint8* lp_A, int l); + bool CalculateSessionKey(uint8 const* lp_A, int l); //! calculates the password verifier (v) /*! @@ -77,7 +77,7 @@ class SRP6 \param l the length of client proof (M) \return true if client and server proof matches otherwise false */ - bool Proof(uint8* lp_M, int l); + bool Proof(uint8 const* lp_M, int l); //! compare password verifier (v) /*! diff --git a/src/shared/Crypto/BigNumber.cpp b/src/shared/Crypto/BigNumber.cpp index b0fa8ff17ce..92eda694345 100644 --- a/src/shared/Crypto/BigNumber.cpp +++ b/src/shared/Crypto/BigNumber.cpp @@ -19,6 +19,7 @@ #include "Crypto/BigNumber.h" #include #include +#include BigNumber::BigNumber() { diff --git a/src/shared/Database/DBCFileLoader.cpp b/src/shared/Database/DBCFileLoader.cpp index 7f1b272fccd..9e36afd46ff 100644 --- a/src/shared/Database/DBCFileLoader.cpp +++ b/src/shared/Database/DBCFileLoader.cpp @@ -19,10 +19,11 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -#include - #include "DBCFileLoader.h" +#include +#include + DBCFileLoader::DBCFileLoader() { data = nullptr; diff --git a/src/shared/Database/Database.cpp b/src/shared/Database/Database.cpp index 6c678d8f265..a86ac6f0661 100644 --- a/src/shared/Database/Database.cpp +++ b/src/shared/Database/Database.cpp @@ -21,9 +21,12 @@ #include "Util.h" #include "Log.h" +#include "Errors.h" #include "DatabaseEnv.h" #include "Config/Config.h" #include "Database/SqlOperations.h" +#include "IO/Multithreading/CreateThread.h" +#include "Database.h" #include #include @@ -224,14 +227,14 @@ bool Database::InitDelayThread(std::string const& infoString) //New delay thread for delay execute SqlConnection* threadConnection = CreateConnection(); - if(!threadConnection->Initialize(infoString.c_str())) + if(!threadConnection->Initialize(infoString)) return false; std::shared_ptr tbody = std::make_shared(this, threadConnection); m_threadsBodies.emplace_back(tbody); - m_delayThreads.emplace_back([tbody](){ + m_delayThreads.emplace_back(IO::Multithreading::CreateThread("DB:" + threadConnection->DatabaseName(), [tbody](){ tbody->run(); - }); + })); return true; } @@ -392,20 +395,25 @@ std::unique_ptr Database::PQueryNamed(char const* format,...) } bool Database::Execute(char const* sql) +{ + return Execute(DbExecMode::CanBeAsync, sql); +} + +bool Database::Execute(DbExecMode mode, char const* sql) { if (!m_pAsyncConn) return false; - SqlTransaction * pTrans = m_TransStorage->get(); - if(pTrans) + SqlTransaction* pTrans = m_currentTransaction.get(); + if (pTrans) { - //add SQL request to trans queue + // add SQL request to trans queue pTrans->DelayExecute(new SqlPlainRequest(sql)); } else { - //if async execution is not available - if(!m_bAllowAsyncTransactions) + // if async execution is not available + if (!m_bAllowAsyncTransactions || mode == DbExecMode::MustBeSync) return DirectExecute(sql); // Simple sql statement @@ -415,6 +423,26 @@ bool Database::Execute(char const* sql) return true; } +bool Database::PExecute(DbExecMode mode, char const* format,...) +{ + if (!format) + return false; + + va_list ap; + char szQuery [MAX_QUERY_LEN]; + va_start(ap, format); + int res = vsnprintf(szQuery, MAX_QUERY_LEN, format, ap); + va_end(ap); + + if(res==-1) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "SQL Query truncated (and not execute) for format: %s",format); + return false; + } + + return Execute(mode, szQuery); +} + bool Database::PExecute(char const* format,...) { if (!format) @@ -460,21 +488,21 @@ bool Database::BeginTransaction(uint32 serialId) if (!m_pAsyncConn) return false; - MANGOS_ASSERT(!m_TransStorage->get()); // if we will get a nested transaction request - we MUST fix code!!! + MANGOS_ASSERT(!m_currentTransaction.get()); // if we will get a nested transaction request - we MUST fix code!!! + + m_currentTransaction.reset(new SqlTransaction(serialId)); - //initiate transaction on current thread - m_TransStorage->init(serialId); return true; } bool Database::InTransaction() { - return m_TransStorage->get() != nullptr; + return m_currentTransaction.get() != nullptr; } uint32 Database::GetTransactionSerialId() { - if (SqlTransaction *trans = m_TransStorage->get()) + if (SqlTransaction *trans = m_currentTransaction.get()) return trans->GetSerialId(); return 0; @@ -482,21 +510,17 @@ uint32 Database::GetTransactionSerialId() bool Database::CommitTransaction() { - if (!m_pAsyncConn) + // We must have a pending transaction + if (!m_pAsyncConn || !m_currentTransaction.get()) return false; - //check if we have pending transaction - //ASSERT(m_TransStorage->get()); - if (!m_TransStorage->get()) - return false; - - //if async execution is not available - if(!m_bAllowAsyncTransactions) + // if async execution is not available + if (!m_bAllowAsyncTransactions) return CommitTransactionDirect(); - //add SqlTransaction to the async queue + // add SqlTransaction to the async queue // if serial ID > 0, add to the serial delay queue - SqlTransaction *trans = m_TransStorage->detach(); + SqlTransaction* trans = m_currentTransaction.release(); if (trans->GetSerialId() > 0) AddToSerialDelayQueue(trans); else @@ -509,11 +533,12 @@ bool Database::CommitTransactionDirect() if (!m_pAsyncConn) return false; - //check if we have pending transaction - ASSERT (m_TransStorage->get()); + // check if we have pending transaction + if (!m_currentTransaction.get()) + return false; - //directly execute SqlTransaction - SqlTransaction * pTrans = m_TransStorage->detach(); + // directly execute SqlTransaction + SqlTransaction* pTrans = m_currentTransaction.release(); pTrans->Execute(m_pAsyncConn); delete pTrans; @@ -525,11 +550,11 @@ bool Database::RollbackTransaction() if (!m_pAsyncConn) return false; - if(!m_TransStorage->get()) + if (!m_currentTransaction.get()) return false; - //remove scheduled transaction - m_TransStorage->reset(); + // remove scheduled transaction + m_currentTransaction.reset(); return true; } @@ -620,15 +645,15 @@ bool Database::ExecuteStmt(SqlStatementID const& id, SqlStmtParameters* params) if (!m_pAsyncConn) return false; - SqlTransaction * pTrans = m_TransStorage->get(); - if(pTrans) + SqlTransaction* pTrans = m_currentTransaction.get(); + if (pTrans) { - //add SQL request to trans queue + // add SQL request to trans queue pTrans->DelayExecute(new SqlPreparedRequest(id.ID(), params)); } else { - //if async execution is not available + // if async execution is not available if(!m_bAllowAsyncTransactions) return DirectExecuteStmt(id, params); @@ -692,33 +717,3 @@ std::string Database::GetStmtString(int const stmtId) const return std::string(); } - -//HELPER CLASSES AND FUNCTIONS -Database::TransHelper::~TransHelper() -{ - reset(); -} - -SqlTransaction * Database::TransHelper::init(uint32 serialId) -{ - MANGOS_ASSERT(!m_pTrans); //if we will get a nested transaction request - we MUST fix code!!! - m_pTrans = new SqlTransaction(serialId); - - return m_pTrans; -} - -SqlTransaction * Database::TransHelper::detach() -{ - SqlTransaction * pRes = m_pTrans; - m_pTrans = nullptr; - return pRes; -} - -void Database::TransHelper::reset() -{ - if(m_pTrans) - { - delete m_pTrans; - m_pTrans = nullptr; - } -} diff --git a/src/shared/Database/Database.h b/src/shared/Database/Database.h index 75bfa049fae..2493005ccc3 100644 --- a/src/shared/Database/Database.h +++ b/src/shared/Database/Database.h @@ -25,8 +25,8 @@ #include #include "Database/SqlDelayThread.h" #include "Policies/ThreadingModel.h" -#include #include "SqlPreparedStatement.h" +#include "ThreadSpecificPtr.h" #include #include #include @@ -71,6 +71,8 @@ class SqlConnection //methods to work with prepared statements bool ExecuteStmt(int nIndex, SqlStmtParameters const& id); + std::string const& DatabaseName() const { return m_database; } + //SqlConnection object lock /// TODO make SqlConnection a shared_ptr? class Lock @@ -116,6 +118,12 @@ class SqlConnection StmtHolder m_holder; }; +enum class DbExecMode +{ + CanBeAsync, + MustBeSync, +}; + class Database { public: @@ -218,7 +226,10 @@ class Database template bool DelayQueryHolderUnsafe(void (*method)(std::unique_ptr, SqlQueryHolder*, ParamType1), SqlQueryHolder* holder, ParamType1 param1); + /// Unless in Sync mode, the return value just gives you a hint whenever or not the statement was added to be async queue bool Execute(char const* sql); + bool Execute(DbExecMode executionMode, char const* sql); + bool PExecute(DbExecMode executionMode, char const* format,...) ATTR_PRINTF(3,4); bool PExecute(char const* format,...) ATTR_PRINTF(2,3); // Writes SQL commands to a LOG file (see mangosd.conf "LogSQL") @@ -285,30 +296,8 @@ class Database //factory method to create SqlConnection objects virtual SqlConnection* CreateConnection() = 0; - class TransHelper - { - public: - TransHelper() : m_pTrans(nullptr) {} - ~TransHelper(); - - //initializes new SqlTransaction object - SqlTransaction * init(uint32 serialId); - //gets pointer on current transaction object. Returns nullptr if transaction was not initiated - SqlTransaction * get() const { return m_pTrans; } - //detaches SqlTransaction object allocated by init() function - //next call to get() function will return nullptr! - //do not forget to destroy obtained SqlTransaction object! - SqlTransaction * detach(); - //destroyes SqlTransaction allocated by init() function - void reset(); - - private: - SqlTransaction * m_pTrans; - }; - - //per-thread based storage for SqlTransaction object initialization - no locking is required - typedef ACE_TSS DBTransHelperTSS; - Database::DBTransHelperTSS m_TransStorage; + // per-thread based storage for SqlTransaction object initialization - no locking is required + MaNGOS::ThreadSpecificPtr m_currentTransaction; // DB connections diff --git a/src/shared/Database/DatabaseMysql.cpp b/src/shared/Database/DatabaseMysql.cpp index 864ce986a85..ebbdfaec400 100644 --- a/src/shared/Database/DatabaseMysql.cpp +++ b/src/shared/Database/DatabaseMysql.cpp @@ -24,6 +24,7 @@ #include #include #include "Log.h" +#include "Errors.h" #include "Util.h" #include "Policies/SingletonImp.h" #include "Platform/Define.h" diff --git a/src/shared/Database/SqlOperations.cpp b/src/shared/Database/SqlOperations.cpp index dbbcba503fa..60e0b8fcd95 100644 --- a/src/shared/Database/SqlOperations.cpp +++ b/src/shared/Database/SqlOperations.cpp @@ -153,7 +153,7 @@ using SqlResultQueueWorker = ThreadPool::SingleQueue; #endif SqlResultQueue::SqlResultQueue() : - numUnsafeQueries(0), m_callbackThreads(new ThreadPool(6)) + numUnsafeQueries(0), m_callbackThreads(new ThreadPool("SqlResult", 6)) { m_callbackThreads->start(); } diff --git a/src/shared/Database/SqlPreparedStatement.cpp b/src/shared/Database/SqlPreparedStatement.cpp index 86ab85067f0..69e4e36d39c 100644 --- a/src/shared/Database/SqlPreparedStatement.cpp +++ b/src/shared/Database/SqlPreparedStatement.cpp @@ -18,6 +18,7 @@ #include "DatabaseEnv.h" #include "Log.h" +#include "Errors.h" bool SqlStmtFieldData::toBool() const { MANGOS_ASSERT(m_type == FIELD_BOOL); return m_binaryData.boolean; } uint8 SqlStmtFieldData::toUint8() const { MANGOS_ASSERT(m_type == FIELD_UI8); return m_binaryData.ui8; } diff --git a/src/shared/Database/SqlPreparedStatement.h b/src/shared/Database/SqlPreparedStatement.h index de00fcc8e47..de75b6bdc45 100644 --- a/src/shared/Database/SqlPreparedStatement.h +++ b/src/shared/Database/SqlPreparedStatement.h @@ -20,7 +20,6 @@ #define SQLPREPAREDSTATEMENTS_H #include "Common.h" -#include #include #include diff --git a/src/shared/DelayExecutor.cpp b/src/shared/DelayExecutor.cpp deleted file mode 100644 index 95213a704fe..00000000000 --- a/src/shared/DelayExecutor.cpp +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (C) 2005-2010 MaNGOS - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -#include - -#include "DelayExecutor.h" -#include "Policies/SingletonImp.h" -#include "Policies/ThreadingModel.h" - - -using DelayExecutorLock = MaNGOS::ClassLevelLockable; -INSTANTIATE_SINGLETON_2(DelayExecutor, DelayExecutorLock); -INSTANTIATE_CLASS_MUTEX(DelayExecutor, std::mutex); - -DelayExecutor* DelayExecutor::instance() -{ - return &MaNGOS::Singleton::Instance(); -} - -DelayExecutor::DelayExecutor() - : pre_svc_hook_(0), post_svc_hook_(0), activated_(false) -{ -} - -DelayExecutor::~DelayExecutor() -{ - delete pre_svc_hook_; - delete post_svc_hook_; - deactivate(); -} - -int DelayExecutor::deactivate() -{ - if (!activated()) - return -1; - - activated(false); - queue_.queue()->deactivate(); - wait(); - - return 0; -} - -int DelayExecutor::svc() -{ - if (pre_svc_hook_) - pre_svc_hook_->call(); - - for (;;) - { - ACE_Method_Request* rq = queue_.dequeue(); - - if (!rq) - break; - - rq->call(); - delete rq; - } - - if (post_svc_hook_) - post_svc_hook_->call(); - - return 0; -} - -int DelayExecutor::activate(int num_threads, ACE_Method_Request* pre_svc_hook, ACE_Method_Request* post_svc_hook) -{ - if (activated()) - return -1; - - if (num_threads < 1) - return -1; - - delete pre_svc_hook_; - delete post_svc_hook_; - - pre_svc_hook_ = pre_svc_hook; - post_svc_hook_ = post_svc_hook; - - queue_.queue()->activate(); - - if (ACE_Task_Base::activate(THR_NEW_LWP | THR_JOINABLE | THR_INHERIT_SCHED, num_threads) == -1) - return -1; - - activated(true); - - return true; -} - -int DelayExecutor::execute(ACE_Method_Request* new_req) -{ - if (new_req == nullptr) - return -1; - - if (queue_.enqueue(new_req, (ACE_Time_Value*)&ACE_Time_Value::zero) == -1) - { - delete new_req; - ACE_ERROR_RETURN((LM_ERROR, ACE_TEXT("(%t) %p\n"), ACE_TEXT("DelayExecutor::execute enqueue")), -1); - } - - return 0; -} - -bool DelayExecutor::activated() -{ - return activated_; -} - -void DelayExecutor::activated(bool s) -{ - activated_ = s; -} diff --git a/src/shared/DelayExecutor.h b/src/shared/DelayExecutor.h deleted file mode 100644 index e7d7e9cb70e..00000000000 --- a/src/shared/DelayExecutor.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (C) 2005-2010 MaNGOS - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - */ - -#ifndef _M_DELAY_EXECUTOR_H -#define _M_DELAY_EXECUTOR_H - -#include -#include -#include - -class DelayExecutor : protected ACE_Task_Base -{ - public: - - DelayExecutor(); - virtual ~DelayExecutor(); - - static DelayExecutor* instance(); - - int execute(ACE_Method_Request* new_req); - - int activate(int num_threads = 1, ACE_Method_Request* pre_svc_hook = nullptr, ACE_Method_Request* post_svc_hook = nullptr); - - int deactivate(); - - bool activated(); - - virtual int svc(); - - private: - - ACE_Activation_Queue queue_; - ACE_Method_Request* pre_svc_hook_; - ACE_Method_Request* post_svc_hook_; - bool activated_; - - void activated(bool s); -}; - -#endif // _M_DELAY_EXECUTOR_H \ No newline at end of file diff --git a/src/shared/EnumFlag.h b/src/shared/EnumFlag.h new file mode 100644 index 00000000000..04ac75f29a2 --- /dev/null +++ b/src/shared/EnumFlag.h @@ -0,0 +1,150 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#ifndef MANGOS_ENUMFLAG_H +#define MANGOS_ENUMFLAG_H + +#include + +/* Example usage: +// In header +enum class MyCoolFlags +{ + None = 0, + FlagA = (1 << 0), + FlagB = (1 << 0), + FlagC = (1 << 0), + FlagD = (1 << 0), + FlagE = (1 << 0), +}; +DEFINE_ENUM_FLAG(MyCoolFlags); +// In CPP file +void MyCoolFunction(EnumFlag flags) +{ + if (flags.HasFlag(MyCoolFlags::FlagA)) + // Do something +} + */ +template +constexpr bool IsEnumFlag(T) { return false; } + +#define DEFINE_ENUM_FLAG(enumType) constexpr bool IsEnumFlag(enumType) { return true; } + +namespace EnumTraits +{ + template + using IsFlag = std::integral_constant; +} + +template +inline constexpr auto operator&(T left, T right) -> typename std::enable_if::value, T>::type +{ + return static_cast(static_cast>(left) & static_cast>(right)); +} + +template +inline constexpr auto operator&=(T& left, T right) -> typename std::enable_if::value, T&>::type +{ + return left = left & right; +} + +template +inline constexpr auto operator|(T left, T right) -> typename std::enable_if::value, T>::type +{ + return static_cast(static_cast>(left) | static_cast>(right)); +} + +template +inline constexpr auto operator|=(T& left, T right) -> typename std::enable_if::value, T&>::type +{ + return left = left | right; +} + +template +inline constexpr auto operator~(T value) -> typename std::enable_if::value, T>::type +{ + return static_cast(~static_cast>(value)); +} + +template +class EnumFlag +{ + static_assert(EnumTraits::IsFlag::value, "EnumFlag must be used only with enums that are marked as flags by DEFINE_ENUM_FLAG macro"); + +public: + /*implicit*/ constexpr EnumFlag(T value) : _value(value) + { + } + + EnumFlag& operator&=(EnumFlag right) + { + _value &= right._value; + return *this; + } + + constexpr friend EnumFlag operator&(EnumFlag left, EnumFlag right) + { + return left &= right; + } + + EnumFlag& operator|=(EnumFlag right) + { + _value |= right._value; + return *this; + } + + constexpr friend EnumFlag operator|(EnumFlag left, EnumFlag right) + { + return left |= right; + } + + constexpr EnumFlag operator~() const + { + return static_cast(~static_cast>(_value)); + } + + void RemoveFlag(EnumFlag flag) + { + _value &= ~flag._value; + } + + constexpr bool HasFlag(T flag) const + { + using i = std::underlying_type_t; + return static_cast((i)_value & (i)flag) != static_cast(0); + } + + constexpr bool HasAllFlags(T flags) const + { + return (_value & flags) == flags; + } + + constexpr operator T() const + { + return _value; + } + + constexpr std::underlying_type_t AsUnderlyingType() const + { + return static_cast>(_value); + } + +private: + T _value; +}; + +#endif //MANGOS_ENUMFLAG_H diff --git a/src/shared/Errors.cpp b/src/shared/Errors.cpp new file mode 100644 index 00000000000..9f4e1783b75 --- /dev/null +++ b/src/shared/Errors.cpp @@ -0,0 +1,81 @@ +#include "Errors.h" +#include "Log.h" + +#include + +void MaNGOS::Errors::PrintStacktrace() +{ + PrintStacktrace(1, 64); +} + +void MaNGOS::Errors::PrintStacktrace(int skipFrames, int maxFrames) +{ + cpptrace::stacktrace st = cpptrace::generate_trace( + std::size_t(skipFrames) + 1, // we want to skip our own frame + std::size_t(maxFrames) + ); + + bool hasStacktraceInfo = false; + for (size_t i = 0; i < st.frames.size(); ++i) + { + cpptrace::stacktrace_frame const& trace = st.frames[i]; + if (trace.line.has_value()) + { + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, + "#%zu [0x%" PRIXPTR "] %s %s:%u", + i, + trace.object_address, + trace.symbol.c_str(), + trace.filename.c_str(), + trace.line.value() + ); + } + else + { + // without line + sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, + "#%zu [0x%" PRIXPTR "] %s %s", + i, + trace.object_address, + trace.symbol.c_str(), + trace.filename.c_str() + ); + } + + if (!hasStacktraceInfo && trace.line.has_value() && !trace.symbol.empty()) + { + // we assume there are symbols if at least one frame was parsed successfully + hasStacktraceInfo = true; + } + } + + if (!hasStacktraceInfo) + { + // without line +#if PLATFORM == PLATFORM_WINDOWS + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Missing debug symbols. Place an up-to-date PDB file next to executable and/or build with debug symbols."); +#else + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Missing debug symbols. Please build with debug symbols."); +#endif + } +} + +[[noreturn]] +void MaNGOS::Errors::PrintStacktraceAndThrow(char const* filename, int line, char const* functionName, char const* failedExpression, char const* message) +{ + if (message) + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "%s:%i Error: Assertion in %s: %s (%s)", filename, line, functionName, failedExpression, message); + else + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "%s:%i Error: Assertion in %s: %s", filename, line, functionName, failedExpression); + + MaNGOS::Errors::PrintStacktrace(1, 32); + + std::string completeMessage = failedExpression; + if (message) + completeMessage += std::string(" Message: ") + message; + + throw std::runtime_error(completeMessage); + + // Just in case the std::runtime_error was ignored by a debugger, we throw an assert. + assert("MANGOS_ASSERT throw was skipped" && false); +} diff --git a/src/shared/Errors.h b/src/shared/Errors.h index a5ec27b1a00..15a83f26080 100644 --- a/src/shared/Errors.h +++ b/src/shared/Errors.h @@ -22,71 +22,33 @@ #ifndef MANGOSSERVER_ERRORS_H #define MANGOSSERVER_ERRORS_H -#include "Common.h" +namespace MaNGOS { namespace Errors +{ + /// Prints a stack trace to `sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ...)` and will then terminate the program + [[noreturn]] + void PrintStacktraceAndThrow(char const* filename, int line, char const* functionName, char const* failedExpression, char const* message = nullptr); -//#ifndef HAVE_CONFIG_H -#define HAVE_ACE_STACK_TRACE_H 1 -//#endif + /// Prints a stack trace to `sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ...)` + void PrintStacktrace(); -#ifdef HAVE_ACE_STACK_TRACE_H -#include "ace/Stack_Trace.h" -#include "Log.h" // sLog in only used when HAVE_ACE_STACK_TRACE_H -#endif - -#ifdef HAVE_ACE_STACK_TRACE_H -// Normal assert. -#define WPError(CONDITION) \ -if (!(CONDITION)) \ -{ \ - ACE_Stack_Trace st; \ - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "%s:%i: Error: Assertion in %s failed: %s", \ - __FILE__, __LINE__, __FUNCTION__, STRINGIZE(CONDITION)); \ - sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, "%s", st.c_str()); \ - throw std::runtime_error(STRINGIZE(CONDITION)); \ - assert(STRINGIZE(CONDITION) && 0); \ -} - -// Just warn. -#define WPWarning(CONDITION) \ -if (!(CONDITION)) \ -{ \ - ACE_Stack_Trace st; \ - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "%s:%i: Warning: Assertion in %s failed: %s",\ - __FILE__, __LINE__, __FUNCTION__, STRINGIZE(CONDITION)); \ - sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "%s", st.c_str()); \ -} -#else -// Normal assert. -#define WPError(CONDITION) \ -if (!(CONDITION)) \ -{ \ - printf("%s:%i: Error: Assertion in %s failed: %s", \ - __FILE__, __LINE__, __FUNCTION__, STRINGIZE(CONDITION)); \ - assert(STRINGIZE(CONDITION) && 0); \ -} + /// Prints a stack trace to `sLog.Out(LOG_BASIC, LOG_LVL_MINIMAL, ...)` + void PrintStacktrace(int skipFrames, int maxFrames); +}} // namespace MaNGOS::Errors -// Just warn. -#define WPWarning(CONDITION) \ -if (!(CONDITION)) \ -{ \ - printf("%s:%i: Warning: Assertion in %s failed: %s",\ - __FILE__, __LINE__, __FUNCTION__, STRINGIZE(CONDITION)); \ -} -#endif +/// Just a macro that converse a raw string to a quoted "string" +#define MANGOS_ERROR_STRING_ESCAPE(a) #a -#define ASSERT MANGOS_ASSERT -#ifndef MANGOS_ASSERT -#ifdef MANGOS_DEBUG -# define MANGOS_ASSERT WPError -#else -# define MANGOS_ASSERT WPError // Error even if in release mode. -#endif -#endif +/// +/// MANGOS_ASSERT(abc == 2); // will throw if abc is not 2 +/// +#define MANGOS_ASSERT(condition) do { if (!(condition)) { MaNGOS::Errors::PrintStacktraceAndThrow(__FILE__, __LINE__, __FUNCTION__, MANGOS_ERROR_STRING_ESCAPE(condition)); } } while(0) #ifdef MANGOS_DEBUG #define MANGOS_DEBUG_ASSERT(x) MANGOS_ASSERT(x) #else -#define MANGOS_DEBUG_ASSERT(x) +#define MANGOS_DEBUG_ASSERT(x) do {} while(0) #endif +#define ASSERT MANGOS_ASSERT + #endif diff --git a/src/shared/IO/Context/AsyncIoOperation.h b/src/shared/IO/Context/AsyncIoOperation.h new file mode 100644 index 00000000000..060600e9808 --- /dev/null +++ b/src/shared/IO/Context/AsyncIoOperation.h @@ -0,0 +1,61 @@ +#ifndef MANGOS_IO_IOOPERATION_H +#define MANGOS_IO_IOOPERATION_H + +#include +#if defined(WIN32) +#include +#include +#include "../../Errors.h" + +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN +#endif + +namespace IO +{ +#if defined(WIN32) + class IocpOperationTask : public OVERLAPPED + { + public: + void InitNew(std::function const& callback) + { + MANGOS_DEBUG_ASSERT(m_callback == nullptr && callback != nullptr); + + Internal = 0; + InternalHigh = 0; + Offset = 0; + OffsetHigh = 0; + hEvent = nullptr; + m_callback = callback; + } + + void Reset() + { + MANGOS_DEBUG_ASSERT(m_callback != nullptr); + + m_callback = nullptr; + } + + void OnComplete(DWORD errorCode) + { + m_callback(errorCode); + } + + std::function m_callback = nullptr; + }; + + typedef IocpOperationTask AsyncIoOperation; +#elif defined(__linux__) || defined(__APPLE__) + class SystemIoEventReceiver + { + public: + /// @param event On Linux: EPOLL flags, will be 0 for immediate events or a bitmask (multiple) of epoll events (e.g. EPOLLIN, EPOLLOUT, ...) + /// @param event On Macos: kqueue filter, will be EVFILT_USER for immediate events one of kqueue filter (e.g. EVFILT_READ, EVFILT_WRITE, ...) + virtual void OnIoEvent(uint32_t event) = 0; + }; + typedef SystemIoEventReceiver AsyncIoOperation; +#endif +} + +#endif //MANGOS_IO_IOOPERATION_H diff --git a/src/shared/IO/Context/IoContext.h b/src/shared/IO/Context/IoContext.h new file mode 100644 index 00000000000..eb74f831c18 --- /dev/null +++ b/src/shared/IO/Context/IoContext.h @@ -0,0 +1,88 @@ +#ifndef MANGOS_IO_IOCONTEXT_H +#define MANGOS_IO_IOCONTEXT_H + +#include +#include "./AsyncIoOperation.h" + +#if defined(WIN32) +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN +#endif + +#if defined(__linux__) +#include "../NativeAliases.h" +#include "mutex" +#include "queue" + +enum class IoContextEpollTargetType // this is used in `(epoll_event).data.u32` to decide what to do with it +{ + IoEventReceiverFunction = 0, // ptr will be a pointer to a IO::SystemIoEventReceiver + ContextSwitchRequest = 1, // will only be called by m_contextSwitchRequestPipe.readHead +}; +#endif + +#if defined(__APPLE__) +#include "../NativeAliases.h" +#endif + +namespace IO +{ + class IoContext + { + public: + /// Returns nullptr in case of an error + static std::unique_ptr CreateIoContext(); + ~IoContext(); + IoContext(IoContext const&) = delete; + IoContext& operator=(IoContext const&) = delete; + IoContext(IoContext&&) = delete; + IoContext& operator=(IoContext&&) = delete; + + /// Will run the IO loop until .Shutdown() is called. + /// It is allowed to execute this function from multiple threads at the same time. + /// But try to limit this to a reasonable amount and not have more threads than (V)Cores on your CPU. + void RunUntilShutdown(); + bool IsRunning() const; + + void Shutdown(); + +#if defined(WIN32) + HANDLE GetWindowsCompletionPort() const; +#elif defined(__linux__) + IO::Native::FileHandle GetUnixEpollDescriptor() const { return m_epollDescriptor; } +#elif defined(__APPLE__) + IO::Native::FileHandle GetKqueueDescriptor() const { return m_kqueueDescriptor; } +#endif + +#if defined(WIN32) + void PostOperationForImmediateInvocation(IO::IocpOperationTask* operation); +#elif defined(__linux__) || defined(__APPLE__) + /// On linux with epoll: Invokes {SystemIoEventReceiver::OnIoEvent} in IO thread with parameter 0 + /// On macos with kqueue: Invokes {SystemIoEventReceiver::OnIoEvent} in IO thread with parameter EVFILT_USER + void PostForImmediateInvocation(IO::SystemIoEventReceiver* eventReceiver); +#endif + + private: + volatile bool m_isRunning; + +#if defined(WIN32) + explicit IoContext(HANDLE completionPort); + HANDLE m_completionPort; + volatile uint32_t m_runningThreadsCount; +#elif defined(__linux__) + IO::Native::FileHandle const m_epollDescriptor; + IO::Native::FileHandle const m_contextSwitchNotifyEventFd; + + std::mutex m_contextSwitchQueueLock; + std::queue m_contextSwitchQueue; + + explicit IoContext(IO::Native::FileHandle epollDescriptor, IO::Native::FileHandle contextSwitchEventFd); +#elif defined(__APPLE__) + IO::Native::FileHandle const m_kqueueDescriptor; + explicit IoContext(IO::Native::FileHandle kqueueDescriptor); +#endif + }; +} + +#endif //MANGOS_IO_IOCONTEXT_H diff --git a/src/shared/IO/Context/IoContext_macos.cpp b/src/shared/IO/Context/IoContext_macos.cpp new file mode 100644 index 00000000000..2bda5f9d827 --- /dev/null +++ b/src/shared/IO/Context/IoContext_macos.cpp @@ -0,0 +1,79 @@ +#include +#include +#include "Log.h" +#include "IoContext.h" +#include "IO/SystemErrorToString.h" + +IO::IoContext::IoContext(IO::Native::FileHandle kqueueDescriptor) + : m_kqueueDescriptor(kqueueDescriptor), m_isRunning{true} +{ +} + +IO::IoContext::~IoContext() +{ + ::close(m_kqueueDescriptor); +} + +std::unique_ptr IO::IoContext::CreateIoContext() +{ + // Initialize our main kqueue + int kqueueDescriptor = ::kqueue(); + if (kqueueDescriptor == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateIoContext() -> ::kqueue(...) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + return std::unique_ptr(new IO::IoContext(kqueueDescriptor)); +} + +void IO::IoContext::RunUntilShutdown() +{ + int const maxEventsPerLoop = 250; + + struct timespec timeout; + timeout.tv_sec = 0; + timeout.tv_nsec = 500 * 1000000; // 500 milliseconds in nanoseconds + + struct kevent events[maxEventsPerLoop]; + + while (m_isRunning) + { + int numEvents = ::kevent(m_kqueueDescriptor, nullptr, 0, events, maxEventsPerLoop, &timeout); + if (numEvents == -1) + { + if (errno != EINTR) // ignore interrupted system call + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "RunEventLoop -> ::kevent(...) Error: %s", SystemErrorToString(errno).c_str()); + continue; + } + + for (int i = 0; i < numEvents; i++) + { + struct kevent const& event = events[i]; + ((SystemIoEventReceiver*)(event.udata))->OnIoEvent(event.filter); + } + } +} + +bool IO::IoContext::IsRunning() const +{ + return m_isRunning; +} + +void IO::IoContext::Shutdown() +{ + m_isRunning = false; +} + +void IO::IoContext::PostForImmediateInvocation(IO::SystemIoEventReceiver* eventReceiver) +{ + struct kevent addedEvent{}; + + // Create and trigger a one-time event + EV_SET(&addedEvent, (uint64_t)(eventReceiver), EVFILT_USER, EV_ADD | EV_ONESHOT, NOTE_TRIGGER, 0, eventReceiver); + + if (::kevent(m_kqueueDescriptor, &addedEvent, 1, nullptr, 0, nullptr) == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "PostKqueueEventForImmediateExecution() -> ::kevent(...) Error: %s", SystemErrorToString(errno).c_str()); + } +} diff --git a/src/shared/IO/Context/IoContext_unix.cpp b/src/shared/IO/Context/IoContext_unix.cpp new file mode 100644 index 00000000000..d94d6831212 --- /dev/null +++ b/src/shared/IO/Context/IoContext_unix.cpp @@ -0,0 +1,112 @@ +#include +#include +#include +#include +#include "Log.h" +#include "IoContext.h" +#include "IO/SystemErrorToString.h" + +IO::IoContext::IoContext(IO::Native::FileHandle epollDescriptor, IO::Native::FileHandle contextSwitchEventFd) + : m_epollDescriptor(epollDescriptor), m_contextSwitchNotifyEventFd(contextSwitchEventFd), m_isRunning{true} +{ +} + +IO::IoContext::~IoContext() +{ + ::close(m_contextSwitchNotifyEventFd); + ::close(m_epollDescriptor); +} + +std::unique_ptr IO::IoContext::CreateIoContext() +{ + // Initialize our main epoll queue + int const epollSizeHint = 50; // <-- hint, how much initial epoll space we want to have. But in modern kernels this is ignored anyway. + int epollDescriptor = ::epoll_create(epollSizeHint); + if (epollDescriptor == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateIoContext() -> ::epoll_create(...) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + // Add eventfd, where we can listen to incoming context switch events + uint32_t constexpr initialCounter = 0; + IO::Native::FileHandle contextSwitchEventFd = ::eventfd(initialCounter, 0); + if (contextSwitchEventFd == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateIoContext() -> ::eventfd(...) : %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + // Add our contextSwitchEventFd to the epoll set + struct epoll_event event{}; + event.events = EPOLLIN | EPOLLET; // We are using edge here, since we are just using it as a "once" signalling process system + event.data.u32 = static_cast(IoContextEpollTargetType::ContextSwitchRequest); + if (::epoll_ctl(epollDescriptor, EPOLL_CTL_ADD, contextSwitchEventFd, &event) == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateIoContext() -> ::epoll_ctl(...) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + return std::unique_ptr(new IO::IoContext(epollDescriptor, contextSwitchEventFd)); +} + +void IO::IoContext::RunUntilShutdown() +{ + int const maxEventsPerLoop = 250; + + struct epoll_event events[maxEventsPerLoop]; + + while (m_isRunning) + { + int numEvents = ::epoll_wait(m_epollDescriptor, events, maxEventsPerLoop, 500); + if (numEvents == -1) + { + if (errno != EINTR) // ignore interrupted system call + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "RunEventLoop -> ::epoll_wait(...) Error: %s", SystemErrorToString(errno).c_str()); + continue; + } + + for (int i = 0; i < numEvents; i++) + { + struct epoll_event const& event = events[i]; + + if (event.data.u32 == static_cast(IoContextEpollTargetType::ContextSwitchRequest)) + { + while (!m_contextSwitchQueue.empty()) + { + IO::SystemIoEventReceiver* eventReceiver; + { + std::lock_guard lock(m_contextSwitchQueueLock); + if (m_contextSwitchQueue.empty()) // re-check after we locked the queue if it's really not empty + continue; + eventReceiver = m_contextSwitchQueue.front(); + m_contextSwitchQueue.pop(); + } + eventReceiver->OnIoEvent(0); + } + } + else + { + ((SystemIoEventReceiver*)(event.data.ptr))->OnIoEvent(event.events); + } + } + } +} + +bool IO::IoContext::IsRunning() const +{ + return m_isRunning; +} + +void IO::IoContext::Shutdown() +{ + m_isRunning = false; +} + +void IO::IoContext::PostForImmediateInvocation(IO::SystemIoEventReceiver* eventReceiver) +{ + m_contextSwitchQueueLock.lock(); + m_contextSwitchQueue.push(eventReceiver); + m_contextSwitchQueueLock.unlock(); + ::eventfd_write(m_contextSwitchNotifyEventFd, 1); +} diff --git a/src/shared/IO/Context/IoContext_windows.cpp b/src/shared/IO/Context/IoContext_windows.cpp new file mode 100644 index 00000000000..c835f0800c9 --- /dev/null +++ b/src/shared/IO/Context/IoContext_windows.cpp @@ -0,0 +1,100 @@ +#include "IoContext.h" +#include "Log.h" +#include + +std::unique_ptr IO::IoContext::CreateIoContext() +{ + DWORD constexpr numberOfMaxThreads = 0; // 0 means as many as there are threads on the system + ULONG_PTR completionKey = 0; + HANDLE completionPort = ::CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, completionKey, numberOfMaxThreads); + if (completionPort == nullptr) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::CreateIoCompletionPort(root, ...) Error: %u", GetLastError()); + return nullptr; + } + return std::unique_ptr(new IoContext(completionPort)); +} + +IO::IoContext::IoContext(HANDLE completionPort) : m_isRunning(true), m_completionPort(completionPort), m_runningThreadsCount(0) +{ +} + +IO::IoContext::~IoContext() +{ + if (m_isRunning) + { + Shutdown(); + } +} + +void IO::IoContext::RunUntilShutdown() +{ + ULONG_PTR completionKey = 0; + IocpOperationTask* task = nullptr; + + DWORD bytesWritten = 0; + DWORD constexpr maxWait = INFINITE; + + m_runningThreadsCount++; + while (m_isRunning) + { + bool isOkay = ::GetQueuedCompletionStatus(m_completionPort, &bytesWritten, &completionKey, reinterpret_cast(&task), maxWait); + + if (task) + { + task->OnComplete(isOkay ? 0 : ::GetLastError()); + } + else + { + DWORD errorCode = ::GetLastError(); + if (errorCode != WAIT_TIMEOUT && m_isRunning) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::GetQueuedCompletionStatus(...) Has no TASK!!! Error: %u", errorCode); + } + std::this_thread::yield(); // wait one os tick to try again + } + } + m_runningThreadsCount--; +} + +bool IO::IoContext::IsRunning() const +{ + return m_isRunning; +} + +void IO::IoContext::Shutdown() +{ + if (m_isRunning) + { + uint32_t runningThreadsCountLocal = m_runningThreadsCount; // local count to prevent race condition after `running = false` + m_isRunning = false; + + // We need to wake up the running threads by sending a "null-completion-event" and wait until all thread stopped + for (uint32_t i = 0; i < runningThreadsCountLocal; i++) + { + ::PostQueuedCompletionStatus(m_completionPort, 0, 0, nullptr); + } + while (m_runningThreadsCount > 0) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + + ::CloseHandle(m_completionPort); + m_completionPort = nullptr; + } +} + +void IO::IoContext::PostOperationForImmediateInvocation(IO::AsyncIoOperation* task) +{ + ULONG_PTR completionKey = 0; + if (!::PostQueuedCompletionStatus(m_completionPort, 0, completionKey, task)) + { + DWORD error = ::GetLastError(); + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::PostQueuedCompletionStatus(...) Error: %u", error); + } +} + +HANDLE IO::IoContext::GetWindowsCompletionPort() const +{ + return m_completionPort; +} diff --git a/src/shared/IO/Filesystem/FileHandle.h b/src/shared/IO/Filesystem/FileHandle.h new file mode 100644 index 00000000000..13496d49c16 --- /dev/null +++ b/src/shared/IO/Filesystem/FileHandle.h @@ -0,0 +1,68 @@ +#ifndef MANGOS_IO_FILESYSTEM_FILEHANDLE_H +#define MANGOS_IO_FILESYSTEM_FILEHANDLE_H + +#include +#include +#include +#include +#include "IO/NativeAliases.h" + +namespace IO { namespace Filesystem { + + enum class SeekDirection + { + /// Seek from start of the file (allows positive numbers) + Start, + /// Seek from current position (allows negative and positive numbers) + Current, + /// Seek from current position (allows negative numbers) + End, + }; + + class FileHandle + { + public: + ~FileHandle(); + FileHandle(FileHandle const&) = delete; + FileHandle& operator=(FileHandle const&) = delete; + FileHandle(FileHandle&&) = delete; + FileHandle& operator=(FileHandle&&) = delete; + + void Seek(SeekDirection direction, int64_t offset); + + [[nodiscard]] + uint64_t GetTotalFileSize() const; + + [[nodiscard]] + std::chrono::system_clock::time_point GetLastModifyDate() const; + + /// Returns the file path used to open this file + [[nodiscard]] + std::string GetFilePath() const; + + protected: + explicit FileHandle(std::string filePath, IO::Native::FileHandle nativeFileHandle) : m_filePath(std::move(filePath)), m_nativeFileHandle(nativeFileHandle) {}; + std::string m_filePath; + IO::Native::FileHandle m_nativeFileHandle; + }; + + class FileHandleReadonly : public FileHandle + { + public: + explicit FileHandleReadonly(std::string const& filePath, IO::Native::FileHandle nativeFileHandle) : FileHandle(filePath, nativeFileHandle) {}; + FileHandleReadonly(FileHandleReadonly const&) = delete; + FileHandleReadonly& operator=(FileHandleReadonly const&) = delete; + FileHandleReadonly(FileHandleReadonly&&) = delete; + FileHandleReadonly& operator=(FileHandleReadonly&&) = delete; + + /// If return value is smaller than `amountToRead`, the end of file is reached + uint64_t ReadSync(uint8_t* dest, uint64_t amountToRead); + inline uint64_t ReadSync(int8_t* dest, uint64_t amountToRead) { return ReadSync((uint8_t*) dest, amountToRead); }; + + [[nodiscard]] + std::unique_ptr DuplicateFileHandle(); + }; + +}} // namespace IO::Filesystem + +#endif //MANGOS_IO_FILESYSTEM_FILEHANDLE_H diff --git a/src/shared/IO/Filesystem/FileSystem.h b/src/shared/IO/Filesystem/FileSystem.h new file mode 100644 index 00000000000..f0a98220c86 --- /dev/null +++ b/src/shared/IO/Filesystem/FileSystem.h @@ -0,0 +1,34 @@ +#ifndef MANGOS_IO_FILESYSTEM_H +#define MANGOS_IO_FILESYSTEM_H + +#include +#include +#include +#include "EnumFlag.h" +#include "FileHandle.h" + +namespace IO { namespace Filesystem { + enum class OutputFilePath + { + JustFileName, + FullFilePath, + }; + + /// This function will open a file in read shared and binary mode + /// You have to check the resulting pointer for nullptr! + /// If the file does not exists or you dont have permission to open it the ptr will be null + [[nodiscard("You need to use the file handle, otherwise the file will close immediately again")]] + std::unique_ptr TryOpenFileReadonly(std::string const& filePath); + + /// Will convert a partial path like "./data/myCoolFile.txt" to a complete absolute path like "/home/user/data/myCoolFile.txt" + [[nodiscard]] + std::string ToAbsolutePath(std::string const& partialPath); + + /// Returns all files in a folder, non-recursively. + /// if OutputFilePath::JustFileName the path will be based on the folderPath e.g. "myCoolFile.txt" + /// if OutputFilePath::FullFilePath the path will be absolute e.g. "/home/user/data/myCoolFile.txt" + [[nodiscard]] + std::vector GetAllFilesInFolder(std::string const& folderPath, OutputFilePath filePathOption); +}} // namespace IO::Filesystem + +#endif //MANGOS_IO_FILESYSTEM_H diff --git a/src/shared/IO/Filesystem/impl/unix/FileHandle.cpp b/src/shared/IO/Filesystem/impl/unix/FileHandle.cpp new file mode 100644 index 00000000000..c2754bb0ba3 --- /dev/null +++ b/src/shared/IO/Filesystem/impl/unix/FileHandle.cpp @@ -0,0 +1,82 @@ +#include "IO/Filesystem/FileHandle.h" +#include "Log.h" +#include "IO/SystemErrorToString.h" +#include "Errors.h" + +#include +#include + +IO::Filesystem::FileHandle::~FileHandle() +{ + ::close(m_nativeFileHandle); +} + +void IO::Filesystem::FileHandle::Seek(IO::Filesystem::SeekDirection direction, int64_t offset) +{ + int nativeDirection = direction == SeekDirection::Start ? SEEK_SET + : direction == SeekDirection::Current ? SEEK_CUR + : SEEK_END; + + MANGOS_ASSERT(::lseek(m_nativeFileHandle, offset, nativeDirection) != -1); +} + +uint64_t IO::Filesystem::FileHandle::GetTotalFileSize() const +{ + struct stat file_stat; + + if (::fstat(m_nativeFileHandle, &file_stat) == -1) + throw std::runtime_error("GetTotalFileSize -> ::fstat() Failed: " + SystemErrorToString(errno)); + + return file_stat.st_size; +} + +std::chrono::system_clock::time_point IO::Filesystem::FileHandle::GetLastModifyDate() const +{ + struct stat file_stat; + + if (::fstat(m_nativeFileHandle, &file_stat) == -1) + throw std::runtime_error("GetLastModifyDate -> ::fstat() Failed: " + SystemErrorToString(errno)); + + uint64_t unixSecs = file_stat.st_mtime; + + std::chrono::system_clock::time_point result(std::chrono::duration_cast(std::chrono::seconds(unixSecs))); + return result; +} + +std::string IO::Filesystem::FileHandle::GetFilePath() const +{ + return m_filePath; +} + +uint64_t IO::Filesystem::FileHandleReadonly::ReadSync(uint8_t* dest, uint64_t amountToRead) +{ + uint64_t leftToRead = amountToRead; + while (leftToRead > 0) + { + size_t amountToReadThisCycle = (size_t) std::min(leftToRead, ((uint64_t) std::numeric_limits::max()) - 1); + + ssize_t actuallyReadThisCycle = ::read(m_nativeFileHandle, dest, amountToReadThisCycle); + if (actuallyReadThisCycle == -1) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "ReadSync -> ::read() Error: %s", SystemErrorToString(errno).c_str()); + return 0; + } + leftToRead -= actuallyReadThisCycle; + + if (actuallyReadThisCycle != amountToReadThisCycle) + { + break; + } + } + return amountToRead - leftToRead; +} + +std::unique_ptr IO::Filesystem::FileHandleReadonly::DuplicateFileHandle() +{ + IO::Native::FileHandle newNativeFileHandle = ::dup(m_nativeFileHandle); + + if (newNativeFileHandle == -1) + throw std::runtime_error("DuplicateFileHandle -> ::dup() Failed: " + SystemErrorToString(errno)); + + return std::make_unique(m_filePath, newNativeFileHandle); +} diff --git a/src/shared/IO/Filesystem/impl/unix/FileSystem.cpp b/src/shared/IO/Filesystem/impl/unix/FileSystem.cpp new file mode 100644 index 00000000000..24893b5bf0f --- /dev/null +++ b/src/shared/IO/Filesystem/impl/unix/FileSystem.cpp @@ -0,0 +1,79 @@ +#include "IO/Filesystem/FileSystem.h" +#include "IO/Filesystem/FileHandle.h" +#include "Log.h" +#include "IO/SystemErrorToString.h" +#include +#include +#include +#include + +/// This function will open a file in read shared and binary mode +/// You have to check the resulting pointer for nullptr! +/// If the file does not exists or you dont have permission to open it the ptr will be null +std::unique_ptr IO::Filesystem::TryOpenFileReadonly(std::string const& filePath) +{ + int nativeFlags = O_RDONLY; + IO::Native::FileHandle fileHandle = ::open(filePath.c_str(), nativeFlags); + + if (fileHandle == -1) { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Unable to open file. Error %s on file: %s", SystemErrorToString(errno).c_str(), filePath.c_str()); + return nullptr; + } + + return std::unique_ptr(new FileHandleReadonly(filePath, fileHandle)); +} + +/// Will convert a partial path like "./data/myCoolFile.txt" to a complete absolute path like "/home/user/data/myCoolFile.txt" +std::string IO::Filesystem::ToAbsolutePath(std::string const& partialPath) +{ + // There is no absolute/canonicalize path function in linux. + // There is :realpath, but it requires the file/folder to be present + + if (partialPath.find('/') == 0) + return partialPath; // already absolute + + std::string trimmedPath = (partialPath.find("./") == 0) + ? partialPath.substr(2) // starts with "relative from CWD path" + : partialPath; + + char temp[PATH_MAX]; + if (::getcwd(temp, sizeof(temp)) == nullptr) + throw std::runtime_error("ToAbsolutePath -> ::getcwd(...) Failed: " + SystemErrorToString(errno)); + + // TODO: Find a way to canonicalize_filepath without reimplementing it by my own + // TODO: For own impl: Keep in mind all the stuff that can be included like "../abc/../.\.\//d" or "./abc\./.conf/bash.config" + std::string completePath = std::string(temp) + "/" + trimmedPath; + return completePath; +} + +/// Returns all files in a folder, non-recursively. +/// if OutputFilePath::JustFileName the path will be based on the folderPath e.g. "myCoolFile.txt" +/// if OutputFilePath::FullFilePath the path will be absolute e.g. "/home/user/data/myCoolFile.txt" +std::vector IO::Filesystem::GetAllFilesInFolder(std::string const& folderPath, IO::Filesystem::OutputFilePath filePathOption) +{ + std::vector files; + DIR* dir = opendir(folderPath.c_str()); + if (dir == nullptr) + return files; + + std::string safeFolderPath = (folderPath.rfind('/') == (folderPath.size() - 1)) + ? folderPath // folder already ends with / + : folderPath + "/"; + + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) + { + std::string fullFilePath = safeFolderPath + entry->d_name; // we need the fullFilePath to check if it's a file via ::stat(...) + struct stat info{}; + if (::stat(fullFilePath.c_str(), &info) == 0 && S_ISREG(info.st_mode)) // S_ISREG means IsRegularFile + { + std::string filePath = filePathOption == OutputFilePath::FullFilePath + ? fullFilePath + : entry->d_name; + + files.emplace_back(filePath); + } + } + ::closedir(dir); + return files; +} diff --git a/src/shared/IO/Filesystem/impl/windows/FileHandle.cpp b/src/shared/IO/Filesystem/impl/windows/FileHandle.cpp new file mode 100644 index 00000000000..2b4e09a8eef --- /dev/null +++ b/src/shared/IO/Filesystem/impl/windows/FileHandle.cpp @@ -0,0 +1,97 @@ +#include "IO/Filesystem/FileHandle.h" +#include "Log.h" +#include "Errors.h" + +IO::Filesystem::FileHandle::~FileHandle() +{ + ::CloseHandle(m_nativeFileHandle); +} + +void IO::Filesystem::FileHandle::Seek(IO::Filesystem::SeekDirection direction, int64_t offset) +{ + DWORD nativeDirection = direction == SeekDirection::Start ? FILE_BEGIN + : direction == SeekDirection::Current ? FILE_CURRENT + : FILE_END; + + LARGE_INTEGER distanceToMove; + distanceToMove.QuadPart = offset; + MANGOS_ASSERT(::SetFilePointerEx(m_nativeFileHandle, distanceToMove, nullptr, nativeDirection)); +} + +uint64_t IO::Filesystem::FileHandle::GetTotalFileSize() const +{ + LARGE_INTEGER fileSize; + bool isOkay = ::GetFileSizeEx(m_nativeFileHandle, &fileSize); + if (!isOkay) + throw std::runtime_error("GetTotalFileSize -> ::GetFileSizeEx() Failed, ErrorCode = " + std::to_string(GetLastError())); + + return fileSize.QuadPart; +} + +std::chrono::system_clock::time_point IO::Filesystem::FileHandle::GetLastModifyDate() const +{ + FILETIME nativeWinFileTime; + bool isOkay = ::GetFileTime(m_nativeFileHandle, nullptr, nullptr, &nativeWinFileTime); + if (!isOkay) + throw std::runtime_error("GetLastModifyDate -> ::GetFileTime() Failed, ErrorCode = " + std::to_string(GetLastError())); + + // Convert FILETIME to ULARGE_INTEGER, so we can use it as uint64_t + ULARGE_INTEGER ulargeInt; + ulargeInt.LowPart = nativeWinFileTime.dwLowDateTime; + ulargeInt.HighPart = nativeWinFileTime.dwHighDateTime; + + // Windows epoch time is January 1, 1601 (UTC) in 100-nanosecond intervals -> Convert to UNIX epoch + uint64_t constexpr FILETIME_to_1970 = 116444736000000000ULL; + ulargeInt.QuadPart -= FILETIME_to_1970; + uint64_t unixNanos = ulargeInt.QuadPart * 100; + + std::chrono::system_clock::time_point result(std::chrono::duration_cast(std::chrono::nanoseconds(unixNanos))); + return result; +} + +std::string IO::Filesystem::FileHandle::GetFilePath() const +{ + return m_filePath; +} + +uint64_t IO::Filesystem::FileHandleReadonly::ReadSync(uint8_t* dest, uint64_t amountToRead) +{ + uint64_t leftToRead = amountToRead; + while (leftToRead > 0) + { + DWORD amountToReadThisCycle = (DWORD) std::min(leftToRead, ((uint64_t) std::numeric_limits::max()) - 1); + + DWORD actuallyReadThisCycle = 0; + bool isOkay = ::ReadFile(m_nativeFileHandle, dest, amountToReadThisCycle, &actuallyReadThisCycle, nullptr); + if (!isOkay) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "ReadSync -> ::ReadFile() Error: %u", GetLastError()); + return 0; + } + leftToRead -= actuallyReadThisCycle; + + if (actuallyReadThisCycle != amountToReadThisCycle) + { + break; + } + } + return amountToRead - leftToRead; +} + +std::unique_ptr IO::Filesystem::FileHandleReadonly::DuplicateFileHandle() +{ + HANDLE newNativeFileHandle = nullptr; + + bool isOkay = DuplicateHandle( + GetCurrentProcess(), + m_nativeFileHandle, + GetCurrentProcess(), + &newNativeFileHandle, + 0, + FALSE, + DUPLICATE_SAME_ACCESS); + if (!isOkay) + throw std::runtime_error("DuplicateFileHandle -> ::DuplicateHandle() Failed, ErrorCode = " + std::to_string(GetLastError())); + + return std::make_unique(m_filePath, newNativeFileHandle); +} diff --git a/src/shared/IO/Filesystem/impl/windows/FileSystem.cpp b/src/shared/IO/Filesystem/impl/windows/FileSystem.cpp new file mode 100644 index 00000000000..7866173048f --- /dev/null +++ b/src/shared/IO/Filesystem/impl/windows/FileSystem.cpp @@ -0,0 +1,73 @@ +#include "IO/Filesystem/FileSystem.h" +#include "IO/Filesystem/FileHandle.h" +#include "Log.h" + +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN + +/// This function will open a file in read shared and binary mode +/// You have to check the resulting pointer for nullptr! +/// If the file does not exists or you dont have permission to open it the ptr will be null +std::unique_ptr IO::Filesystem::TryOpenFileReadonly(std::string const& filePath) +{ + HANDLE nativeFileHandle = CreateFileA( + filePath.c_str(), + GENERIC_READ, + FILE_SHARE_READ, // Share mode: allow other processes to read + nullptr, // Security attributes + OPEN_EXISTING, // Open exising file. Fail if it does not exist + FILE_ATTRIBUTE_NORMAL, // Normal open, without any special flags + nullptr // Template file handle (would be used when creating a new file and copy the attributes) + ); + + if (nativeFileHandle == INVALID_HANDLE_VALUE) { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "Unable to open file. Error %d on file: %s", GetLastError(), filePath.c_str()); + return nullptr; + } + + return std::unique_ptr(new FileHandleReadonly(filePath, nativeFileHandle)); +} + +/// Will convert a partial path like "./data/myCoolFile.txt" to a complete absolute path like "/home/user/data/myCoolFile.txt" +std::string IO::Filesystem::ToAbsolutePath(std::string const& partialPath) +{ + char fullPath[MAX_PATH]; + if (::GetFullPathNameA(partialPath.c_str(), MAX_PATH, fullPath, nullptr) == 0) + throw std::runtime_error("ToAbsolutePath -> ::GetFullPathNameA() failed " + std::to_string(GetLastError())); + + return std::string(fullPath); +} + +/// Returns all files in a folder, non-recursively. +/// if OutputFilePath::JustFileName the path will be based on the folderPath e.g. "myCoolFile.txt" +/// if OutputFilePath::FullFilePath the path will be absolute e.g. "/home/user/data/myCoolFile.txt" +std::vector IO::Filesystem::GetAllFilesInFolder(std::string const& folderPath, IO::Filesystem::OutputFilePath filePathOption) +{ + std::vector files; + + // Construct the search path + std::string searchPath = folderPath + "\\*"; + + WIN32_FIND_DATAA fileData; + HANDLE hFind = ::FindFirstFileA(searchPath.c_str(), &fileData); + if (hFind != INVALID_HANDLE_VALUE) + { + do + { + if (!(fileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) // Somehow Windows detects .MPQ files as FILE_ATTRIBUTE_ARCHIVE? Thus, we can't use FILE_ATTRIBUTE_NORMAL + { + std::string filePath = filePathOption == OutputFilePath::FullFilePath + ? IO::Filesystem::ToAbsolutePath(folderPath + "\\" + fileData.cFileName) + : fileData.cFileName; + + files.emplace_back(filePath); + } + + } while (::FindNextFileA(hFind, &fileData)); + + ::FindClose(hFind); + } + + return files; +} diff --git a/src/shared/IO/Multithreading/CreateThread.cpp b/src/shared/IO/Multithreading/CreateThread.cpp new file mode 100644 index 00000000000..8b8721e5cc0 --- /dev/null +++ b/src/shared/IO/Multithreading/CreateThread.cpp @@ -0,0 +1,67 @@ +#include "CreateThread.h" + +#if defined(WIN32) +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN +#elif defined(__linux__) || defined(__APPLE__) +#include +#endif + +std::unique_ptr IO::Multithreading::CreateThreadPtr(std::string const& name, std::function entryFunction) +{ + return std::make_unique([name, entryFunction = std::move(entryFunction)]() + { + IO::Multithreading::RenameCurrentThread(name); + entryFunction(); + }); +} + +std::thread IO::Multithreading::CreateThread(std::string const& name, std::function entryFunction) +{ + return std::thread([name, entryFunction = std::move(entryFunction)]() + { + IO::Multithreading::RenameCurrentThread(name); + entryFunction(); + }); +} + +void IO::Multithreading::RenameCurrentThread(std::string const& name) +{ +#if defined(WIN32) + // Windows part taken from https://stackoverflow.com/a/23899379 + // SetThreadDescription is only supported on >= Win10, that's why we are using this approach + + const DWORD MS_VC_EXCEPTION=0x406D1388; +#pragma pack(push,8) + typedef struct tagTHREADNAME_INFO + { + DWORD dwType; // Must be 0x1000. + LPCSTR szName; // Pointer to name (in user addr space). + DWORD dwThreadID; // Thread ID (-1=caller thread). + DWORD dwFlags; // Reserved for future use, must be zero. + } THREADNAME_INFO; +#pragma pack(pop) + + THREADNAME_INFO info; + info.dwType = 0x1000; + info.szName = name.c_str(); + info.dwThreadID = GetCurrentThreadId(); + info.dwFlags = 0; + + __try + { + RaiseException( MS_VC_EXCEPTION, 0, sizeof(info)/sizeof(ULONG_PTR), (ULONG_PTR*)&info ); + } + __except(EXCEPTION_EXECUTE_HANDLER) + { + } +#elif defined(__linux__) + ::pthread_setname_np(pthread_self(), name.c_str()); +#elif defined(__APPLE__) + ::pthread_setname_np(name.c_str()); +#else + // It's not too serisous if we cant rename a thread + #warning "IO::Multithreading::_renameThisThread not supported on your platform" +#endif +} diff --git a/src/shared/IO/Multithreading/CreateThread.h b/src/shared/IO/Multithreading/CreateThread.h new file mode 100644 index 00000000000..6c0aae1aaac --- /dev/null +++ b/src/shared/IO/Multithreading/CreateThread.h @@ -0,0 +1,25 @@ +#ifndef MANGOS_IO_MULTITHREADING_CREATETHREAD_H +#define MANGOS_IO_MULTITHREADING_CREATETHREAD_H + +#include +#include +#include +#include + +namespace IO { namespace Multithreading { + /// Creates a new system thread that has a name attached to it. + /// Names are super useful when monitoring the utilization of each thread. + [[nodiscard("Use this return value to at least .join() or .detach() the thread")]] + std::unique_ptr CreateThreadPtr(std::string const& name, std::function entryFunction); + + /// Creates a new system thread that has a name attached to it. + /// Names are super useful when monitoring the utilization of each thread. + [[nodiscard("Use this return value to at least .join() or .detach() the thread")]] + std::thread CreateThread(std::string const& name, std::function entryFunction); + + /// Will rename your current thread. + /// Names are super useful when monitoring the utilization of each thread. + void RenameCurrentThread(std::string const& name); +}} // namespace IO::Multithreading + +#endif //MANGOS_IO_MULTITHREADING_CREATETHREAD_H diff --git a/src/shared/IO/NativeAliases.h b/src/shared/IO/NativeAliases.h new file mode 100644 index 00000000000..9b7b1d8fe5b --- /dev/null +++ b/src/shared/IO/NativeAliases.h @@ -0,0 +1,23 @@ +#ifndef MANGOS_IO_NATIVEALIASES_H +#define MANGOS_IO_NATIVEALIASES_H + +#if defined(WIN32) +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN + +namespace IO { namespace Native { + typedef UINT_PTR SocketHandle; + typedef HANDLE FileHandle; +}} // namespace IO::_Native + +#elif defined(__linux__) || defined(__APPLE__) + +namespace IO { namespace Native { + typedef int SocketHandle; + typedef int FileHandle; +}} // namespace IO::Native + +#endif + +#endif //MANGOS_IO_NATIVEALIASES_H diff --git a/src/shared/IO/Networking/AsyncSocket.cpp b/src/shared/IO/Networking/AsyncSocket.cpp new file mode 100644 index 00000000000..cda81f6e2ef --- /dev/null +++ b/src/shared/IO/Networking/AsyncSocket.cpp @@ -0,0 +1,65 @@ +#include "AsyncSocket.h" +#include "Log.h" +#include "Errors.h" +#include "IpAddress.h" + +IO::Networking::AsyncSocket::AsyncSocket(AsyncSocket&& other) noexcept : + m_ctx(other.m_ctx), + m_descriptor(std::move(other.m_descriptor)), + m_contextCallback(std::move(other.m_contextCallback)), + m_readCallback(std::move(other.m_readCallback)), + m_writeCallback(std::move(other.m_writeCallback)), + m_writeSrc(std::move(other.m_writeSrc)), +#if defined(WIN32) + m_currentContextTask(std::move(other.m_currentContextTask)), + m_currentWriteTask(std::move(other.m_currentWriteTask)), + m_currentReadTask(std::move(other.m_currentReadTask)) +#elif defined(__linux__) || defined(__APPLE__) + m_readDstBuffer(other.m_readDstBuffer), + m_readDstBufferSize(other.m_readDstBufferSize), + m_readDstBufferBytesLeft(other.m_readDstBufferBytesLeft), + m_writeSrcAlreadyTransferred(other.m_writeSrcAlreadyTransferred) +#endif +{ + MANGOS_DEBUG_ASSERT(!(m_atomicState.load(std::memory_order_relaxed) & SocketStateFlags::IS_INITIALIZED)); // dont allow std::move() if memory address is fixed + + m_atomicState.exchange(other.m_atomicState); + other.m_atomicState.exchange(SocketStateFlags::WAS_MOVED_NO_DTOR); +} + +IO::Networking::AsyncSocket::~AsyncSocket() +{ + int state = m_atomicState.load(std::memory_order_relaxed); + if (state & SocketStateFlags::WAS_MOVED_NO_DTOR) + return; // Ignore destructor + + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] Destructor called ~AsyncSocket: No references left", GetRemoteIpString().c_str()); + m_descriptor.CloseSocket(); // <-- This will actually close the socket and release the file descriptor to the kernel + + // Logic behind these checks: + // If the destructor is called, there should be no more std::shared_ptr<> references to this object + // Every Read(...) or Write(...) should use `shared_from_this()` if this is not the case, one of the following checks will fail + MANGOS_ASSERT(!(state & SocketStateFlags::CONTEXT_PRESENT)); + MANGOS_ASSERT(!(state & SocketStateFlags::WRITE_PRESENT)); + MANGOS_ASSERT(!(state & SocketStateFlags::READ_PRESENT)); +} + +bool IO::Networking::AsyncSocket::IsClosing() const +{ + bool isClosing = m_atomicState.load(std::memory_order_relaxed) & SHUTDOWN_PENDING; + return isClosing; +} + +void IO::Networking::AsyncSocket::ReadSkip(std::size_t skipSize, std::function const& callback) +{ + std::shared_ptr> skipBuffer(new std::vector()); + skipBuffer->resize(skipSize); + Read((char*)skipBuffer->data(), skipSize, [skipBuffer, callback](IO::NetworkError const& error, size_t) + { + // KEEP skipBuffer in scope! + // Do not remove skipBuffer before Read() is done, since we are transferring into it via async IO + // and since we are using a raw pointer, the Task has no knowledge about the lifetime of the std::vector + skipBuffer->clear(); + callback(error); + }); +} diff --git a/src/shared/IO/Networking/AsyncSocket.h b/src/shared/IO/Networking/AsyncSocket.h new file mode 100644 index 00000000000..f907f4b01c1 --- /dev/null +++ b/src/shared/IO/Networking/AsyncSocket.h @@ -0,0 +1,148 @@ +#ifndef MANGOS_IO_NETWORKING_ASYNCSOCKET_H +#define MANGOS_IO_NETWORKING_ASYNCSOCKET_H + +#include "IO/Context/IoContext.h" +#include "IO/Networking/NetworkError.h" +#include "IO/Networking/SocketDescriptor.h" +#include "IO/Context/AsyncIoOperation.h" +#include "IO/NativeAliases.h" +#include "IO/ReadableBuffer.h" + +#include "Policies/ObjectConstructorTraits.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace IO { namespace Networking { + + /// You have to keep the instance alive while a transaction is running. Use a shared pointer or something on every callback! + class AsyncSocket final : public MaNGOS::Policies::NoCopyButAllowMove +#if defined(__linux__) || defined(__APPLE__) + , public IO::SystemIoEventReceiver +#endif + { + public: // public functions + /// Dont forget to call `InitializeAndFixateMemoryLocation` before making a transfer + explicit AsyncSocket(IO::IoContext* ctx, SocketDescriptor socketDescriptor); + AsyncSocket(AsyncSocket&& other) noexcept; + ~AsyncSocket(); // this destructor will throw if there is a pending transaction + + /// You have to execute this function, before any transfer or context switch operation is performed. + /// Instead of having trampoline pointers which might hinder performance, + /// the memory address of this socket class is directly registered in ::epoll/::kqueue. + /// This makes it a bit harder to use this class, but improves performance a little bit. + /// This cannot be undone. You have to destruct the socket. + [[nodiscard("Check the returning error code and close the socket in case of an error")]] + IO::NetworkError InitializeAndFixateMemoryLocation(); + + /// If set to true, it asks the OS to disables "Nagle's algorithm" for this socket. + IO::NetworkError SetNativeSocketOption_NoDelay(bool doNoDelay); + /// Provides a hint to the OS how large the outgoing send buffer should be + IO::NetworkError SetNativeSocketOption_SystemOutgoingSendBuffer(int bytes); + + /// Keep in mind to keep the source buffer in scope of the callback, otherwise random memory might get overwritten + /// Most of the time this is not an issue, since you want to process the incoming buffer + /// You have to keep the pointer alive until the callback is called. Use [self = shared_from_this()] + void Read(char* target, size_t size, std::function const& callback); + void ReadSome(char* target, size_t maxSize, std::function const& callback); + void ReadSkip(size_t skipSize, std::function const& callback); + + // Development decision `char*` vs `ReadableBuffer`: + // Read() takes a `char*` while Write() uses a smart pointer to prevent accidental use-after-free. + // It's easy to forget to keep the buffer in scope. (without this precaution, Write() could also take a `char*`) + + /// Warning: Using this function will NOT copy the buffer content, dont overwrite it unless callback is triggered! + /// You have to keep the pointer alive until the callback is called. Use [self = shared_from_this()] + void Write(IO::ReadableBuffer const& source, std::function const& callback); + + /// The callback is invoked in the IO thread + /// Useful for computational expensive operations (e.g. packing and encryption), that should be avoided in the main loop + /// You have to keep the pointer alive until the callback is called. Use [self = shared_from_this()] + void EnterIoContext(std::function const& callback); + + void CloseSocket(); + bool IsClosing() const; + + IO::Networking::IpEndpoint const& GetRemoteEndpoint() const + { + return m_descriptor.GetRemoteEndpoint(); + } + + /// IPv4 Format: 255.255.255.255 + /// IPv6 Format: [FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF] + std::string const& GetRemoteIpString() const + { + return GetRemoteEndpoint().ip.ToString(); + } + + private: +#if defined(__linux__) || defined(__APPLE__) + void OnIoEvent(uint32_t event) final; // invoked by IoContext +#endif + + protected: // socket specific variables + IO::IoContext* m_ctx; + IO::Networking::SocketDescriptor m_descriptor; + + private: // internal stuff + + // We are doing all this atomic stuff, just so we don't have to std::mutex everything + enum SocketStateFlags : int + { + SHUTDOWN_PENDING = (1 << 0), // stop all new transaction requests. There should never be a new _INUSE when this is present + IGNORE_TRANSFERS = (1 << 1), // Like SHUTDOWN_PENDING but the event receivers `PerformNonBlockingRead` and `PerformNonBlockingWrite` will ignore the event + + // PRESENT: Stuff that is present and set + // PENDING: Stuff that is currently being used, if you want to close the socket you must spinwait it. + + WRITE_PENDING_SET = (1 << 2), + WRITE_PRESENT = (1 << 3), + WRITE_PENDING_LOAD = (1 << 4), + + READ_PENDING_SET = (1 << 5), + READ_PRESENT = (1 << 6), + READ_PENDING_LOAD = (1 << 7), + + CONTEXT_PENDING_SET = (1 << 8), + CONTEXT_PRESENT = (1 << 9), + CONTEXT_PENDING_LOAD = (1 << 10), + + WAS_MOVED_NO_DTOR = (1 << 11), + IS_INITIALIZED = (1 << 12), // memory address of this socket is fixed because its referenced in kernel code (epoll/kqueue). Cant std::move this object. + }; + std::atomic m_atomicState{0}; + + std::function m_contextCallback = nullptr; // <-- Callback into user code + + // Read = the target buffer to write the network stream to + std::function m_readCallback = nullptr; // <-- Callback into user code + + // Write = the source buffer from where to read to be able to write to the network stream + std::function m_writeCallback = nullptr; // <-- Callback into user code + IO::ReadableBuffer m_writeSrc{}; + +#if defined(WIN32) + IocpOperationTask m_currentContextTask; // <-- Internal tasks / callback to internal networking code + IocpOperationTask m_currentWriteTask; // <-- Internal tasks / callback to internal networking code + IocpOperationTask m_currentReadTask; // <-- Internal tasks / callback to internal networking code +#elif defined(__linux__) || defined(__APPLE__) + void PerformNonBlockingRead(); + void PerformNonBlockingWrite(); + void PerformContextSwitch(); + void StopPendingTransactionsAndForceClose(); + + char* m_readDstBuffer = nullptr; // this ptr will move along the buffer as its filled, check m_readDstBufferBytesLeft for space + size_t m_readDstBufferSize = 0; // will be 0 if ReadSome(), otherwise the original buffer size + size_t m_readDstBufferBytesLeft = 0; + + size_t m_writeSrcAlreadyTransferred = 0; +#endif + }; +}} // namespace IO::Networking + +#endif //MANGOS_IO_NETWORKING_ASYNCSOCKET_H diff --git a/src/shared/IO/Networking/AsyncSocketAcceptor.h b/src/shared/IO/Networking/AsyncSocketAcceptor.h new file mode 100644 index 00000000000..e7ef6f95db3 --- /dev/null +++ b/src/shared/IO/Networking/AsyncSocketAcceptor.h @@ -0,0 +1,59 @@ +#ifndef MANGOS_IO_NETWORKING_ASYNCSOCKETACCEPTOR_H +#define MANGOS_IO_NETWORKING_ASYNCSOCKETACCEPTOR_H + +#include "IO/NativeAliases.h" +#include "IO/Context/IoContext.h" +#include "IO/Context/AsyncIoOperation.h" +#include "IO/Networking/NetworkError.h" +#include "IO/Networking/SocketDescriptor.h" + +#include "Policies/ObjectConstructorTraits.h" + +#include "nonstd/expected.hpp" + +#include +#include +#include +#include + +namespace IO { namespace Networking { + + class AsyncSocket; + + /// A class that allows you to bind to a TCP address and accept connections + class AsyncSocketAcceptor : MaNGOS::Policies::NoCopyButAllowMove + #if defined(__linux__) || defined(__APPLE__) + , IO::SystemIoEventReceiver + #endif + { + public: + ~AsyncSocketAcceptor(); // this destructor will throw if ClosePortAndStopAcceptingNewConnections was not called + + static std::unique_ptr CreateAndBindServer(IO::IoContext* ctx, std::string const& bindIpStr, uint16_t port); + void ClosePortAndStopAcceptingNewConnections(); + + /// Automatically accepts all incoming connections until this Acceptor is StoppedAndClosed + void AutoAcceptSocketsUntilClose(std::function const& onNewSocket); + #if defined(__linux__) || defined(__APPLE__) + void OnIoEvent(uint32_t event); // used for ::accept + #endif + + private: + explicit AsyncSocketAcceptor(IO::IoContext* ctx, IO::Native::SocketHandle acceptorNativeSocket); + + IO::Native::SocketHandle m_acceptorNativeSocket; + IO::IoContext* m_ctx; + bool m_wasClosed; + + #if defined(WIN32) + void AcceptOne(std::function acceptResult)> const& afterAccept); + IocpOperationTask m_currentAcceptTask; + #elif defined(__linux__) || defined(__APPLE__) + std::function m_onNewSocketCallback; + void OnNewClientToAcceptAvailable(); // a new socket on ::accept() is available + #endif + + }; +}} // namespace IO::Networking + +#endif // MANGOS_IO_NETWORKING_ASYNCSOCKETACCEPTOR_H diff --git a/src/shared/IO/Networking/AsyncSocketAcceptor_posix.cpp b/src/shared/IO/Networking/AsyncSocketAcceptor_posix.cpp new file mode 100644 index 00000000000..a51fcb791c5 --- /dev/null +++ b/src/shared/IO/Networking/AsyncSocketAcceptor_posix.cpp @@ -0,0 +1,158 @@ +#include "./AsyncSocketAcceptor.h" + +#include "IO/Utils_Unix.h" +#include "IO/Context/IoContext.h" +#include "IO/Networking/SocketDescriptor.h" +#include "IO/Networking/IpAddress.h" +#include "IO/Networking/Internal.h" +#include "IO/SystemErrorToString.h" + +#include "Log.h" +#include "Errors.h" + +#if defined(__linux__) || defined(__APPLE__) +#include +#include +#include +#endif + +#if defined(__linux__) +#include +#elif defined(__APPLE__) +#include +#include +#endif + +#include +#include +#include +#include +#include + +IO::Networking::AsyncSocketAcceptor::AsyncSocketAcceptor(IO::IoContext* ctx, IO::Native::SocketHandle acceptorNativeSocket) + : m_ctx(ctx), m_acceptorNativeSocket(acceptorNativeSocket), m_wasClosed(false), m_onNewSocketCallback{nullptr} {} + +std::unique_ptr IO::Networking::AsyncSocketAcceptor::CreateAndBindServer(IO::IoContext* ctx, std::string const& bindIpStr, uint16_t port) +{ + nonstd::optional maybeBindIp = IpAddress::TryParseFromString(bindIpStr); + if (!maybeBindIp.has_value()) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "Fail to parse IP '%s'", bindIpStr.c_str()); + return nullptr; + } + + IO::Native::SocketHandle listenNativeSocket = ::socket(AF_INET, SOCK_STREAM, 0); + if (listenNativeSocket == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateAndBindServer -> ::socket(listen) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + int optionValue = 1; // Unix/macOS is a bit weird. When someone else is still connected to our socket, but we restart the server, the server cannot bind again. + if (::setsockopt(listenNativeSocket, SOL_SOCKET, SO_REUSEADDR, &optionValue, sizeof(optionValue)) != 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateAndBindServer -> ::setsockopt(reuseaddr) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + sockaddr_in m_serverAddress{}; + m_serverAddress.sin_family = AF_INET; + IO::Networking::Internal::inet_pton(maybeBindIp.value(), &(m_serverAddress.sin_addr)); + m_serverAddress.sin_port = htons(port); + if (::bind(listenNativeSocket, (struct sockaddr*)&m_serverAddress, sizeof(m_serverAddress)) != 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateAndBindServer -> ::bind(listen) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + int const acceptBacklogCount = 50; // the number of connection requests that are queued in the kernel until this process calls "accept" + if (::listen(listenNativeSocket, acceptBacklogCount) != 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateAndBindServer -> ::listen(...) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } + + auto x = new AsyncSocketAcceptor(ctx, listenNativeSocket); + auto server = std::unique_ptr(x); + + // Add server socket to event queue (needed for ::accept(..)) +#if defined(__linux__) + ::epoll_event event; + event.events = EPOLLIN | EPOLLERR; // Don't use EdgeTrigger here, since if multiple ::accepts are in the queue, we one get notified for one + event.data.u32 = static_cast(IoContextEpollTargetType::IoEventReceiverFunction); + static_assert(std::is_base_of::element_type>::value, "Must implement SystemIoEventReceiver interface!"); + event.data.ptr = server.get(); // note static_assert above + if (::epoll_ctl(ctx->GetUnixEpollDescriptor(), EPOLL_CTL_ADD, listenNativeSocket, &event) == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateAndBindServer -> ::epoll_ctl(...) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } +#elif defined(__APPLE__) + struct ::kevent addedEvents{}; + static_assert(std::is_base_of::element_type>::value, "Must implement SystemIoEventReceiver interface!"); + EV_SET(&addedEvents, listenNativeSocket, EVFILT_READ, EV_ADD | EV_ERROR, 0, 0, server.get()); + if (::kevent(ctx->GetKqueueDescriptor(), &addedEvents, 1, nullptr, 0, nullptr) == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "CreateAndBindServer -> ::kevent(...) Error: %s", SystemErrorToString(errno).c_str()); + return nullptr; + } +#else + #error "Unsupported" +#endif + + return server; +} + +IO::Networking::AsyncSocketAcceptor::~AsyncSocketAcceptor() +{ + MANGOS_ASSERT(m_wasClosed); +} + +void IO::Networking::AsyncSocketAcceptor::ClosePortAndStopAcceptingNewConnections() +{ + m_wasClosed = true; + + ::close(m_acceptorNativeSocket); +} + +void IO::Networking::AsyncSocketAcceptor::AutoAcceptSocketsUntilClose(std::function const& onNewSocket) +{ + m_onNewSocketCallback = onNewSocket; +} + +void IO::Networking::AsyncSocketAcceptor::OnNewClientToAcceptAvailable() +{ + if (m_onNewSocketCallback == nullptr) + return; // ignore, do not ::accept() yet + + ::sockaddr_in peerAddress; + socklen_t client_len = sizeof(peerAddress); + int nativePeerSocket = ::accept(m_acceptorNativeSocket, (struct sockaddr*)&peerAddress, &client_len); + if (nativePeerSocket == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "RunEventLoop -> ::accept(...) Error: %s", SystemErrorToString(errno).c_str()); + return; + } + + IO::Networking::IpAddress peerIpAddress = IO::Networking::Internal::inet_ntop(&(peerAddress.sin_addr)); + uint16_t peerPort = ntohs(peerAddress.sin_port); + + IO::Networking::IpEndpoint peerEndpoint(peerIpAddress, peerPort); + IO::Networking::SocketDescriptor socketDescriptor{nativePeerSocket, peerEndpoint}; + + IO::NetworkError err = IO::Utils::SetFdStatusFlag(nativePeerSocket, O_NONBLOCK); + if (err) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "OnNewClientToAcceptAvailable -> ::IO::Utils::SetFdStatusFlag(...) Error: %s", err.ToString().c_str()); + socketDescriptor.CloseSocket(); + return; + } + + m_onNewSocketCallback(std::move(socketDescriptor)); +} + +void IO::Networking::AsyncSocketAcceptor::OnIoEvent(uint32_t event) +{ + // The only event we can receive is ::accept() + OnNewClientToAcceptAvailable(); +} diff --git a/src/shared/IO/Networking/AsyncSocketAcceptor_windows.cpp b/src/shared/IO/Networking/AsyncSocketAcceptor_windows.cpp new file mode 100644 index 00000000000..7754c32c35e --- /dev/null +++ b/src/shared/IO/Networking/AsyncSocketAcceptor_windows.cpp @@ -0,0 +1,180 @@ +#ifndef MANGOS_IO_NETWORKING_WIN32_AsyncSocketAcceptor_H +#define MANGOS_IO_NETWORKING_WIN32_AsyncSocketAcceptor_H + +#include "IO/Networking/AsyncSocketAcceptor.h" +#include "IO/Networking/Internal.h" +#include "IO/Context/AsyncIoOperation.h" +#include "IO/Networking/IpAddress.h" +#include "IO/Networking/SocketDescriptor.h" +#include "IO/SystemErrorToString.h" +#include "Log.h" +#include "Memory/ArrayDeleter.h" + +#include +#include +#include +#include + +#include +#include // TODO: Currently just needed for ::AcceptEx, maybe its better if we get this func-ptr at runtime, just like Microsoft recommends it + +IO::Networking::AsyncSocketAcceptor::AsyncSocketAcceptor(IO::IoContext* ctx, IO::Native::SocketHandle acceptorNativeSocket) + : m_ctx(ctx), m_acceptorNativeSocket(acceptorNativeSocket), m_wasClosed(false) {} + +IO::Networking::AsyncSocketAcceptor::~AsyncSocketAcceptor() +{ + MANGOS_ASSERT(m_wasClosed); +} + +void IO::Networking::AsyncSocketAcceptor::ClosePortAndStopAcceptingNewConnections() +{ + m_wasClosed = true; + + ::closesocket(m_acceptorNativeSocket); + + while (m_currentAcceptTask.m_callback != nullptr) + std::this_thread::yield(); // I think it's fine to "busy" wait here instead of adding complex .wait() logic to the hot `StartAcceptOperation` code. +} + +std::unique_ptr IO::Networking::AsyncSocketAcceptor::CreateAndBindServer(IO::IoContext* ctx, std::string const& bindIpStr, uint16_t port) +{ + nonstd::optional maybeBindIp = IpAddress::TryParseFromString(bindIpStr); + if (!maybeBindIp.has_value()) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "Fail to parse IP '%s'", bindIpStr.c_str()); + return nullptr; + } + + int errorCode; + + // TODO check if WSA was already initialized + // TODO if fatal error, close socket and CleanupWSA (if reference counter == 0) + + WSADATA wsaData; + errorCode = ::WSAStartup(MAKEWORD(2, 2), &wsaData); + if (errorCode != 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::WSAStartup(...) Error: %u", errorCode); + return nullptr; + } + + // Create an IPv4 TCP server where other clients can connect to + SOCKET listenNativeSocket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (listenNativeSocket == INVALID_SOCKET) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::socket(listen, ...) Error: %u", ::WSAGetLastError()); + return nullptr; + } + + // Attach our listener socket to our completion port + if (::CreateIoCompletionPort((HANDLE) listenNativeSocket, ctx->GetWindowsCompletionPort(), (u_long) 0, 0) != ctx->GetWindowsCompletionPort()) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::CreateIoCompletionPort(listen, ...) Error: %u", ::WSAGetLastError()); + return nullptr; + } + + sockaddr_in m_serverAddress{}; + m_serverAddress.sin_family = AF_INET; + IO::Networking::Internal::inet_pton(maybeBindIp.value(), &(m_serverAddress.sin_addr)); + m_serverAddress.sin_port = ::htons(port); + if (::bind(listenNativeSocket, (struct sockaddr*)(&m_serverAddress), sizeof(m_serverAddress)) != 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::bind(...) Error: %s", SystemErrorToString(::WSAGetLastError()).c_str()); + return nullptr; + } + + int const acceptBacklogCount = 50; // the number of connection requests that are queued in the kernel until this process calls "accept" + if (::listen(listenNativeSocket, acceptBacklogCount) != 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::listen(...) Error: %s", SystemErrorToString(::WSAGetLastError()).c_str()); + return nullptr; + } + + auto server = std::unique_ptr(new AsyncSocketAcceptor(ctx, listenNativeSocket)); + return server; +} + +void IO::Networking::AsyncSocketAcceptor::AutoAcceptSocketsUntilClose(std::function const& onNewSocket) +{ + AcceptOne([onNewSocket, this](nonstd::expected acceptResult) + { + if (!acceptResult.has_value()) + { + if (!m_wasClosed) + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "AcceptOne Error: %s", acceptResult.error().ToString().c_str()); + return; + } + + onNewSocket(std::move(acceptResult.value())); + AutoAcceptSocketsUntilClose(onNewSocket); + }); +} + +void IO::Networking::AsyncSocketAcceptor::AcceptOne(std::function)> const& afterAccept) +{ + SOCKET nativePeerSocket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); // <-- will be filled when callback is called + if (nativePeerSocket == INVALID_SOCKET) + { + afterAccept(nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, ::WSAGetLastError()))); + return; + } + + // AcceptEx stores the data in an internal Windows format. It can be read with `::GetAcceptExSockaddrs()`. + std::shared_ptr addrBuffer = std::shared_ptr(new uint8_t[((sizeof(sockaddr_in) + 16) * 2)], MaNGOS::Memory::array_deleter()); + m_currentAcceptTask.InitNew([nativePeerSocket, this, addrBuffer, afterAccept](DWORD errorCode) + { + auto localAfterAccept = std::move(afterAccept); + auto localAddrBuffer = std::move(addrBuffer); + auto localNativePeerSocket = std::move(nativePeerSocket); + this->m_currentAcceptTask.Reset(); // after we reset, the captured variables are no longer valid + + if (!errorCode) + { // No error, everything is fine + SOCKADDR_IN* localAddr = nullptr; + SOCKADDR_IN* remoteAddr = nullptr; + int localAddrLen = 0; + int remoteAddrLen = 0; + + ::GetAcceptExSockaddrs(localAddrBuffer.get(), 0, + sizeof(SOCKADDR_IN) + 16, sizeof(SOCKADDR_IN) + 16, + (SOCKADDR**)&localAddr, &localAddrLen, + (SOCKADDR**)&remoteAddr, &remoteAddrLen); + + IpAddress ip = IO::Networking::Internal::inet_ntop(&(remoteAddr->sin_addr)); + uint16_t port = ::ntohs(remoteAddr->sin_port); + + localAfterAccept(IO::Networking::SocketDescriptor(localNativePeerSocket, { ip, port })); + return; + } + + if (errorCode == ERROR_OPERATION_ABORTED && m_wasClosed) + { // ignore "aborted" error when we are in a closing state + localAfterAccept(nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::SocketClosed, 0))); + return; + } + + // we got a real error + localAfterAccept(nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, ::WSAGetLastError()))); + }); + + DWORD bytesWritten = 0; + + bool booleanOkay = ::AcceptEx(m_acceptorNativeSocket, nativePeerSocket, + addrBuffer.get(), + 0, + sizeof (sockaddr_in) + 16, sizeof (sockaddr_in) + 16, + &bytesWritten, &m_currentAcceptTask + ); + if (!booleanOkay) + { + int lastError = ::WSAGetLastError(); + if (lastError != WSA_IO_PENDING) // Pending means that this task was queued (which is what we want) + { + m_currentAcceptTask.Reset(); + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::AcceptEx(...) Error: %u", lastError); + return; + } + } +} + +#endif //MANGOS_IO_NETWORKING_WIN32_AsyncSocketAcceptor_H diff --git a/src/shared/IO/Networking/AsyncSocket_posix.cpp b/src/shared/IO/Networking/AsyncSocket_posix.cpp new file mode 100644 index 00000000000..ef5cfa3b18d --- /dev/null +++ b/src/shared/IO/Networking/AsyncSocket_posix.cpp @@ -0,0 +1,596 @@ +#include "AsyncSocket.h" +#include "IO/SystemErrorToString.h" +#include "Errors.h" +#include "Log.h" + +#if defined(__linux__) +#include +#elif defined(__APPLE__) +#include +#include +#endif +#include +#include +#include + +IO::Networking::AsyncSocket::AsyncSocket(IO::IoContext* ctx, IO::Networking::SocketDescriptor socketDescriptor) + : m_ctx(ctx), m_descriptor(std::move(socketDescriptor)) +{ +} + +IO::NetworkError IO::Networking::AsyncSocket::InitializeAndFixateMemoryLocation() +{ + int state = m_atomicState.fetch_or(SocketStateFlags::IS_INITIALIZED); + MANGOS_ASSERT(!(state & SocketStateFlags::IS_INITIALIZED)); // can be only performed once + +#if defined(__linux__) + ::epoll_event event; + event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLRDHUP | EPOLLET; + event.data.ptr = this; + if (::epoll_ctl(m_ctx->GetUnixEpollDescriptor(), EPOLL_CTL_ADD, m_descriptor.GetNativeSocket(), &event) == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "OnNewClientToAcceptAvailable -> ::epoll_ctl(...) Error: %s", SystemErrorToString(errno).c_str()); + return IO::NetworkError(NetworkError::ErrorType::InternalError, errno); + } +#elif defined(__APPLE__) + struct kevent addedEvents[2]; + + // EVFILT_READ (epoll: EPOLLIN) + EV_SET(&addedEvents[0], m_descriptor.GetNativeSocket(), EVFILT_READ, EV_ADD | EV_CLEAR, 0, 0, this); + + // EVFILT_WRITE (epoll: EPOLLOUT) + EV_SET(&addedEvents[1], m_descriptor.GetNativeSocket(), EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, this); + + if (::kevent(m_ctx->GetKqueueDescriptor(), addedEvents, 2, nullptr, 0, nullptr) == -1) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "AsyncSocket -> ::kevent(...) Error: %s", SystemErrorToString(errno).c_str()); + return IO::NetworkError(NetworkError::ErrorType::InternalError, errno); + } +#else + #error "Unsupported" +#endif + return IO::NetworkError(NetworkError::ErrorType::NoError); +} + +void IO::Networking::AsyncSocket::Read(char* target, std::size_t size, std::function const& callback) +{ + int state = m_atomicState.fetch_or(SocketStateFlags::READ_PENDING_SET); + MANGOS_DEBUG_ASSERT(state & SocketStateFlags::IS_INITIALIZED); + + if (state & SocketStateFlags::READ_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + + if (state & SocketStateFlags::READ_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (size == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ERROR: Tried to IO::Networking::AsyncSocket::Read(...) with size 0"); + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), 0); // technically not an error, we are just done with the buffer + return; + } + + // Check if there is already something for us buffered in memory + ssize_t alreadyRead = ::recv(m_descriptor.GetNativeSocket(), target, size, 0); + if (alreadyRead == 0) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] Read(...) -> ::recv() returned 0, which means the socket is half-closed.", GetRemoteIpString().c_str()); + StopPendingTransactionsAndForceClose(); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + else if (alreadyRead == -1) + { + if (errno != EWOULDBLOCK) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::InternalError, errno), 0); + return; + } + alreadyRead = 0; // Would block, so we need to queue it for later + } + if (alreadyRead == size) + { // oh wow, we already have the whole buffer, no need to set up variables + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), alreadyRead); + return; + } + + m_readDstBuffer = target + alreadyRead; + m_readDstBufferSize = size; + m_readDstBufferBytesLeft = size - alreadyRead; + m_readCallback = callback; + + m_atomicState.fetch_xor(SocketStateFlags::READ_PRESENT | SocketStateFlags::READ_PENDING_SET); // set PRESENT and unset PENDING_SET +} + +void IO::Networking::AsyncSocket::ReadSome(char* target, std::size_t size, std::function const& callback) +{ + int state = m_atomicState.fetch_or(SocketStateFlags::READ_PENDING_SET); + MANGOS_DEBUG_ASSERT(state & SocketStateFlags::IS_INITIALIZED); + + if (state & SocketStateFlags::READ_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + + if (state & SocketStateFlags::READ_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (size == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ERROR: Tried to IO::Networking::AsyncSocket::Read(...) with size 0"); + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), 0); // technically not an error, we are just done with the buffer + return; + } + + // Check if there is already something for us buffered in memory + ssize_t alreadyRead = ::recv(m_descriptor.GetNativeSocket(), target, size, 0); + if (alreadyRead == 0) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] Read(...) -> ::recv() returned 0, which means the socket is half-closed.", GetRemoteIpString().c_str()); + StopPendingTransactionsAndForceClose(); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + else if (alreadyRead == -1) + { + if (errno != EWOULDBLOCK) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::InternalError, errno), 0); + return; + } + alreadyRead = 0; // Would block, so we need to queue it for later + } + if (alreadyRead != 0) + { // oh wow, we already have "some" buffer, no need to set up variables + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), alreadyRead); + return; + } + + m_readDstBuffer = target + alreadyRead; + m_readDstBufferSize = 0; // 0 means ReadSome(), only one ::recv call + m_readDstBufferBytesLeft = size - alreadyRead; + m_readCallback = callback; + + m_atomicState.fetch_xor(SocketStateFlags::READ_PRESENT | SocketStateFlags::READ_PENDING_SET); // set PRESENT and unset PENDING_SET +} + +/// Warning: Using this function will NOT copy the buffer, dont overwrite it unless callback is triggered! +/// (but a reference to the smart_ptr will be held throughout the transfer, so you dont need to) +void IO::Networking::AsyncSocket::Write(IO::ReadableBuffer const& source, std::function const& callback) +{ + int state = m_atomicState.fetch_or(SocketStateFlags::WRITE_PENDING_SET); + MANGOS_DEBUG_ASSERT(state & SocketStateFlags::IS_INITIALIZED); + + if (state & SocketStateFlags::WRITE_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed)); + return; + } + + if (state & SocketStateFlags::WRITE_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + if (source.GetSize() == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ERROR: Tried to IO::Networking::AsyncSocket::Write(...) with size 0"); + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError)); // technically not an error, we are just done with the buffer + return; + } + + // Check if we can write into memory buffer + ssize_t alreadySent = ::send(m_descriptor.GetNativeSocket(), source.GetPtr(), source.GetSize(), 0); + if (alreadySent == -1) + { + if (errno != EWOULDBLOCK) + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::InternalError, errno)); + return; + } + alreadySent = 0; // Would block, so we need to queue it for later + } + if (alreadySent == source.GetSize()) + { // oh wow, we already sent the whole buffer, no need to set up variables + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError)); + return; + } + + m_writeSrc = source; + m_writeSrcAlreadyTransferred = alreadySent; + m_writeCallback = callback; + + m_atomicState.fetch_xor(SocketStateFlags::WRITE_PRESENT | SocketStateFlags::WRITE_PENDING_SET); // set PRESENT and unset PENDING_SET +} + +void IO::Networking::AsyncSocket::CloseSocket() +{ + // This function will not actually >close< the socket, since this would release the file descriptor and could cause race conditions, + // if the same descriptor id is reused by another socket. + + // set SHUTDOWN_PENDING flag, and check if there was already a previous one + if (m_atomicState.fetch_or(SocketStateFlags::SHUTDOWN_PENDING) & SocketStateFlags::SHUTDOWN_PENDING) + return; // there was already a ::shutdown() + + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] CloseSocket(): Disconnect request", GetRemoteIpString().c_str()); + ::shutdown(m_descriptor.GetNativeSocket(), SHUT_RDWR); +} + +void IO::Networking::AsyncSocket::PerformNonBlockingRead() +{ + int state = m_atomicState.fetch_or(SocketStateFlags::READ_PENDING_LOAD); + if (state & SocketStateFlags::READ_PENDING_LOAD) + return; // Someone else uses it + + if (!(state & SocketStateFlags::READ_PRESENT)) + { + // there is a really rare race condition + // since we are using EDGE trigger for ::recv + // if the buffer is not directly set, but we receive an EPOLL event, + // we might have to wait for the buffer to be set + // Otherwise this event is ignored forever + while ((state = m_atomicState.load()) & SocketStateFlags::READ_PENDING_SET) + std::this_thread::yield(); + } + + if (!(state & SocketStateFlags::READ_PRESENT)) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_LOAD); + return; // There is no buffer + } + + if (state & SocketStateFlags::IGNORE_TRANSFERS) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_LOAD); + return; // We are not allowed to react to it + } + + ssize_t newWrittenBytes = ::recv(m_descriptor.GetNativeSocket(), m_readDstBuffer, m_readDstBufferBytesLeft, 0); + if (newWrittenBytes == 0) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_LOAD); + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] ::recv() returned 0, which means the socket is half-closed.", GetRemoteIpString().c_str()); + StopPendingTransactionsAndForceClose(); + return; + } + if (newWrittenBytes < 0) + { + // If ::recv() failed because the socket is "not ready" we simply use a higher log level + sLog.Out(LOG_NETWORK, errno == EWOULDBLOCK ? LOG_LVL_BASIC : LOG_LVL_ERROR, "[%s] ::recv on client failed: %s", GetRemoteIpString().c_str(), SystemErrorToString(errno).c_str()); + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_LOAD); + return; + } + + m_readDstBufferBytesLeft -= newWrittenBytes; + m_readDstBuffer += newWrittenBytes; + + bool isReadSome = m_readDstBufferSize == 0; // if we have readSome we only want to execute one ::recv() call + + if (m_readDstBufferBytesLeft == 0 || isReadSome) + { // we are done with this buffer + m_readDstBuffer = nullptr; + + auto tmpCallback = std::move(m_readCallback); + m_atomicState.fetch_and(~(SocketStateFlags::READ_PENDING_LOAD | SocketStateFlags::READ_PRESENT)); + + std::size_t transferSize = isReadSome ? newWrittenBytes : m_readDstBufferSize; + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), transferSize); + } + else + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_LOAD); + } +} + +void IO::Networking::AsyncSocket::PerformNonBlockingWrite() +{ + int state = m_atomicState.fetch_or(SocketStateFlags::WRITE_PENDING_LOAD); + + if (state & SocketStateFlags::WRITE_PENDING_LOAD) + return; // Someone else uses it + + if (!(state & SocketStateFlags::WRITE_PRESENT)) + { + // there is a really rare race condition + // since we are using EDGE trigger for ::send + // if the buffer is not directly set, but we receive an EPOLL event, + // we might have to wait for the buffer to be set + // Otherwise this event is ignored forever + while ((state = m_atomicState.load()) & SocketStateFlags::WRITE_PENDING_SET) + std::this_thread::yield(); + } + + if (!(state & SocketStateFlags::WRITE_PRESENT)) + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_LOAD); + return; // There is no buffer :( + } + + if (state & SocketStateFlags::IGNORE_TRANSFERS) + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_LOAD); + return; // We are not allowed to react to it + } + + ssize_t newSentBytes = ::send(m_descriptor.GetNativeSocket(), (m_writeSrc.GetPtr() + m_writeSrcAlreadyTransferred), (m_writeSrc.GetSize() - m_writeSrcAlreadyTransferred), 0); + if (newSentBytes == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_DETAIL, "[Performance] Unnecessary call to PerformNonBlockingWrite()"); + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_LOAD); + return; + } + if (newSentBytes == -1) + { + if (errno != EWOULDBLOCK) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] ::send on client failed: %s", GetRemoteIpString().c_str(), SystemErrorToString(errno).c_str()); + } + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_LOAD); + return; + } + + m_writeSrcAlreadyTransferred += newSentBytes; + + if (m_writeSrcAlreadyTransferred == m_writeSrc.GetSize()) + { // we are done with this buffer + m_writeSrc = nullptr; + + auto tmpCallback = std::move(m_writeCallback); + m_atomicState.fetch_and(~(SocketStateFlags::WRITE_PENDING_LOAD | SocketStateFlags::WRITE_PRESENT)); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::NoError)); + } + else + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_LOAD); + } +} + +void IO::Networking::AsyncSocket::PerformContextSwitch() +{ + int state = m_atomicState.fetch_or(SocketStateFlags::CONTEXT_PENDING_LOAD); + + if (state & SocketStateFlags::CONTEXT_PENDING_LOAD) + return; // Someone else uses it + + MANGOS_ASSERT(state & SocketStateFlags::CONTEXT_PRESENT); // why was this function even called if we have no context? + + auto tmpCallback = std::move(m_contextCallback); + m_atomicState.fetch_and(~(SocketStateFlags::CONTEXT_PENDING_LOAD | SocketStateFlags::CONTEXT_PRESENT)); + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::CONTEXT_PENDING_LOAD); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed)); + return; // The socket was closed, no transfers are allowed + } + + MANGOS_DEBUG_ASSERT(tmpCallback); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::NoError)); +} + +void IO::Networking::AsyncSocket::StopPendingTransactionsAndForceClose() +{ + CloseSocket(); // this guarantees SHUTDOWN_PENDING to be set + + int state = m_atomicState.fetch_or(SocketStateFlags::IGNORE_TRANSFERS); + if (state & SocketStateFlags::IGNORE_TRANSFERS) + return; // maybe another thread also called StopPendingTransactionsAndForceClose + + // we must wait for the other threads to finish + int const pendingTransferMask = SocketStateFlags::WRITE_PENDING_SET | + SocketStateFlags::WRITE_PENDING_LOAD | + SocketStateFlags::READ_PENDING_SET | + SocketStateFlags::READ_PENDING_LOAD; + if (state & pendingTransferMask) + { + while ((state = m_atomicState.load()) & pendingTransferMask) + std::this_thread::yield(); // :( atomic::wait() was implemented in C++20 + } + + if (state & SocketStateFlags::WRITE_PRESENT) + { + auto tmpWriteCallback = std::move(m_writeCallback); + m_writeSrc = nullptr; + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PRESENT); + tmpWriteCallback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed)); + } + + if (state & SocketStateFlags::READ_PRESENT) + { + auto tmpReadCallback = std::move(m_readCallback); + m_readDstBuffer = nullptr; + m_readDstBufferBytesLeft = 0; + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpReadCallback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + } + + // Note: Don't even think about clearing CONTEXT_PRESENT here, since it's stored as a raw pointer in `m_contextSwitchQueue` +} + +void IO::Networking::AsyncSocket::EnterIoContext(std::function const& callback) +{ + int state = m_atomicState.fetch_or(SocketStateFlags::CONTEXT_PENDING_SET); + if (state & SocketStateFlags::CONTEXT_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::CONTEXT_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed)); + return; + } + + if (state & SocketStateFlags::CONTEXT_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::CONTEXT_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + m_contextCallback = callback; + m_atomicState.fetch_xor(SocketStateFlags::CONTEXT_PRESENT | SocketStateFlags::CONTEXT_PENDING_SET); // set PRESENT and unset PENDING_SET + + m_ctx->PostForImmediateInvocation(this); +} + +void IO::Networking::AsyncSocket::OnIoEvent(uint32_t event) +{ + int const CALLBACK_EVENT_FLAG = +#if defined(__linux__) + 0; +#elif defined(__APPLE__) + EVFILT_USER; +#else + #error "Unsupported" +#endif + + if (m_atomicState.load(std::memory_order_relaxed) & SocketStateFlags::IGNORE_TRANSFERS) + return; // This is just an initial check, must be atomically checked in the handlers later. + + if (event == CALLBACK_EVENT_FLAG) + { + PerformContextSwitch(); + return; + } + +#if defined(__linux__) + if (event & EPOLLERR) + { + int error = 0; + socklen_t errLen = sizeof(error); + if (::getsockopt(m_descriptor.GetNativeSocket(), SOL_SOCKET, SO_ERROR, (void*)&error, &errLen) == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] epoll reported socket error: %s", GetRemoteIpString().c_str(), SystemErrorToString(error).c_str()); + } + else + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] epoll reported socket error: Internal error", GetRemoteIpString().c_str()); + } + StopPendingTransactionsAndForceClose(); + } + else if (event & EPOLLRDHUP) + { + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] EPOLLRDHUP -> Going to disconnect.", GetRemoteIpString().c_str()); + StopPendingTransactionsAndForceClose(); + } + else + { + if (event & EPOLLIN) + PerformNonBlockingRead(); + + if (event & EPOLLOUT) + PerformNonBlockingWrite(); + } +#elif defined(__APPLE__) + switch ((int)event) // it's a "filter" from kqueue + { + case EVFILT_EXCEPT: + { + int error = 0; + socklen_t errLen = sizeof(error); + if (::getsockopt(m_descriptor.GetNativeSocket(), SOL_SOCKET, SO_ERROR, (void*)&error, &errLen) == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] kqueue reported socket exception: Error: %s", GetRemoteIpString().c_str(), SystemErrorToString(error).c_str()); + + if (error == 0) + break; + } + else + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[%s] kqueue reported socket exception: Internal error", GetRemoteIpString().c_str()); + } + StopPendingTransactionsAndForceClose(); + break; + } + case EVFILT_READ: + { + PerformNonBlockingRead(); + break; + } + case EVFILT_WRITE: + { + PerformNonBlockingWrite(); + break; + } + default: + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "Unhandled event %d", (int)event); + } +#else + #error "Unsupported" +#endif +} + +IO::NetworkError IO::Networking::AsyncSocket::SetNativeSocketOption_NoDelay(bool doNoDelay) +{ + if (IsClosing()) + return IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed); + + int optionValue = doNoDelay ? 1 : 0; + if (::setsockopt(m_descriptor.GetNativeSocket(), IPPROTO_TCP, TCP_NODELAY, (char*) &optionValue, sizeof(optionValue)) != 0) + return IO::NetworkError::FromSystemError(errno); + + return IO::NetworkError(IO::NetworkError::ErrorType::NoError); +} + +IO::NetworkError IO::Networking::AsyncSocket::SetNativeSocketOption_SystemOutgoingSendBuffer(int bytes) +{ + MANGOS_ASSERT(bytes >= 1); // although a buffer of size 1 would already be pretty small... + + if (IsClosing()) + return IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed); + + int optionValue = bytes; + if (::setsockopt(m_descriptor.GetNativeSocket(), SOL_SOCKET, SO_SNDBUF, (char*) &optionValue, sizeof(optionValue)) != 0) + return IO::NetworkError::FromSystemError(errno); + + return IO::NetworkError(IO::NetworkError::ErrorType::NoError); +} diff --git a/src/shared/IO/Networking/AsyncSocket_windows.cpp b/src/shared/IO/Networking/AsyncSocket_windows.cpp new file mode 100644 index 00000000000..ef958c2d055 --- /dev/null +++ b/src/shared/IO/Networking/AsyncSocket_windows.cpp @@ -0,0 +1,377 @@ +#include "AsyncSocket.h" +#include "Log.h" + +IO::NetworkError IO::Networking::AsyncSocket::InitializeAndFixateMemoryLocation() +{ + int state = m_atomicState.fetch_or(SocketStateFlags::IS_INITIALIZED); + MANGOS_ASSERT(!(state & SocketStateFlags::IS_INITIALIZED)); // can be only performed once + + // There is nothing to do on windows (using IOCP), since we are referencing this socket for each transfer individually + + return IO::NetworkError(NetworkError::ErrorType::NoError); +} + +IO::NetworkError IO::Networking::AsyncSocket::SetNativeSocketOption_NoDelay(bool doNoDelay) +{ + int optionValue = doNoDelay ? 1 : 0; + int result = ::setsockopt(m_descriptor.GetNativeSocket(), IPPROTO_TCP, TCP_NODELAY, (char*)&optionValue, sizeof(optionValue)); + if (result != 0) + return IO::NetworkError(NetworkError::ErrorType::InternalError, ::WSAGetLastError()); + return IO::NetworkError(NetworkError::ErrorType::NoError); +} + +IO::NetworkError IO::Networking::AsyncSocket::SetNativeSocketOption_SystemOutgoingSendBuffer(int bytes) +{ + int optionValue = bytes; + int result = ::setsockopt(m_descriptor.GetNativeSocket(), SOL_SOCKET, SO_SNDBUF, (char*)&optionValue, sizeof(optionValue)); + if (result != 0) + return IO::NetworkError(NetworkError::ErrorType::InternalError, ::WSAGetLastError()); + return IO::NetworkError(NetworkError::ErrorType::NoError); +} + +void IO::Networking::AsyncSocket::Read(char* target, std::size_t size, std::function const& callback) +{ + int state = m_atomicState.fetch_or(SocketStateFlags::READ_PENDING_SET); + MANGOS_DEBUG_ASSERT(state & SocketStateFlags::IS_INITIALIZED); + + if (state & SocketStateFlags::READ_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + + if (state & SocketStateFlags::READ_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (size == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ERROR: Tried to IO::Networking::AsyncSocket::Read(...) with size 0"); + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), 0); // technically not an error, we are just done with the buffer + return; + } + + m_readCallback = callback; + + int const bufferCount = 1; + struct BufferCtx + { + WSABUF buffers[bufferCount]; + }; + + std::shared_ptr bufferCtx(new BufferCtx{0}); + bufferCtx->buffers[0].len = size; + bufferCtx->buffers[0].buf = target; + + m_currentReadTask.InitNew([this, bufferCtx, size](DWORD errorCode) { + uint64_t bytesProcessed = m_currentReadTask.InternalHigh; + if (bytesProcessed == 0) + { // 0 means the socket is already closed on the other side + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] Empty response -> Going to disconnect.", GetRemoteIpString().c_str()); + CloseSocket(); + auto tmpCallback = std::move(m_readCallback); + m_currentReadTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + + if (bytesProcessed < bufferCtx->buffers[0].len) + { // We are not done yet. We need to requeue our task + bufferCtx->buffers[0].buf += bytesProcessed; + bufferCtx->buffers[0].len -= bytesProcessed; + + int const bufferCount = 1; + DWORD flags = 0; + int errorCode = ::WSARecv(m_descriptor.GetNativeSocket(), bufferCtx->buffers, bufferCount, nullptr, &flags, &(m_currentReadTask), nullptr); + if (errorCode) + { + int err = ::WSAGetLastError(); + if (err != WSA_IO_PENDING) // Pending means that this task was queued (which is what we want) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::WSARecv(...) Error: %u", err); + auto tmpCallback = std::move(m_readCallback); + m_currentReadTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::InternalError, err), 0); + return; + } + } + } + else + { + auto tmpCallback = std::move(m_readCallback); + m_currentReadTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), size); + } + }); + + DWORD flags = 0; + m_atomicState.fetch_xor(SocketStateFlags::READ_PRESENT | SocketStateFlags::READ_PENDING_SET); // set PRESENT and unset PENDING_SET + int errorCode = ::WSARecv(m_descriptor.GetNativeSocket(), bufferCtx->buffers, bufferCount, nullptr, &flags, &m_currentReadTask, nullptr); + if (errorCode) + { + int err = ::WSAGetLastError(); + if (err != WSA_IO_PENDING) // Pending means that this task was queued (which is what we want) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::WSARecv(...) Error: %u", err); + auto tmpCallback = std::move(m_readCallback); + m_currentReadTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::InternalError, err), 0); + return; + } + } +} + +void IO::Networking::AsyncSocket::ReadSome(char* target, std::size_t size, std::function const& callback) +{ + int state = m_atomicState.fetch_or(SocketStateFlags::READ_PENDING_SET); + MANGOS_DEBUG_ASSERT(state & SocketStateFlags::IS_INITIALIZED); + + if (state & SocketStateFlags::READ_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + + if (state & SocketStateFlags::READ_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed), 0); + return; + } + + if (size == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ERROR: Tried to IO::Networking::AsyncSocket::Read(...) with size 0"); + m_atomicState.fetch_and(~SocketStateFlags::READ_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), 0); // technically not an error, we are just done with the buffer + return; + } + + m_readCallback = callback; + + int const bufferCount = 1; + struct BufferCtx + { + WSABUF buffers[bufferCount]; + }; + + std::shared_ptr bufferCtx(new BufferCtx{0}); + bufferCtx->buffers[0].len = size; + bufferCtx->buffers[0].buf = target; + + m_currentReadTask.InitNew([this, bufferCtx](DWORD errorCode) { + uint64_t bytesProcessed = m_currentReadTask.InternalHigh; + if (bytesProcessed == 0) + { // 0 means the socket is already closed on the other side + sLog.Out(LOG_NETWORK, LOG_LVL_DEBUG, "[%s] Empty response -> Going to disconnect.", GetRemoteIpString().c_str()); + CloseSocket(); + auto tmpCallback = std::move(m_readCallback); + m_currentReadTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed), 0); + return; + } + + auto tmpCallback = std::move(m_readCallback); + m_currentReadTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::NoError), bytesProcessed); + }); + + DWORD flags = 0; + m_atomicState.fetch_xor(SocketStateFlags::READ_PRESENT | SocketStateFlags::READ_PENDING_SET); // set PRESENT and unset PENDING_SET + int errorCode = ::WSARecv(m_descriptor.GetNativeSocket(), bufferCtx->buffers, bufferCount, nullptr, &flags, &m_currentReadTask, nullptr); + if (errorCode) + { + int err = ::WSAGetLastError(); + if (err != WSA_IO_PENDING) // Pending means that this task was queued (which is what we want) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::WSARecv(...) Error: %u", err); + auto tmpCallback = std::move(m_readCallback); + m_currentReadTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::READ_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::InternalError, err), 0); + return; + } + } +} + +/// Warning: Using this function will NOT copy the buffer, dont overwrite it unless callback is triggered! +/// (but a reference to the smart_ptr will be held throughout the transfer, so you dont need to) +void IO::Networking::AsyncSocket::Write(IO::ReadableBuffer const& source, std::function const& callback) +{ + if (source.GetSize() > 8*1024*1024) + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[NETWORK] You are about to send a very large message (%llu bytes). The Windows Kernel will happily accept that. Split the Write(...) calls next time!", source.GetSize()); + + int state = m_atomicState.fetch_or(SocketStateFlags::WRITE_PENDING_SET); + MANGOS_DEBUG_ASSERT(state & SocketStateFlags::IS_INITIALIZED); + + if (state & SocketStateFlags::WRITE_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed)); + return; + } + + if (state & SocketStateFlags::WRITE_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + if (source.GetSize() == 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ERROR: Tried to IO::Networking::AsyncSocket::Write(...) with size 0"); + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::NoError)); // technically not an error, we are just done with the buffer + return; + } + + m_writeCallback = callback; + m_writeSrc = source; + + int const bufferCount = 1; + struct BufferCtx + { + WSABUF buffers[bufferCount]; + }; + + std::shared_ptr bufferCtx(new BufferCtx{0}); + bufferCtx->buffers[0].len = m_writeSrc.GetSize(); + bufferCtx->buffers[0].buf = (char*)(m_writeSrc.GetPtr()); + + m_currentWriteTask.InitNew([this, bufferCtx](DWORD errorCode) { + uint64_t bytesProcessed = m_currentWriteTask.InternalHigh; + + IO::NetworkError errorResult(IO::NetworkError::ErrorType::InternalError, errorCode); + + if (bytesProcessed == 0) + { // 0 means the socket is already closed on the other side + CloseSocket(); + errorResult = IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed); + } + else if (bytesProcessed < bufferCtx->buffers[0].len || errorCode != 0) + { // Compared to Read(...), the Write(...) system call should be able to transfer the whole buffer in one + CloseSocket(); + errorResult = IO::NetworkError(IO::NetworkError::ErrorType::InternalError, errorCode); + } + else + { + errorResult = IO::NetworkError(IO::NetworkError::ErrorType::NoError); + } + + auto tmpCallback = std::move(m_writeCallback); + m_writeSrc = nullptr; + m_currentWriteTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PRESENT); + tmpCallback(errorResult); + }); + + DWORD flags = 0; + m_atomicState.fetch_xor(SocketStateFlags::WRITE_PRESENT | SocketStateFlags::WRITE_PENDING_SET); // set PRESENT and unset PENDING_SET + int errorCode = ::WSASend(m_descriptor.GetNativeSocket(), bufferCtx->buffers, bufferCount, nullptr, flags, &m_currentWriteTask, nullptr); + if (errorCode) + { + int err = ::WSAGetLastError(); + if (err != WSA_IO_PENDING) // Pending means that this task was queued (which is what we want) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::WSASend(...) Error: %u", err); + auto tmpCallback = std::move(m_writeCallback); + m_writeSrc = nullptr; + m_currentWriteTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::WRITE_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::InternalError, err)); + return; + } + } +} + +void IO::Networking::AsyncSocket::CloseSocket() +{ + // set SHUTDOWN_PENDING flag, and check if there was already a previous one + if (m_atomicState.fetch_or(SocketStateFlags::SHUTDOWN_PENDING) & SocketStateFlags::SHUTDOWN_PENDING) + return; // there was already a ::shutdown() + + ::shutdown(m_descriptor.GetNativeSocket(), SD_BOTH); // will interrupt and fail all pending IOCP events and post them to the queue +} + +/// The callback is invoked in the IO thread +/// Useful for computational expensive operations (e.g. packing and encryption), that should be avoided in the main loop +void IO::Networking::AsyncSocket::EnterIoContext(std::function const& callback) +{ + int state = m_atomicState.fetch_or(SocketStateFlags::CONTEXT_PENDING_SET); + MANGOS_DEBUG_ASSERT(state & SocketStateFlags::IS_INITIALIZED); + + if (state & SocketStateFlags::CONTEXT_PENDING_SET) + { + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + if (state & SocketStateFlags::SHUTDOWN_PENDING) + { + m_atomicState.fetch_and(~SocketStateFlags::CONTEXT_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::SocketClosed)); + return; + } + + if (state & SocketStateFlags::CONTEXT_PRESENT) + { + m_atomicState.fetch_and(~SocketStateFlags::CONTEXT_PENDING_SET); + callback(IO::NetworkError(IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed)); + return; + } + + m_contextCallback = callback; + + m_currentContextTask.InitNew([this](DWORD errorCode) { + auto tmpCallback = std::move(m_contextCallback); + m_currentContextTask.Reset(); + m_atomicState.fetch_and(~SocketStateFlags::CONTEXT_PRESENT); + tmpCallback(IO::NetworkError(IO::NetworkError::ErrorType::NoError)); + }); + + m_atomicState.fetch_xor(SocketStateFlags::CONTEXT_PRESENT | SocketStateFlags::CONTEXT_PENDING_SET); // set PRESENT and unset PENDING_SET + + m_ctx->PostOperationForImmediateInvocation(&m_currentContextTask); +} + +IO::Networking::AsyncSocket::AsyncSocket(IO::IoContext* ctx, IO::Networking::SocketDescriptor socketDescriptor) + : m_ctx(ctx), m_descriptor(std::move(socketDescriptor)) +{ + // Attach our acceptor socket to our completion port + if (::CreateIoCompletionPort((HANDLE) m_descriptor.GetNativeSocket(), m_ctx->GetWindowsCompletionPort(), (u_long)0, 0) != m_ctx->GetWindowsCompletionPort()) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "::CreateIoCompletionPort(accept, ...) Error: %u", ::WSAGetLastError()); + return; + } +} diff --git a/src/shared/IO/Networking/DNS.cpp b/src/shared/IO/Networking/DNS.cpp new file mode 100644 index 00000000000..2a9c036f867 --- /dev/null +++ b/src/shared/IO/Networking/DNS.cpp @@ -0,0 +1,102 @@ +#include "./DNS.h" +#include "./Internal.h" +#include "Log.h" +#include "Errors.h" +#include "Util.h" +#include "IO/SystemErrorToString.h" + +#if defined(WIN32) +#include +#include +#elif defined(__linux__) || defined(__APPLE__) +#include +#include +#endif + +std::string IO::Networking::DNS::GetOwnHostname() +{ + char hostname[1024]; + if (::gethostname(hostname, sizeof(hostname)) == -1) + { + sLog.Out(LogType::LOG_NETWORK, LOG_LVL_ERROR, "IO ERROR: ::gethostname(...): %s", SystemErrorToString(errno).c_str()); + MANGOS_ASSERT(false); + } + return hostname; +} + +std::vector IO::Networking::DNS::ResolveDomainAll(std::string const& domainName, IpAddress::Type type) +{ + MANGOS_ASSERT(type == IpAddress::Type::IPv4); // TODO: this function is only tested with IPv4. `inet_ntop` will fail + + // Check if we can parse the domain as an IP + nonstd::optional maybeIp = IpAddress::TryParseFromString(domainName); + if (maybeIp) + { + IpAddress const& ip = maybeIp.value(); + MANGOS_ASSERT(ip.GetType() == type); + return { ip }; // The "domain" can be directly parsed as an IP + } + + // try to resolve the domain + addrinfo hints = {}; + hints.ai_family = type == IpAddress::Type::IPv4 ? AF_INET : AF_INET6; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + addrinfo* dnsResult = nullptr; + if (::getaddrinfo(domainName.c_str(), nullptr, &hints, &dnsResult) != 0) + { + sLog.Out(LogType::LOG_NETWORK, LOG_LVL_ERROR, "IO ERROR: ::getaddrinfo(...): %s", SystemErrorToString( +#if defined(WIN32) + ::WSAGetLastError() +#else + errno +#endif + ).c_str()); + return {}; // error occurred, empty return + } + + std::vector list; + + for (addrinfo* ptr = dnsResult; ptr != nullptr; ptr = ptr->ai_next) + { + if (ptr->ai_family == AF_INET) + { + sockaddr_in* sockaddr_ipv4 = reinterpret_cast(ptr->ai_addr); + IpAddress ip = IO::Networking::Internal::inet_ntop(&(sockaddr_ipv4->sin_addr)); + list.emplace_back(ip); + } + else if (ptr->ai_family == AF_INET6) + { + // TODO: Add inet_ntop for IPv6 + //sockaddr_in6* sockaddr_ipv6 = reinterpret_cast(ptr->ai_addr); + //IpAddress ip = IO::Networking::Internal::inet_ntop(&(sockaddr_ipv6->sin_addr)); + //list.emplace_back(ip); + } + else + { + MANGOS_ASSERT(false && (ptr->ai_family == AF_INET || ptr->ai_family == AF_INET6)); + } + } + + freeaddrinfo(dnsResult); + + return list; +} + +nonstd::optional IO::Networking::DNS::ResolveDomainSingle(std::string const& domainName, IpAddress::Type type, SelectionStrategy strategy) +{ + std::vector allIps = ResolveDomainAll(domainName, type); + if (allIps.empty()) + return nonstd::nullopt; // No IP found + + switch (strategy) + { + case SelectionStrategy::First: + return allIps.front(); + case SelectionStrategy::Random: + return allIps[urand(0, allIps.size() - 1)]; + } + + MANGOS_ASSERT(false); +} diff --git a/src/shared/IO/Networking/DNS.h b/src/shared/IO/Networking/DNS.h new file mode 100644 index 00000000000..ab1a94449af --- /dev/null +++ b/src/shared/IO/Networking/DNS.h @@ -0,0 +1,30 @@ +#ifndef MANGOS_IO_NETWORKING_DNS_H +#define MANGOS_IO_NETWORKING_DNS_H + +#include +#include "./IpAddress.h" + +namespace IO { namespace Networking { namespace DNS +{ + std::string GetOwnHostname(); + + /// Will also work with IP addresses without touching the DNS layer + /// \warning Will return an empty list if unable to resolve the domain + std::vector ResolveDomainAll(std::string const& domainName, IO::Networking::IpAddress::Type type); + + /// Different strategies on how to resolve multiple IPAddresses on the same domain name + /// This can happen if, for example, a domain has multiple "A-Records" with the same name but different IPs. + enum class SelectionStrategy + { + /// If multiple IPAddresses, take the first one + First, + + /// If multiple IPAddresses, take a random one, has a "load-balancing" effect + Random, + }; + + /// Just like `ResolveDomainAll` but will return at most one IPAddress + nonstd::optional ResolveDomainSingle(std::string const& domainName, IO::Networking::IpAddress::Type type, SelectionStrategy strategy = SelectionStrategy::First); +}}} // namespace IO::Networking + +#endif // MANGOS_IO_NETWORKING_DNS_H diff --git a/src/shared/IO/Networking/Internal.cpp b/src/shared/IO/Networking/Internal.cpp new file mode 100644 index 00000000000..149ed452ec1 --- /dev/null +++ b/src/shared/IO/Networking/Internal.cpp @@ -0,0 +1,61 @@ +#include "./Internal.h" + +#include "Errors.h" + +#if defined(WIN32) +#include +#include +#elif defined(__linux__) || defined(__APPLE__) +#include +#include +#endif + +/// Converts a native `IN_ADDR` to a `IO::Networking::IpAddress` +IO::Networking::IpAddress IO::Networking::Internal::inet_ntop(in_addr const* nativeAddress) +{ +#if defined(WIN32) + // We cant use ::inet_ntoa(...) because it's not thread safe. We cant use ::inet_ntop(...) because it's not WinXP compatible, so we have to do it ourselves. + int constexpr MAX_IPV4_LENGTH = 16; // "255.255.255.255" = length 15 + 1 for null-terminator + char ipv4AddressString[MAX_IPV4_LENGTH]; + { // This implementation was taken from ACE, should be universal + uint8_t const* p = reinterpret_cast(nativeAddress); + snprintf(ipv4AddressString, MAX_IPV4_LENGTH, "%d.%d.%d.%d", p[0], p[1], p[2], p[3]); + } + auto ipAddress = IO::Networking::IpAddress::TryParseFromString(ipv4AddressString); +#elif defined(__linux__) || defined(__APPLE__) + char ipv4AddressString[INET_ADDRSTRLEN]; + ::inet_ntop(AF_INET, nativeAddress, ipv4AddressString, INET_ADDRSTRLEN); + auto ipAddress = IO::Networking::IpAddress::TryParseFromString(ipv4AddressString); +#else + #error "Unsupported platform" +#endif + MANGOS_ASSERT(ipAddress.has_value()); // this should never fail, since we got a valid IP from IN_ADDR + return ipAddress.value(); +} + +void IO::Networking::Internal::CloseSocket(IO::Native::SocketHandle nativeSocket) +{ +#if defined(WIN32) + ::closesocket(nativeSocket); +#elif defined(__linux__) || defined(__APPLE__) + ::close(nativeSocket); +#else + #error "Unsupported platform" +#endif +} + +// Converts a `IO::Networking::IpAddress` to a native `IN_ADDR` +void IO::Networking::Internal::inet_pton(IO::Networking::IpAddress const& ipAddress, in_addr* out_dest) +{ + MANGOS_ASSERT(ipAddress.GetType() == IpAddress::Type::IPv4); + +#if defined(WIN32) + // We cant use `inet_pton`, because it's not supported on WinXP. + // But this method would basically just take the internal representation and store it in a union anyways ¯\_(ツ)_/¯ + out_dest->s_addr = ::htonl(ipAddress._getInternalIPv4ReprAsUint32()); +#elif defined(__linux__) || defined(__APPLE__) + MANGOS_ASSERT(::inet_pton(AF_INET, ipAddress.ToString().c_str(), out_dest) == 1); +#else + #error "Unsupported platform" +#endif +} diff --git a/src/shared/IO/Networking/Internal.h b/src/shared/IO/Networking/Internal.h new file mode 100644 index 00000000000..b39377f8d27 --- /dev/null +++ b/src/shared/IO/Networking/Internal.h @@ -0,0 +1,22 @@ +#ifndef MANGOS_IO_NETWORKING_INTERNAL_H +#define MANGOS_IO_NETWORKING_INTERNAL_H + +#include "./IpAddress.h" +#include "IO/NativeAliases.h" + +struct in_addr; + +namespace IO { namespace Networking { namespace Internal +{ + /// Converts a native `IN_ADDR` to a `IO::Networking::IpAddress` + IO::Networking::IpAddress inet_ntop(in_addr const* nativeAddress); + + /// Converts a `IO::Networking::IpAddress` to a native `IN_ADDR` + void inet_pton(IO::Networking::IpAddress const& ipAddress, in_addr* out_dest); + + /// Closes a socket + void CloseSocket(IO::Native::SocketHandle nativeSocket); + +}}} // IO::Networking::Internal + +#endif // MANGOS_IO_NETWORKING_INTERNAL_H diff --git a/src/shared/IO/Networking/IpAddress.cpp b/src/shared/IO/Networking/IpAddress.cpp new file mode 100644 index 00000000000..05601ea7d37 --- /dev/null +++ b/src/shared/IO/Networking/IpAddress.cpp @@ -0,0 +1,168 @@ +#include "IpAddress.h" +#include "Errors.h" + +#include + +IO::Networking::IpAddress IO::Networking::IpAddress::FromIpv4Uint32(uint32_t ip) +{ + IpAddress result; + result.m_address.type = Type::IPv4; + result.m_address.ipv4 = ip; + result.UpdateCachedString(); + return result; +} + +/// IPv4 Format: 255.255.255.255 +/// IPv6 Format: [FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF] +nonstd::optional IO::Networking::IpAddress::TryParseFromString(std::string const& ipAddressString) +{ + IpAddress result; + + size_t ipv6Begin = -1; + ipv6Begin = ipAddressString.find('['); + if (ipv6Begin == -1) + { + result.m_address.type = Type::IPv4; + + // IPv4 expected format: 255.255.255.255 + const char* const fixEndPtr = ipAddressString.c_str() + ipAddressString.size(); + + const char* tmpLastEndPtr = ipAddressString.c_str(); // <- loop variable + for (int i = 0; i < 4; i++) + { + char const* tmpStartPtr = tmpLastEndPtr; + tmpLastEndPtr = fixEndPtr; + + // Parse a number. Must be in range [0-255] + int64_t segment = std::strtoll(tmpStartPtr, const_cast(&tmpLastEndPtr), 10); + if (segment < 0 || segment > 255) + return nonstd::nullopt; // invalid number range, only [0..255] is valid + + if (i != 3) + { // We should not be at the end, and the next character should be a dot + if (tmpLastEndPtr >= fixEndPtr || tmpLastEndPtr[0] != '.') + return nonstd::nullopt; + tmpLastEndPtr++; // Skip the '.' + } + else + { // Last segment, we should be at the end + if (tmpLastEndPtr != fixEndPtr) + return nonstd::nullopt; + } + result.m_address.ipv4 <<= 8; + result.m_address.ipv4 |= (uint8_t) segment; + } + } + else + { + result.m_address.type = Type::IPv6; + // TODO: Implement me. Keep in mind all the IPv6 truncation possibilities + return nonstd::nullopt; + } + + result.UpdateCachedString(); + + return result; +} + +IO::Networking::IpAddress::Type IO::Networking::IpAddress::GetType() const +{ + return m_address.type; +} + +/// "127.0.0.1" would return 2130706433 +uint32_t IO::Networking::IpAddress::_getInternalIPv4ReprAsUint32() const +{ + MANGOS_ASSERT(m_address.type == Type::IPv4); + return m_address.ipv4; +} + +/// IPv4 Format: 255.255.255.255 +/// IPv6 Format: [FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF] +void IO::Networking::IpAddress::UpdateCachedString() +{ + if (m_address.type == Type::IPv4) + { + m_cachedToString = std::to_string((m_address.ipv4 >> (3*8)) & 0xFF) + "." + + std::to_string((m_address.ipv4 >> (2*8)) & 0xFF) + "." + + std::to_string((m_address.ipv4 >> (1*8)) & 0xFF) + "." + + std::to_string((m_address.ipv4 >> (0*8)) & 0xFF); + } + else + { + // The IPv6 spec allows multiple zeros in a row to be truncated _once_ to just "::" + // And segments where the number is just :FFFF: can be represented by :0: + // Leading zeros in a segment can be completely omitted. + // For example 1111:1100:0000:0000:0222:FFFF:0033:3333 + // Can be 1111:1100::222:0:33:3333 + int zeroStart = -1, zeroLength = 0; + + // Find the longest sequence of zeros for compression + { + bool inZeroSeq = false; + for (int i = 0; i < m_address.ipv6.size(); i++) + { + if (m_address.ipv6[i] == 0) + { + int length = 0; + while (i < 8 && m_address.ipv6[i] == 0) + { + ++length; + ++i; + } + if (length > zeroLength) + { + zeroStart = i - length; + zeroLength = length; + } + } + } + } + + std::stringstream result; + result << std::hex << std::uppercase; // enable number conversion to hex output + result << '['; + for (int i = 0; i < m_address.ipv6.size(); i++) + { + if (i == zeroStart) + { // we are in a zero truncation part + if (i == 0) + result << "::"; + else + result << ':'; + + i += zeroLength - 1; + } + else + { + if (i > 0) + result << ':'; // add : separator + + result << m_address.ipv6[i]; + } + } + result << ']'; + m_cachedToString = result.str(); + } +} + +bool IO::Networking::operator==(IpAddress const& lhs, IpAddress const& rhs) +{ + if (lhs.GetType() != rhs.GetType()) + return false; + + switch (lhs.GetType()) + { + case IpAddress::Type::IPv4: + return lhs.m_address.ipv4 == rhs.m_address.ipv4; + case IpAddress::Type::IPv6: + return lhs.m_address.ipv6 == rhs.m_address.ipv6; + } + + return false; +} + +bool IO::Networking::operator==(IpEndpoint const& lhs, IpEndpoint const& rhs) +{ + return lhs.ip == rhs.ip && lhs.port == rhs.port; +} diff --git a/src/shared/IO/Networking/IpAddress.h b/src/shared/IO/Networking/IpAddress.h new file mode 100644 index 00000000000..df8609f6d54 --- /dev/null +++ b/src/shared/IO/Networking/IpAddress.h @@ -0,0 +1,77 @@ +#ifndef MANGOS_IPADDRESS_H +#define MANGOS_IPADDRESS_H + +#include +#include +#include +#include +#include "nonstd/optional.hpp" + +namespace IO { namespace Networking +{ + class IpAddress + { + public: + enum class Type { IPv4, IPv6 }; + + static IpAddress FromIpv4Uint32(uint32_t ip); + + /// IPv4 Format: 255.255.255.255 + /// IPv6 Format: [FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF] + static nonstd::optional TryParseFromString(std::string const& ipAddressString); + + /// IPv4 Format: 255.255.255.255 + /// IPv6 Format: [FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF] + std::string const& ToString() const { return m_cachedToString; } + + Type GetType() const; + + /// "127.0.0.1" would return 2130706433 + uint32_t _getInternalIPv4ReprAsUint32() const; + private: + struct // NOLINT(*-pro-type-member-init) we manage the initialization on our own. + { + Type type = Type::IPv4; + union + { + uint32_t ipv4; // "127.0.0.1" would be 2130706433 + std::array ipv6; // index[0] is leftmost element in string representation + }; + } m_address; + + // Since IPs are used in a lot of logging, we just cache the result, so it is not re-created all the time + void UpdateCachedString(); + std::string m_cachedToString; + + public: + friend bool operator==(IpAddress const& lhs, IpAddress const& rhs); + }; + + class IpEndpoint + { + public: + IpAddress ip; + uint16_t port; + + public: + IpEndpoint() : ip{}, port{0} {} + IpEndpoint(IO::Networking::IpAddress ip, uint16_t port) : ip{std::move(ip)}, port{port} {} + + /// IPv4 Format: 255.255.255.255:1337 + /// IPv6 Format: [FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF]:1337 + std::string toString() const + { + return ip.ToString() + ':' + std::to_string(port); + }; + + public: + friend bool operator==(IpEndpoint const& lhs, IpEndpoint const& rhs); + }; + + // Forward declaration for operator== + bool operator==(IO::Networking::IpAddress const& lhs, IO::Networking::IpAddress const& rhs); + bool operator==(IO::Networking::IpEndpoint const& lhs, IO::Networking::IpEndpoint const& rhs); +}} // namespace IO::Networking + + +#endif //MANGOS_IPADDRESS_H diff --git a/src/shared/IO/Networking/NetworkError.cpp b/src/shared/IO/Networking/NetworkError.cpp new file mode 100644 index 00000000000..55d68f35968 --- /dev/null +++ b/src/shared/IO/Networking/NetworkError.cpp @@ -0,0 +1,53 @@ +#include "./NetworkError.h" +#include "../SystemErrorToString.h" + +std::string const& GetErrorBaseString(IO::NetworkError::ErrorType errorType) +{ + switch (errorType) + { + case IO::NetworkError::ErrorType::NoError: + { + static std::string txt = "NoError"; + return txt; + } + case IO::NetworkError::ErrorType::InternalError: + { + static std::string txt = "InternalError"; + return txt; + } + case IO::NetworkError::ErrorType::SocketClosed: + { + static std::string txt = "SocketClosed"; + return txt; + } + case IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed: + { + static std::string txt = "OnlyOneTransferPerDirectionAllowed"; + return txt; + } + case IO::NetworkError::ErrorType::Timeout: + { + static std::string txt = "Timeout"; + return txt; + } + case IO::NetworkError::ErrorType::InvalidProtocolBehavior: + { + static std::string txt = "InvalidProtocolBehavior"; + return txt; + } + default: + { + static std::string txt = "UndefinedErrorType"; + return txt; + } + } +} + +std::string IO::NetworkError::ToString() const +{ + std::string result = GetErrorBaseString(this->GetErrorType()); + if (m_additionalOsErrorCode) + result += " (Code " + std::to_string(m_additionalOsErrorCode) + ": " + SystemErrorToString(m_additionalOsErrorCode) + ")"; + + return result; +} diff --git a/src/shared/IO/Networking/NetworkError.h b/src/shared/IO/Networking/NetworkError.h new file mode 100644 index 00000000000..ecca0248839 --- /dev/null +++ b/src/shared/IO/Networking/NetworkError.h @@ -0,0 +1,40 @@ +#ifndef MANGOS_IO_NETWORKING_NETWORKERROR_H +#define MANGOS_IO_NETWORKING_NETWORKERROR_H + +#include + +namespace IO +{ + class NetworkError { + public: + enum class ErrorType : int + { + NoError, + InternalError, // see m_additionalOsErrorCode + SocketClosed, + OnlyOneTransferPerDirectionAllowed, + Timeout, + InvalidProtocolBehavior, + }; + public: + explicit constexpr NetworkError(ErrorType errorType) : NetworkError(errorType, 0) {}; + explicit constexpr NetworkError(ErrorType errorType, int osErrorCode) : m_error{errorType}, m_additionalOsErrorCode{osErrorCode} {}; + + ErrorType GetErrorType() const { return m_error; }; + + explicit operator bool() const { return GetErrorType() != ErrorType::NoError; }; + std::string ToString() const; + + static NetworkError FromSystemError(int osErrorCode) + { + return NetworkError(ErrorType::InternalError, osErrorCode); + } + private: + ErrorType m_error; + /// internal variable for ToString(), might be os and situation dependent (On windows there is ::GetLastError()/errno and ::WSAGetLastError()) + int m_additionalOsErrorCode; + }; + +} // namespace IO + +#endif //MANGOS_IO_NETWORKING_NETWORKERROR_H diff --git a/src/shared/IO/Networking/SocketConnector.cpp b/src/shared/IO/Networking/SocketConnector.cpp new file mode 100644 index 00000000000..9ed8010f662 --- /dev/null +++ b/src/shared/IO/Networking/SocketConnector.cpp @@ -0,0 +1,106 @@ +#include "SocketConnector.h" +#include "Internal.h" + +#if defined(__linux__) || defined(__APPLE__) + #include + #include + #include + #include "IO/Utils_Unix.h" + #define GetNetworkError() errno +#elif defined(WIN32) + #include + #include + #define GetNetworkError() ::WSAGetLastError() +#endif + +#if defined(__linux__) + #include + #include +#elif defined(__APPLE__) + #include + #include +#endif + +nonstd::expected IO::Networking::SocketConnector::ConnectBlocking(IO::Networking::IpEndpoint const& target, std::chrono::milliseconds timeoutMs) +{ + IO::Native::SocketHandle nativeSocket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (nativeSocket == -1) + { + return nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, GetNetworkError())); + } + + ::sockaddr_in targetAddress{}; + targetAddress.sin_family = AF_INET; + targetAddress.sin_port = htons(target.port); + IO::Networking::Internal::inet_pton(target.ip, &(targetAddress.sin_addr)); + + // Set socket to non-blocking mode, so the `::connect` does not block + +#if defined(__linux__) || defined(__APPLE__) + IO::NetworkError err = IO::Utils::SetFdStatusFlag(nativeSocket, O_NONBLOCK); + if (err) + return nonstd::make_unexpected(err); +#elif defined(WIN32) + u_long mode = 1; + int ioCtlStatus = ::ioctlsocket(nativeSocket, FIONBIO, &mode); + if (ioCtlStatus != 0) + return nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, GetNetworkError())); +#endif + + int result = ::connect(nativeSocket, (struct sockaddr *)&targetAddress, sizeof(targetAddress)); + if (result == -1) + { + int lastError = GetNetworkError(); + if (lastError != +#if defined(__linux__) || defined(__APPLE__) + EINPROGRESS +#elif defined(WIN32) + WSAEWOULDBLOCK +#endif + ) + { // Oh, this is an actual error :( + IO::Networking::Internal::CloseSocket(nativeSocket); + return nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, lastError)); + } + } + + ::timeval tv{}; + tv.tv_sec = (long)(timeoutMs.count() / 1000L); + tv.tv_usec = (long)((timeoutMs.count() % 1000L) * 1000L); + + ::fd_set selectFileDescriptors; + FD_ZERO(&selectFileDescriptors); + FD_SET(nativeSocket, &selectFileDescriptors); + + // wait for some kind response + int selectStatus = ::select((int)(nativeSocket + 1), &selectFileDescriptors, &selectFileDescriptors, &selectFileDescriptors, &tv); + if (selectStatus == -1) + { // ::select internal error + int lastError = GetNetworkError(); + IO::Networking::Internal::CloseSocket(nativeSocket); + return nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, lastError)); + } + + if (selectStatus == 0) + { // timeout + IO::Networking::Internal::CloseSocket(nativeSocket); + return nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::Timeout, 0)); + } + + int socketError = 0; + socklen_t socketErrorLength = sizeof(socketError); + int getSocketOptStatus = ::getsockopt(nativeSocket, SOL_SOCKET, SO_ERROR, (char*)&socketError, &socketErrorLength); + if (getSocketOptStatus != 0) + { + IO::Networking::Internal::CloseSocket(nativeSocket); + return nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, GetNetworkError())); + } + + if (socketError != 0) + { + IO::Networking::Internal::CloseSocket(nativeSocket); + return nonstd::make_unexpected(IO::NetworkError(NetworkError::ErrorType::InternalError, socketError)); + } + + return IO::Networking::SocketDescriptor(nativeSocket, target); +} diff --git a/src/shared/IO/Networking/SocketConnector.h b/src/shared/IO/Networking/SocketConnector.h new file mode 100644 index 00000000000..fba0baf0064 --- /dev/null +++ b/src/shared/IO/Networking/SocketConnector.h @@ -0,0 +1,67 @@ +#ifndef MANGOS_IO_NETWORKING_ASYNCSOCKETCONNECTOR_H +#define MANGOS_IO_NETWORKING_ASYNCSOCKETCONNECTOR_H + +#include +#include "SocketDescriptor.h" +#include "NetworkError.h" + +#include "nonstd/expected.hpp" + +namespace IO { namespace Networking +{ + /** Helper class to create a SocketDescriptor which connects to another server + \example + + // Create IpEndpoint + auto maybeIp = IO::Networking::IpAddress::TryParseFromString("127.0.0.1"); + MANGOS_ASSERT(maybeIp.has_value()); + IO::Networking::IpAddress ip = maybeIp.value(); + uint16 port = 8080; + IO::Networking::IpEndpoint endpoint(ip, port); + + // Try to connect + auto maybeSocketDescriptor = IO::Networking::SocketConnector::ConnectBlocking(endpoint, std::chrono::seconds(10)); + MANGOS_ASSERT(maybeSocketDescriptor.has_value()); + + // Bind socketDescriptor to AsyncSocket and initialize + auto socket = std::make_shared(ctx, std::move(maybeSocketDescriptor.value())); + MANGOS_ASSERT(!(socket->InitializeAndFixateMemoryLocation())); + + // Send example request + std::string requestString = "Hello World!!!"; + std::vector request(requestString.begin(), requestString.end()); + socket->Write(std::move(request), [socket](IO::NetworkError const& error) + { + MANGOS_ASSERT(!error); + + // Receive the response + auto response = std::make_shared>(); + response->resize(1024); + socket->ReadSome(response->data(), response->size(), [socket, response](IO::NetworkError const& error, size_t actuallyRead) + { + MANGOS_ASSERT(!error); + + std::string responseString(response->data(), actuallyRead); + std::cout << responseString << std::endl; + socket->CloseSocket(); + }); + }); + */ + class SocketConnector + { + public: + SocketConnector() = delete; + + template + static nonstd::expected ConnectBlocking(IO::Networking::IpEndpoint const& target, std::chrono::duration timeout) + { + return ConnectBlocking(target, std::chrono::duration_cast(timeout)); + } + + /// Creates a socket and connects it to the target endpoint. + /// Check for errors in the return value. + static nonstd::expected ConnectBlocking(IO::Networking::IpEndpoint const& target, std::chrono::milliseconds timeoutMs); + }; +}} // namespace IO::Networking + +#endif // MANGOS_IO_NETWORKING_ASYNCSOCKETCONNECTOR_H diff --git a/src/shared/IO/Networking/SocketDescriptor.cpp b/src/shared/IO/Networking/SocketDescriptor.cpp new file mode 100644 index 00000000000..9643813a6a7 --- /dev/null +++ b/src/shared/IO/Networking/SocketDescriptor.cpp @@ -0,0 +1,21 @@ +#include "SocketDescriptor.h" +#include "IO/Networking/Internal.h" + +#include "Errors.h" + +IO::Networking::SocketDescriptor::SocketDescriptor(IO::Native::SocketHandle nativeSocket, IO::Networking::IpEndpoint remoteEndpoint) + : m_nativeSocket(nativeSocket), m_remoteEndpoint(remoteEndpoint), m_isClosed(false) +{ +} + +IO::Networking::SocketDescriptor::~SocketDescriptor() +{ + MANGOS_ASSERT(m_isClosed); +} + +void IO::Networking::SocketDescriptor::CloseSocket() +{ + MANGOS_ASSERT(!m_isClosed); + m_isClosed = true; + IO::Networking::Internal::CloseSocket(m_nativeSocket); +} diff --git a/src/shared/IO/Networking/SocketDescriptor.h b/src/shared/IO/Networking/SocketDescriptor.h new file mode 100644 index 00000000000..11834deff67 --- /dev/null +++ b/src/shared/IO/Networking/SocketDescriptor.h @@ -0,0 +1,37 @@ +#ifndef MANGOS_IO_NETWORKING_SOCKETDESCRIPTOR_H +#define MANGOS_IO_NETWORKING_SOCKETDESCRIPTOR_H + +#include "IO/Networking/IpAddress.h" +#include "IO/NativeAliases.h" + +#include "Policies/ObjectConstructorTraits.h" + +#include + +namespace IO { namespace Networking { + +/// Is the owner of a native socket. Cannot be detached from it. +/// The socket must be closed before this is destructed. +class SocketDescriptor final : public MaNGOS::Policies::NoCopyButAllowMove +{ +public: + explicit SocketDescriptor(Native::SocketHandle nativeSocket, IO::Networking::IpEndpoint remoteEndpoint); + SocketDescriptor(SocketDescriptor&& other) noexcept : m_isClosed(other.m_isClosed), m_remoteEndpoint(other.m_remoteEndpoint), m_nativeSocket(other.m_nativeSocket) + { other.m_isClosed = true; } + ~SocketDescriptor(); + + void CloseSocket(); + + bool IsClosed() const { return m_isClosed; } + IO::Native::SocketHandle const& GetNativeSocket() const { return m_nativeSocket; } + IO::Networking::IpEndpoint const& GetRemoteEndpoint() const { return m_remoteEndpoint; } + + private: + bool m_isClosed; + IO::Native::SocketHandle const m_nativeSocket; + IO::Networking::IpEndpoint const m_remoteEndpoint; +}; + +}} // namespace IO::Networking + +#endif //MANGOS_IO_NETWORKING_SOCKETDESCRIPTOR_H diff --git a/src/shared/IO/Networking/Utils.cpp b/src/shared/IO/Networking/Utils.cpp new file mode 100644 index 00000000000..67f4b338730 --- /dev/null +++ b/src/shared/IO/Networking/Utils.cpp @@ -0,0 +1,19 @@ +#include "./Utils.h" +#include "../../Errors.h" + +/// Checks whenever the same (IPv4) Address is in the same subnet as the other one +bool IO::Networking::IsInSameSubnet(IpAddress const& ipAddressInQuestion, IpAddress const& subnetIpAddress, uint8_t subnetMaskInCidrNotation) +{ + if (ipAddressInQuestion.GetType() != IpAddress::Type::IPv4 || subnetIpAddress.GetType() != IpAddress::Type::IPv4) + return false; + + MANGOS_ASSERT(subnetMaskInCidrNotation >= 0 && subnetMaskInCidrNotation <= 32); // CIDR notation means that "255.255.255.0" is actually "24" + + // An IPv4 address is in the same subnet if the first n-bits (subnetMaskBits) are the same + uint32_t binarySubnetMask = 0xFFFFFFFF << (32 - subnetMaskInCidrNotation); + uint32_t inQuestionNet = ipAddressInQuestion._getInternalIPv4ReprAsUint32() & binarySubnetMask; + uint32_t subnetNet = subnetIpAddress._getInternalIPv4ReprAsUint32() & binarySubnetMask; + bool isInSameSubnet = (inQuestionNet == subnetNet); + + return isInSameSubnet; +} diff --git a/src/shared/IO/Networking/Utils.h b/src/shared/IO/Networking/Utils.h new file mode 100644 index 00000000000..9a69f10e615 --- /dev/null +++ b/src/shared/IO/Networking/Utils.h @@ -0,0 +1,13 @@ +#ifndef MANGOS_IO_NETWORKING_UTILS_H +#define MANGOS_IO_NETWORKING_UTILS_H + +#include +#include "./IpAddress.h" + +namespace IO { namespace Networking +{ + /// Checks whenever the same (IPv4) Address is in the same subnet as the other one + bool IsInSameSubnet(IpAddress const& ipAddressInQuestion, IpAddress const& subnetIpAddress, uint8_t subnetMaskInCidrNotation); +}} // namespace IO::Networking + +#endif //MANGOS_IO_NETWORKING_UTILS_H diff --git a/src/shared/IO/README.md b/src/shared/IO/README.md new file mode 100644 index 00000000000..348a75de7e3 --- /dev/null +++ b/src/shared/IO/README.md @@ -0,0 +1,110 @@ +# vMaNGOS custom IO library + +A custom IO library inspired by Boost::ASIO using native system calls. + +There are slightly different backend implementations on each OS: +- Windows `IOCP` +- Linux `epoll` (with the possibility to add `io_uring` support) +- macOS `kqueue` + +## Comparison to Boost +The usage is very similar to Boost. +Most of the time you have to rename functions to achieve the same thing. + +| vMaNGOS | Boost ASIO | +|-----------------------------------------|--------------------------------------| +| `IO::IoContext` | `boost::asio::io_contex` | +| `IO::NetworkError` | `boost::system::error_code` | +| `IO::Networking::SocketDescriptor` | `boost::asio::detail::socket_holder` | +| `IO::Networking::IpAddress` | `boost::asio::ip::address` | +| `IO::Networking::IpEndpoint` | `boost::asio::ip::basic_endpoint` | +| `IO::Networking::AsyncSocketAcceptor` | `boost::asio::ip::tcp::acceptor` | +| `IO::Networking::AsyncSocketConnector` | `boost::asio::connect` (function) | +| `IO::Networking::AsyncSocket` | `boost::asio::ip::tcp::socket` | +| `IO::ReadableBuffer` | `boost::asio::buffer` | +| And probably a lot more | ... | + +### One transfer per direction restriction +In Boost, multiple actions can be queued on the same socket, +but our implementation restricts it to a single transfer per direction. +This enables us to preallocate resources for IOCP and similar mechanisms directly on the socket, +so each Read/Write/SwitchContext operation avoids additional memory allocations +_(except for the callback and minimal OS-level operations)_. + +### Callbacks might not context switch +A context switch will execute a given callback in an IO thread. +Unlike Boost, where system calls always occur in an IO thread context, +we only switch context when explicitly requested (`socket->EnterIoContext(...)`) +or when a system call would otherwise block. +Developers must handle both scenarios +and ensure the socket handle remains valid for the entire transfer duration +until the callback is invoked. + +## Overview of the most vital elements + +### NetworkError +When a function or callback returns `IO::NetworkError` **you always have to check if there is an error present.** +Otherwise, you will run in undefined behavior and hard-crash your application. + +### IoContext +`IO::IoContext` is the main processing part where everything comes together. +Special IO threads created by you should run `ctx->RunUntilShutdown()`. +Multiple threads can run this function at the same time. +Callbacks are invoked in those threads. + +### AsyncSocketAcceptor +`IO::Networking::AsyncSocketAcceptor` can bind to a TCP port and accept incoming connects. + +When a new connection is accepted, a callback is invoked with `SocketDescirptor` as a parameter. +You can `std::move` it into an `AsyncSocket`, to get a fully working socket. + +## AsyncSocket +`IO::Networking::AsyncSocket` manages asynchronous read and write operations on a socket. +(_see one transfer per-direction restriction_) + +Before initiating any transfers, you must call `InitializeAndFixateMemoryLocation()` +to prepare the socket for IO operations. +For all transfers you have to keep the socket alive until the callback is called. + +**Read Operations**: Use `Read()`, `ReadSome()`, or `ReadSkip()` to read data asynchronously. +You must ensure that the buffer remains valid until the callback is invoked, +as the memory is not copied internally. + +**Write Operations**: Use `Write()` to send data asynchronously. +Using `IO::ReadableBuffer` will copy a shared_ptr of the pointer. +You don't need to hold a reference to it. + +**Context Switching**: You can use `EnterIoContext()` to explicitly execute a callback in the IO thread. +This is useful for tasks that require offloading from the main thread (e.g. encryption or packing). + +# Example Call Flow +This is a simplified call graph, showing the general flow of control made by a `ReadAsync` call. +``` + +----------------------------------------------+ + | | + | "User Code" e.g. AuthSocket | + | | + +----------+-----------------------------------+ + | ^ + | | + 1. ReadAsync | | 4. Callback + | | + v | + +-------------------------+-----+ + | | + | AsyncSocket | + | | + +-+-----------------------------+ + | ^ ^ + | | | +2. Read SystemCall | | 3a. * | 3. Notify + | | | + v | | + +----+-+ +--------+---------+ + | OS +------->| IoContext | + | | Queue | Multi-Threaded | + +------+ +------------------+ +3a.* = On Linux/macOS with POSIX, might directly invoke callback + if buffer can be filled instantly. No IO is queued. +``` +_Made with [asciiflow](https://asciiflow.com/)._ diff --git a/src/shared/IO/ReadableBuffer.h b/src/shared/IO/ReadableBuffer.h new file mode 100644 index 00000000000..0e2349691ae --- /dev/null +++ b/src/shared/IO/ReadableBuffer.h @@ -0,0 +1,294 @@ +#ifndef MANGOS_IO_SMARTBUFFER_H +#define MANGOS_IO_SMARTBUFFER_H + +#include +#include "Platform/Define.h" +#include "Policies/ObjectConstructorTraits.h" +#include "ByteBuffer.h" + +namespace IO +{ + /// A "SmartBuffer" which stores a reference to a std::shared_ptr<> and + /// exposes the `size` and `pointer` to the data in an unified interface + /// + /// Since this ReadableBuffer is intended to be used on AsyncSocket, the size and pointer of the buffer is cached. + /// Do not modify the buffer that this is holding. + /// Create a new ReadableBuffer for each transfer. + class ReadableBuffer // replace me with C++17 std::variant + { + public: + ReadableBuffer() : m_ptr(nullptr), m_size(0), m_type(BufferType::Unset) {} + + // Constructors from std::shared_ptr + ReadableBuffer(std::shared_ptr const& source) + : m_ptr(source->contents()), m_size(source->size()), m_type(BufferType::ByteBuffer) + { + new(&m_buffer.ByteBufferRef) std::shared_ptr(source); + } + + ReadableBuffer(std::shared_ptr&& source) + : m_ptr(source->contents()), m_size(source->size()), m_type(BufferType::ByteBuffer) + { + new(&m_buffer.ByteBufferRef) std::shared_ptr(std::move(source)); + } + + ReadableBuffer(std::shared_ptr const> const& source) + : m_ptr(source->data()), m_size(source->size()), m_type(BufferType::VectorU8) + { + new(&m_buffer.VectorU8) std::shared_ptr const>(source); + } + + ReadableBuffer(std::shared_ptr>&& source) + : m_ptr(source->data()), m_size(source->size()), m_type(BufferType::VectorU8) + { + new(&m_buffer.VectorU8) std::shared_ptr>(std::move(source)); + } + + ReadableBuffer(std::shared_ptr const> const& source) + : m_ptr(reinterpret_cast(source->data())), m_size(source->size()), m_type(BufferType::VectorS8) + { + new(&m_buffer.VectorS8) std::shared_ptr const>(source); + } + + ReadableBuffer(std::shared_ptr>&& source) + : m_ptr(reinterpret_cast(source->data())), m_size(source->size()), m_type(BufferType::VectorS8) + { + new(&m_buffer.VectorS8) std::shared_ptr const>(std::move(source)); + } + + ReadableBuffer(std::shared_ptr const> const& source) + : m_ptr(reinterpret_cast(source->data())), m_size(source->size()), m_type(BufferType::VectorN8) + { + new(&m_buffer.VectorN8) std::shared_ptr const>(source); + } + + ReadableBuffer(std::shared_ptr>&& source) + : m_ptr(reinterpret_cast(source->data())), m_size(source->size()), m_type(BufferType::VectorN8) + { + new(&m_buffer.VectorN8) std::shared_ptr const>(std::move(source)); + } + + ReadableBuffer(std::shared_ptr const& source, size_t size) + : m_ptr(reinterpret_cast(source.get())), m_size(size), m_type(BufferType::PtrU8) + { + new(&m_buffer.PtrU8) std::shared_ptr(source); + } + + // Constructor from buffer, std::move, we take ownership + ReadableBuffer(ByteBuffer&& source) : ReadableBuffer(std::move(std::make_shared(std::move(source)))) + { + } + + ReadableBuffer(std::vector&& source) : ReadableBuffer(std::move(std::make_shared>(std::move(source)))) + { + } + + ReadableBuffer(std::vector&& source) : ReadableBuffer(std::move(std::make_shared>(std::move(source)))) + { + } + + ReadableBuffer(std::vector&& source) : ReadableBuffer(std::move(std::make_shared>(std::move(source)))) + { + } + + // nullptr stuff + ReadableBuffer(std::nullptr_t) : m_ptr(nullptr), m_size(0), m_type(BufferType::Unset) {} + ReadableBuffer& operator=(std::nullptr_t) { + m_ptr = nullptr; + m_size = 0; + m_type = BufferType::Unset; + return *this; + } + + void Destruct() + { + switch (m_type) + { + case BufferType::ByteBuffer: + m_buffer.ByteBufferRef.~shared_ptr(); + break; + case BufferType::VectorU8: + m_buffer.VectorU8.~shared_ptr(); + break; + case BufferType::VectorS8: + m_buffer.VectorS8.~shared_ptr(); + break; + case BufferType::VectorN8: + m_buffer.VectorN8.~shared_ptr(); + break; + case BufferType::PtrU8: + m_buffer.PtrU8.~shared_ptr(); + break; + + case BufferType::Unset: + return; // dont set type again + } + m_type = BufferType::Unset; + } + + // Destructor + ~ReadableBuffer() + { + Destruct(); + } + + // copy + ReadableBuffer(ReadableBuffer const& other) + : m_ptr(other.m_ptr), m_size(other.m_size), m_type(other.m_type) + { + switch (m_type) + { + case BufferType::ByteBuffer: + new(&m_buffer.ByteBufferRef) std::shared_ptr(other.m_buffer.ByteBufferRef); + break; + case BufferType::VectorU8: + new(&m_buffer.VectorU8) std::shared_ptr const>(other.m_buffer.VectorU8); + break; + case BufferType::VectorS8: + new(&m_buffer.VectorS8) std::shared_ptr const>(other.m_buffer.VectorS8); + break; + case BufferType::VectorN8: + new(&m_buffer.VectorN8) std::shared_ptr const>(other.m_buffer.VectorN8); + break; + case BufferType::PtrU8: + new(&m_buffer.PtrU8) std::shared_ptr(other.m_buffer.PtrU8); + break; + + case BufferType::Unset: + break; + } + } + + ReadableBuffer& operator=(ReadableBuffer const& other) + { + if (this == &other) + return *this; // Self-assignment check + + m_ptr = other.m_ptr; + m_size = other.m_size; + m_type = other.m_type; + switch (m_type) + { + case BufferType::ByteBuffer: + new(&m_buffer.ByteBufferRef) std::shared_ptr(other.m_buffer.ByteBufferRef); + break; + case BufferType::VectorU8: + new(&m_buffer.VectorU8) std::shared_ptr const>(other.m_buffer.VectorU8); + break; + case BufferType::VectorS8: + new(&m_buffer.VectorS8) std::shared_ptr const>(other.m_buffer.VectorS8); + break; + case BufferType::VectorN8: + new(&m_buffer.VectorN8) std::shared_ptr const>(other.m_buffer.VectorN8); + break; + case BufferType::PtrU8: + new(&m_buffer.PtrU8) std::shared_ptr(other.m_buffer.PtrU8); + break; + + case BufferType::Unset: + break; + } + return *this; + } + + // move + ReadableBuffer(ReadableBuffer&& other) noexcept + : m_ptr(other.m_ptr), m_size(other.m_size), m_type(other.m_type) + { + switch (m_type) + { + case BufferType::ByteBuffer: + new(&m_buffer.ByteBufferRef) std::shared_ptr(std::move(other.m_buffer.ByteBufferRef)); + break; + case BufferType::VectorU8: + new(&m_buffer.VectorU8) std::shared_ptr const>(std::move(other.m_buffer.VectorU8)); + break; + case BufferType::VectorS8: + new(&m_buffer.VectorS8) std::shared_ptr const>(std::move(other.m_buffer.VectorS8)); + break; + case BufferType::VectorN8: + new(&m_buffer.VectorN8) std::shared_ptr const>(std::move(other.m_buffer.VectorN8)); + break; + case BufferType::PtrU8: + new(&m_buffer.PtrU8) std::shared_ptr(std::move(other.m_buffer.PtrU8)); + break; + + case BufferType::Unset: + break; + } + } + + ReadableBuffer& operator=(ReadableBuffer&& other) noexcept + { + m_ptr = other.m_ptr; + m_size = other.m_size; + m_type = other.m_type; + switch (m_type) + { + case BufferType::ByteBuffer: + new(&m_buffer.ByteBufferRef) std::shared_ptr(std::move(other.m_buffer.ByteBufferRef)); + break; + case BufferType::VectorU8: + new(&m_buffer.VectorU8) std::shared_ptr const>(std::move(other.m_buffer.VectorU8)); + break; + case BufferType::VectorS8: + new(&m_buffer.VectorS8) std::shared_ptr const>(std::move(other.m_buffer.VectorS8)); + break; + case BufferType::VectorN8: + new(&m_buffer.VectorN8) std::shared_ptr const>(std::move(other.m_buffer.VectorN8)); + break; + case BufferType::PtrU8: + new(&m_buffer.PtrU8) std::shared_ptr(std::move(other.m_buffer.PtrU8)); + break; + + case BufferType::Unset: + break; + } + return *this; + } + + size_t GetSize() const + { + return m_size; + } + uint8 const* GetPtr() const + { + return m_ptr; + } + + private: + enum class BufferType + { + Unset, + ByteBuffer, + VectorU8, // std::vector + VectorS8, // std::vector + VectorN8, // std::vector (not specified) + PtrU8, // raw pointer + }; + + union BufferUnion + { + std::shared_ptr ByteBufferRef; + std::shared_ptr const> VectorU8; + std::shared_ptr const> VectorS8; + std::shared_ptr const> VectorN8; + std::shared_ptr PtrU8; + + // we allocate and deallocate manually + BufferUnion() + { + } + ~BufferUnion() + { + } + }; + + uint8 const* m_ptr; + size_t m_size; + BufferType m_type; + BufferUnion m_buffer; + }; +} + +#endif // MANGOS_IO_SMARTBUFFER_H diff --git a/src/shared/IO/SystemErrorToString.cpp b/src/shared/IO/SystemErrorToString.cpp new file mode 100644 index 00000000000..adcfa317438 --- /dev/null +++ b/src/shared/IO/SystemErrorToString.cpp @@ -0,0 +1,32 @@ +#include "./SystemErrorToString.h" + +#if defined(WIN32) +#include +#elif defined(__linux__) +#include +#endif + +constexpr int MAX_ERROR_TEXT_LENGTH = 255; +thread_local char g_threadLocalStorage[MAX_ERROR_TEXT_LENGTH]; + +// The buffer is thread_local, don't free it +char const* SystemErrorToCString(int nativeSystemErrorCode) { +#if defined(WIN32) + if (!FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_MAX_WIDTH_MASK, nullptr, nativeSystemErrorCode, MAKELANGID(LANG_NEUTRAL, SUBLANG_NEUTRAL), g_threadLocalStorage, MAX_ERROR_TEXT_LENGTH, nullptr)) + return ""; + return g_threadLocalStorage; +#elif defined(__linux__) + // Linux might not need our buffer in all cases, sometimes it has a pointer to the text already + return ::strerror_r(nativeSystemErrorCode, g_threadLocalStorage, sizeof(g_threadLocalStorage)); +#elif defined(__APPLE__) + if (::strerror_r(nativeSystemErrorCode, g_threadLocalStorage, sizeof(g_threadLocalStorage)) != 0) + return ""; + return g_threadLocalStorage; +#else +#error "IO::SystemErrorToCString(...) not supported on your platform" +#endif +} + +std::string SystemErrorToString(int nativeSystemErrorCode) { + return '(' + std::to_string(nativeSystemErrorCode) + ") " + SystemErrorToCString(nativeSystemErrorCode); +} diff --git a/src/shared/IO/SystemErrorToString.h b/src/shared/IO/SystemErrorToString.h new file mode 100644 index 00000000000..4b90153c767 --- /dev/null +++ b/src/shared/IO/SystemErrorToString.h @@ -0,0 +1,9 @@ +#ifndef MANGOS_IO_SYSTEMERRORTOSTRING_H +#define MANGOS_IO_SYSTEMERRORTOSTRING_H + +#include + +/** Returns the status code and text description of a system error */ +std::string SystemErrorToString(int nativeSystemErrorCode); + +#endif //MANGOS_IO_SYSTEMERRORTOSTRING_H diff --git a/src/shared/IO/Timer/AsyncSystemTimer.h b/src/shared/IO/Timer/AsyncSystemTimer.h new file mode 100644 index 00000000000..7ff5b8dcc8b --- /dev/null +++ b/src/shared/IO/Timer/AsyncSystemTimer.h @@ -0,0 +1,71 @@ +#ifndef MANGOS_IO_TIMER_ASYNCSYSTEMTIMER_H +#define MANGOS_IO_TIMER_ASYNCSYSTEMTIMER_H + +#include "Common.h" +#include "Log.h" +#include "Policies/Singleton.h" +#include "./TimerHandle.h" + +#if defined(__linux__) || defined(__APPLE__) +#include +#include +#elif defined(WIN32) +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN +#endif + +namespace IO { namespace Timer { + class AsyncSystemTimer : public MaNGOS::Singleton> { + friend IO::Timer::TimerHandle; + public: + explicit AsyncSystemTimer(); + ~AsyncSystemTimer() = default; + AsyncSystemTimer(AsyncSystemTimer const&) = delete; + AsyncSystemTimer& operator=(AsyncSystemTimer const&) = delete; + AsyncSystemTimer(AsyncSystemTimer&&) = delete; + AsyncSystemTimer& operator=(AsyncSystemTimer&&) = delete; + + void RemoveAllTimersAndStopThread(); + + /// Low resolution async clock system clock with ~16ms accuracy. + /// Do not use this function for in-game-logic inside mangosd! + /// Use `player->m_Events.AddEvent` instead. + /// Please lock the necessary resources inside this function + template + std::shared_ptr ScheduleFunctionOnce(std::chrono::duration timeFromNow, std::function const& function) + { + uint64_t milliseconds = std::chrono::duration_cast(timeFromNow).count(); + return this->_ScheduleFunctionOnceMs(milliseconds, function); + } + + private: + std::shared_ptr _ScheduleFunctionOnceMs(uint64_t milliseconds, std::function const& function); + +#if defined(WIN32) + static void CALLBACK _timerQueueTimeoutCallback(PVOID opaquePointer, BOOLEAN _thisVariableIsNotUsedInTimers); + std::mutex m_pendingTimers_mutex; + std::unordered_set> m_pendingTimers; + HANDLE m_nativeTimerQueueHandle; +#elif defined(__linux__) || defined(__APPLE__) + struct InternalTimerEntry { + std::chrono::time_point m_whenToTriggerMe; + std::shared_ptr m_timerHandle; + }; + + void _TimerThreadFunc(); + void _DeleteTimer(TimerHandle* timerHandle); + + std::mutex m_orderedPendingTimer_mutex; + std::deque m_orderedPendingTimer; + + std::condition_variable m_sleepSemaphore; // used to wake up the thread, if something changed at the front() of the timer queue + volatile bool m_threadRunning = true; + std::thread m_thread; +#endif + }; +}} // namespace IO::Timer + +#define sAsyncSystemTimer MaNGOS::Singleton::Instance() + +#endif //MANGOS_IO_TIMER_ASYNCSYSTEMTIMER_H diff --git a/src/shared/IO/Timer/TimerHandle.h b/src/shared/IO/Timer/TimerHandle.h new file mode 100644 index 00000000000..7607f0ab1d5 --- /dev/null +++ b/src/shared/IO/Timer/TimerHandle.h @@ -0,0 +1,27 @@ +#ifndef MANGOS_IO_TIMER_TIMERHANDLE_H +#define MANGOS_IO_TIMER_TIMERHANDLE_H + +#include +#include +#include + +namespace IO { namespace Timer { + class AsyncSystemTimer; + + class TimerHandle : public std::enable_shared_from_this + { + friend IO::Timer::AsyncSystemTimer; + public: + void Cancel(); + private: + explicit TimerHandle(IO::Timer::AsyncSystemTimer* systemTimer, std::function callbackFunction); + + IO::Timer::AsyncSystemTimer* m_asyncSystemTimer = nullptr; + std::function m_callback = nullptr; +#if defined(WIN32) + void* m_nativeTimerHandle = nullptr; +#endif + }; +}} // namespace IO::Timer + +#endif //MANGOS_IO_TIMER_TIMERHANDLE_H diff --git a/src/shared/IO/Timer/impl/unix/AsyncSystemTimer.cpp b/src/shared/IO/Timer/impl/unix/AsyncSystemTimer.cpp new file mode 100644 index 00000000000..d6def3d3486 --- /dev/null +++ b/src/shared/IO/Timer/impl/unix/AsyncSystemTimer.cpp @@ -0,0 +1,97 @@ +#include "../../AsyncSystemTimer.h" +#include "Policies/SingletonImp.h" +#include "IO/Multithreading/CreateThread.h" +#include "Log.h" +#include "IO/SystemErrorToString.h" + +INSTANTIATE_SINGLETON_1(IO::Timer::AsyncSystemTimer); + +IO::Timer::AsyncSystemTimer::AsyncSystemTimer() +{ + m_thread = IO::Multithreading::CreateThread("SystemTimer", [this](){ _TimerThreadFunc(); }); +} + +void IO::Timer::AsyncSystemTimer::RemoveAllTimersAndStopThread() +{ + m_threadRunning = false; + m_sleepSemaphore.notify_all(); + m_thread.join(); +} + +std::shared_ptr IO::Timer::AsyncSystemTimer::_ScheduleFunctionOnceMs(uint64_t milliseconds, std::function const& function) +{ + std::chrono::time_point whenToTriggerMe = std::chrono::system_clock::now() + std::chrono::milliseconds(milliseconds); + std::shared_ptr timerHandle(new IO::Timer::TimerHandle(this, function)); + InternalTimerEntry newEntry { whenToTriggerMe, timerHandle }; + + { + std::lock_guard lock(m_orderedPendingTimer_mutex); + + bool isNewFirstEntry; + if (m_orderedPendingTimer.empty() || (m_orderedPendingTimer.end()->m_whenToTriggerMe < newEntry.m_whenToTriggerMe)) + { // we can just append the new timer to the end + isNewFirstEntry = m_orderedPendingTimer.empty(); + m_orderedPendingTimer.emplace_back(newEntry); + } + else + { // we need to search, where we can insert it + // TODO: binary search it + for (auto it = m_orderedPendingTimer.begin(); it != m_orderedPendingTimer.end(); ++it) + { + if (it->m_whenToTriggerMe > newEntry.m_whenToTriggerMe) + { + auto insertLocation = m_orderedPendingTimer.emplace(it, newEntry); + isNewFirstEntry = insertLocation == m_orderedPendingTimer.begin(); + break; + } + } + } + + if (isNewFirstEntry) + m_sleepSemaphore.notify_all(); // wake the timer thread up + } + + + return newEntry.m_timerHandle; +} + +void IO::Timer::AsyncSystemTimer::_TimerThreadFunc() +{ + std::unique_lock lock(m_orderedPendingTimer_mutex, std::defer_lock); + + while (m_threadRunning) + { + auto now = std::chrono::system_clock::now(); + lock.lock(); + + std::chrono::time_point sleepUntil = m_orderedPendingTimer.empty() + ? std::chrono::time_point::max() // INFINITE SLEEP + : m_orderedPendingTimer.begin()->m_whenToTriggerMe; + + if (sleepUntil <= now) + { // we have something to process RIGHT NOW + auto timerElement = m_orderedPendingTimer.front(); + m_orderedPendingTimer.pop_front(); + lock.unlock(); + timerElement.m_timerHandle->m_callback(); + } + else + { + m_sleepSemaphore.wait_until(lock, sleepUntil); + lock.unlock(); + } + } +} + +void IO::Timer::AsyncSystemTimer::_DeleteTimer(IO::Timer::TimerHandle* timerHandle) +{ + std::lock_guard lock(m_orderedPendingTimer_mutex); + for (auto it = m_orderedPendingTimer.begin(); it != m_orderedPendingTimer.end(); ++it) + { + if (it->m_timerHandle.get() == timerHandle) + { + m_orderedPendingTimer.erase(it); + break; + } + } +} diff --git a/src/shared/IO/Timer/impl/unix/TimerHandle.cpp b/src/shared/IO/Timer/impl/unix/TimerHandle.cpp new file mode 100644 index 00000000000..b008e7e7612 --- /dev/null +++ b/src/shared/IO/Timer/impl/unix/TimerHandle.cpp @@ -0,0 +1,11 @@ +#include "../../AsyncSystemTimer.h" + +IO::Timer::TimerHandle::TimerHandle(IO::Timer::AsyncSystemTimer *systemTimer, std::function callbackFunction) + : m_asyncSystemTimer{systemTimer}, m_callback{std::move(callbackFunction)} +{ +} + +void IO::Timer::TimerHandle::Cancel() +{ + m_asyncSystemTimer->_DeleteTimer(this); +} diff --git a/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.cpp b/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.cpp new file mode 100644 index 00000000000..0e404ba53b5 --- /dev/null +++ b/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.cpp @@ -0,0 +1,92 @@ +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN + +#include "../../AsyncSystemTimer.h" +#include "IO/Multithreading/CreateThread.h" +#include "Log.h" +#include "Policies/SingletonImp.h" +#include "Errors.h" + +INSTANTIATE_SINGLETON_1(IO::Timer::AsyncSystemTimer); + +IO::Timer::AsyncSystemTimer::AsyncSystemTimer() +{ + m_nativeTimerQueueHandle = ::CreateTimerQueue(); + if (!m_nativeTimerQueueHandle) + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "::CreateTimerQueue() failed: %d", GetLastError()); + MANGOS_ASSERT(m_nativeTimerQueueHandle); + + ScheduleFunctionOnce(std::chrono::seconds(0), []() { + // Since we are single threaded, we can rename this Thread, so we know what's up + IO::Multithreading::RenameCurrentThread("SystemTimer"); + }); +} + +void IO::Timer::AsyncSystemTimer::RemoveAllTimersAndStopThread() +{ + HANDLE timerQueueHandle = m_nativeTimerQueueHandle; + m_nativeTimerQueueHandle = nullptr; + if (timerQueueHandle) + { + ::DeleteTimerQueueEx( + timerQueueHandle, + INVALID_HANDLE_VALUE // MSDN: If this parameter is INVALID_HANDLE_VALUE, the function waits for all callback functions to complete before returning. + ); + } + + m_pendingTimers_mutex.lock(); + m_pendingTimers.clear(); + m_pendingTimers_mutex.unlock(); +} + +void CALLBACK IO::Timer::AsyncSystemTimer::_timerQueueTimeoutCallback(PVOID opaquePointer, BOOLEAN _thisVariableIsNotUsedInTimers) +{ + (void)_thisVariableIsNotUsedInTimers; + + auto handleRawSharedPtr = (std::shared_ptr*)opaquePointer; + std::shared_ptr timerHandle = *handleRawSharedPtr; + delete handleRawSharedPtr; + + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.lock(); + bool wasRemovedByMe = timerHandle->m_asyncSystemTimer->m_pendingTimers.erase(timerHandle); + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.unlock(); + if (!wasRemovedByMe) + return; // The timer was already removed, so we don't want to re-execute it again. + + timerHandle->m_callback(); +} + +std::shared_ptr IO::Timer::AsyncSystemTimer::_ScheduleFunctionOnceMs(uint64_t milliseconds, std::function const& function) +{ + MANGOS_ASSERT(this->m_nativeTimerQueueHandle); + + std::shared_ptr timerHandle(new IO::Timer::TimerHandle(this, function)); + timerHandle->m_asyncSystemTimer = this; + timerHandle->m_callback = function; + + // since we are using an opaque pointer model of the kernel here, + // we have to allocate unsafe memory which we will free inside the function + auto handleRawSharedPtr = new std::shared_ptr(timerHandle); + bool wasOkay = ::CreateTimerQueueTimer( + &timerHandle->m_nativeTimerHandle, + m_nativeTimerQueueHandle, + _timerQueueTimeoutCallback, + handleRawSharedPtr, + milliseconds, + 0, // Period = 0: Don't repeat the timer + WT_EXECUTEONLYONCE | WT_EXECUTEINTIMERTHREAD); // Only execute in WT_EXECUTEINTIMERTHREAD (single thread), otherwise we would spam spawn new system threads. + + if (!wasOkay) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "::CreateTimerQueueTimer failed: %d", GetLastError()); + delete handleRawSharedPtr; + return nullptr; + } + + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.lock(); + timerHandle->m_asyncSystemTimer->m_pendingTimers.insert(timerHandle); + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.unlock(); + + return timerHandle; +} diff --git a/src/shared/IO/Timer/impl/windows/TimerHandle.cpp b/src/shared/IO/Timer/impl/windows/TimerHandle.cpp new file mode 100644 index 00000000000..4f83b121acf --- /dev/null +++ b/src/shared/IO/Timer/impl/windows/TimerHandle.cpp @@ -0,0 +1,26 @@ +#include "../../TimerHandle.h" +#include "../../AsyncSystemTimer.h" +#include "Log.h" +#include + +IO::Timer::TimerHandle::TimerHandle(IO::Timer::AsyncSystemTimer *systemTimer, std::function callbackFunction) + : m_asyncSystemTimer{systemTimer}, m_callback{std::move(callbackFunction)} +{ +} + +void IO::Timer::TimerHandle::Cancel() +{ + m_asyncSystemTimer->m_pendingTimers_mutex.lock(); + bool wasRemoved = m_asyncSystemTimer->m_pendingTimers.erase(shared_from_this()); + m_asyncSystemTimer->m_pendingTimers_mutex.unlock(); + if (!wasRemoved) + return; // The timer was already removed, so we don't want to re-execute it again. + + // To avoid race conditions: + // MSDN: If this parameter (last one) is INVALID_HANDLE_VALUE, the function waits for any running timer callback functions to complete before returning. + bool wasOkay = ::DeleteTimerQueueTimer(m_asyncSystemTimer->m_nativeTimerQueueHandle, m_nativeTimerHandle, INVALID_HANDLE_VALUE); + if (!wasOkay) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "::DeleteTimerQueueTimer failed: %d", GetLastError()); + } +} diff --git a/src/shared/IO/Utils.cpp b/src/shared/IO/Utils.cpp new file mode 100644 index 00000000000..99c417cf820 --- /dev/null +++ b/src/shared/IO/Utils.cpp @@ -0,0 +1,16 @@ +#include "Utils.h" + +#if defined(WIN32) +#include +#else +#include +#endif + +uint64_t IO::Utils::GetCurrentProcessId() +{ +#ifdef WIN32 + return ::GetCurrentProcessId(); +#else + return ::getpid(); +#endif +} diff --git a/src/shared/IO/Utils.h b/src/shared/IO/Utils.h new file mode 100644 index 00000000000..594bcf2bd10 --- /dev/null +++ b/src/shared/IO/Utils.h @@ -0,0 +1,10 @@ +#ifndef MANGOS_IO_UTILS_H +#define MANGOS_IO_UTILS_H + +#include + +namespace IO { namespace Utils { + uint64_t GetCurrentProcessId(); +}} // namespace IO::Utils + +#endif // MANGOS_IO_UTILS_H diff --git a/src/shared/IO/Utils_Unix.h b/src/shared/IO/Utils_Unix.h new file mode 100644 index 00000000000..2c2e3601523 --- /dev/null +++ b/src/shared/IO/Utils_Unix.h @@ -0,0 +1,26 @@ +#ifndef MANGOS_IO_NETWORKING_UNIX_LOWLEVELUTIL_H +#define MANGOS_IO_NETWORKING_UNIX_LOWLEVELUTIL_H + +#include +#include "./Networking/NetworkError.h" +#include "./NativeAliases.h" + +namespace IO { namespace Utils +{ + /// Sets a status flag on a handle (for example O_NONBLOCK) + inline IO::NetworkError SetFdStatusFlag(IO::Native::SocketHandle socket, int status) + { + int originalFileStatus = ::fcntl(socket, F_GETFL); + if (originalFileStatus == -1) + return IO::NetworkError(IO::NetworkError::ErrorType::InternalError, errno); + + int newFileStatus = originalFileStatus | status; + int returnVal = ::fcntl(socket, F_SETFL, newFileStatus); + if (returnVal == -1) + return IO::NetworkError(IO::NetworkError::ErrorType::InternalError, errno); + + return IO::NetworkError(IO::NetworkError::ErrorType::NoError); + } +}} // namespace UI::Util + +#endif //MANGOS_IO_NETWORKING_UNIX_LOWLEVELUTIL_H diff --git a/src/shared/Log.cpp b/src/shared/Log.cpp index 8e74560b696..c05d786e0ef 100644 --- a/src/shared/Log.cpp +++ b/src/shared/Log.cpp @@ -24,13 +24,15 @@ #include "Policies/SingletonImp.h" #include "Config/Config.h" #include "Util.h" -#include "ByteBuffer.h" #include "ProgressBar.h" -#include +#include #include +#include -#include "ace/OS_NS_unistd.h" +#if PLATFORM == PLATFORM_WINDOWS +#include +#endif INSTANTIATE_SINGLETON_1(Log); @@ -114,6 +116,7 @@ void Log::OpenWorldLogFiles() logFiles[LOG_GM_CRITICAL] = OpenLogFile("LogFile.CriticalCommands", "gm_critical.log", log_file_timestamp, false); logFiles[LOG_ANTICHEAT] = OpenLogFile("LogFile.Anticheat", "Anticheat.log", log_file_timestamp, false); logFiles[LOG_SCRIPTS] = OpenLogFile("LogFile.Scripts", "Scripts.log", log_file_timestamp, false); + logFiles[LOG_NETWORK] = OpenLogFile("LogFile.Network", "Network.log", log_file_timestamp, false); } void Log::InitSmartlogEntries(std::string const& str) @@ -377,7 +380,8 @@ if (logType != LOG_PERFORMANCE && logType != LOG_DBERRFIX && m_consoleLevel >= l void Log::Out(LogType logType, LogLevel logLevel, char const* format, ...) { - ASSERT(logType >= 0 && logType < LOG_TYPE_MAX&& logLevel >= 0 && logLevel <= LOG_LVL_DEBUG); + if (!(logType >= 0 && logType < LOG_TYPE_MAX&& logLevel >= 0 && logLevel <= LOG_LVL_DEBUG)) + return; if (!format) return; @@ -452,7 +456,7 @@ bool Log::IsSmartLog(uint32 entry, uint32 guid) const void Log::WaitBeforeContinueIfNeed() { - int mode = sConfig.GetIntDefault("WaitAtStartupError", 0); + int mode = sConfig.GetIntDefault("WaitAtStartupError", 5); if (mode < 0) { @@ -468,7 +472,7 @@ void Log::WaitBeforeContinueIfNeed() for (int i = 0; i < mode; ++i) { bar.step(); - ACE_OS::sleep(1); + std::this_thread::sleep_for(std::chrono::seconds(1)); } } } diff --git a/src/shared/Log.h b/src/shared/Log.h index 098f14a85fa..834e3a0157f 100644 --- a/src/shared/Log.h +++ b/src/shared/Log.h @@ -143,6 +143,7 @@ enum LogType LOG_GM_CRITICAL, LOG_ANTICHEAT, LOG_SCRIPTS, + LOG_NETWORK, LOG_TYPE_MAX }; diff --git a/src/shared/Memory/ArrayDeleter.h b/src/shared/Memory/ArrayDeleter.h new file mode 100644 index 00000000000..f67e0a030cd --- /dev/null +++ b/src/shared/Memory/ArrayDeleter.h @@ -0,0 +1,19 @@ +#ifndef MANGOS_ARRAY_DELETER_H +#define MANGOS_ARRAY_DELETER_H + +namespace MaNGOS { namespace Memory +{ + /// A array deleter implementation that can be used in std::shared_ptr + /// In C++14 it is not possible to make a shared_ptr out of an array + /// \example std::shared_ptr mySharedPtr = std::shared_ptr(new uint8_t[1024], MaNGOS::Memory::array_deleter()); + template + struct array_deleter + { + void operator()(T const* p) + { + delete[] p; + } + }; +}} // namespace MaNGOS::Memory + +#endif // MANGOS_ARRAY_DELETER_H diff --git a/src/shared/Memory/NoDeleter.h b/src/shared/Memory/NoDeleter.h new file mode 100644 index 00000000000..18e336fe62b --- /dev/null +++ b/src/shared/Memory/NoDeleter.h @@ -0,0 +1,19 @@ +#ifndef MANGOS_NO_DELETER_H +#define MANGOS_NO_DELETER_H + +namespace MaNGOS { namespace Memory +{ + /// A non deleter implementation that can be used in std::shared_ptr + /// \warning Using this will result in a memory leak, if not managed otherwise! + /// \example std::shared_ptr mySharedPtr = std::shared_ptr(new uint8_t[1024], MaNGOS::Memory::no_deleter()); + template + struct no_deleter + { + void operator()(T const*) + { + // Ignore. No call to `delete` + } + }; +}} // namespace MaNGOS::Memory + +#endif //MANGOS_NO_DELETER_H diff --git a/src/shared/PosixDaemon.cpp b/src/shared/PosixDaemon.cpp index 33806a981bc..26475f49a1b 100644 --- a/src/shared/PosixDaemon.cpp +++ b/src/shared/PosixDaemon.cpp @@ -21,13 +21,14 @@ #include #include #include +#include +#include pid_t parent_pid = 0, sid = 0; void daemonSignal(int s) { - - if (getpid() != parent_pid) + if (::getpid() != parent_pid) { return; } @@ -47,7 +48,7 @@ void daemonSignal(int s) void startDaemon(uint32_t timeout) { - parent_pid = getpid(); + parent_pid = ::getpid(); pid_t pid; signal(SIGUSR1, daemonSignal); @@ -69,7 +70,7 @@ void startDaemon(uint32_t timeout) exit(EXIT_FAILURE); } - umask(0); + ::umask(0); sid = setsid(); diff --git a/src/shared/ProgressBar.cpp b/src/shared/ProgressBar.cpp index 0f557532a59..bab3e08a0e2 100644 --- a/src/shared/ProgressBar.cpp +++ b/src/shared/ProgressBar.cpp @@ -20,6 +20,7 @@ */ #include +#include #include "ProgressBar.h" #include "Errors.h" @@ -40,21 +41,21 @@ BarGoLink::BarGoLink(int row_count) BarGoLink::BarGoLink(uint32 row_count) { - MANGOS_ASSERT(row_count < (uint32)ACE_INT32_MAX); - init((int)row_count); + MANGOS_ASSERT(row_count < std::numeric_limits::max()); + init(static_cast(row_count)); } BarGoLink::BarGoLink(uint64 row_count) { - MANGOS_ASSERT(row_count < (uint64)ACE_INT32_MAX); - init((int)row_count); + MANGOS_ASSERT(row_count < std::numeric_limits::max()); + init(static_cast(row_count)); } #ifdef __APPLE__ BarGoLink::BarGoLink(size_t row_count) { - //MANGOS_ASSERT(row_count < (uint64)ACE_INT32_MAX); - init((int)row_count); + MANGOS_ASSERT(row_count < std::numeric_limits::max()); + init(static_cast(row_count)); } #endif diff --git a/src/shared/ProgressBar.h b/src/shared/ProgressBar.h index 2472d214325..787ab69cbd4 100644 --- a/src/shared/ProgressBar.h +++ b/src/shared/ProgressBar.h @@ -24,17 +24,14 @@ #include "Platform/Define.h" -// Nostalrius : pour SD0. -#define barGoLink BarGoLink - class BarGoLink { public: // constructors explicit BarGoLink(int row_count); - explicit BarGoLink(uint32 row_count); // row_count < ACE_INT32_MAX - explicit BarGoLink(uint64 row_count); // row_count < ACE_INT32_MAX + explicit BarGoLink(uint32 row_count); // row_count < int32::max + explicit BarGoLink(uint64 row_count); // row_count < int32::max #ifdef __APPLE__ - explicit BarGoLink(size_t row_count); + explicit BarGoLink(size_t row_count); // row_count < int32::max #endif ~BarGoLink(); diff --git a/src/shared/ProxyProtocol/ProxyV2Reader.cpp b/src/shared/ProxyProtocol/ProxyV2Reader.cpp new file mode 100644 index 00000000000..5cde76ad08f --- /dev/null +++ b/src/shared/ProxyProtocol/ProxyV2Reader.cpp @@ -0,0 +1,120 @@ +#include "ProxyV2Reader.h" +#include "Log.h" + +#include + +#if defined(__linux__) || defined(__APPLE__) +#include // for ntohs +#endif + +// GCC have alternative #pragma pack(N) syntax and old gcc version not support pack(push,N), also any gcc version not support it at some platform +#if defined( __GNUC__ ) +#pragma pack(1) +#else +#pragma pack(push,1) +#endif + +struct proxy_hdr_v2 { + uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ + uint8_t ver_cmd; /* protocol version and command */ + uint8_t fam; /* protocol family and address */ + uint16_t len; /* number of following bytes part of the header */ +}; + +union proxy_addr { + struct { /* for TCP/UDP over IPv4, len = 12 */ + uint32_t src_addr; + uint32_t dst_addr; + uint16_t src_port; + uint16_t dst_port; + } ipv4_addr; + struct { /* for TCP/UDP over IPv6, len = 36 */ + uint8_t src_addr[16]; + uint8_t dst_addr[16]; + uint16_t src_port; + uint16_t dst_port; + } ipv6_addr; + struct { /* for AF_UNIX sockets, len = 216 */ + uint8_t src_addr[108]; + uint8_t dst_addr[108]; + } unix_addr; +}; + +// GCC have alternative #pragma pack() syntax and old gcc version not support pack(pop), also any gcc version not support it at some platform +#if defined( __GNUC__ ) +#pragma pack() +#else +#pragma pack(pop) +#endif + + +// Read protocol doc at https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt (2.2. Binary header format (version 2)). +void ProxyProtocol::ReadProxyV2Handshake(IO::Networking::AsyncSocket* socket, std::function const&)> const& callback) +{ + std::shared_ptr proxyHeader(new proxy_hdr_v2()); + socket->Read((char*)(proxyHeader.get()), sizeof(proxyHeader), [socket, callback, proxyHeader](IO::NetworkError const& error, size_t) + { + if (error) + { + callback(nonstd::make_unexpected(error)); + return; + } + + constexpr uint8_t expectedSignature[12] = { 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A }; + // Check if we have a valid signature + if (std::memcmp(proxyHeader->sig, expectedSignature, sizeof(expectedSignature)) != 0) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ProxyV2 invalid signature"); + callback(nonstd::make_unexpected(IO::NetworkError(IO::NetworkError::ErrorType::InvalidProtocolBehavior))); + return; + } + + // Check version + if ((proxyHeader->ver_cmd >> 4) != 2) // we only support proxy v2 + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ProxyV2 unexpected version"); + callback(nonstd::make_unexpected(IO::NetworkError(IO::NetworkError::ErrorType::InvalidProtocolBehavior))); + return; + } + + // Check command + constexpr uint8_t PROXY_V2_CMD_PROXY = 1; + if ((proxyHeader->ver_cmd & 0x0F) != PROXY_V2_CMD_PROXY) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ProxyV2 unexpected cmd"); + callback(nonstd::make_unexpected(IO::NetworkError(IO::NetworkError::ErrorType::InvalidProtocolBehavior))); + return; + } + + // Check if we have IPv4_TCP which is currently the only one supported by the vanilla client + constexpr uint8_t PROXY_V2_TCP_OVER_IPV4 = 0x11; + if (proxyHeader->fam != PROXY_V2_TCP_OVER_IPV4) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ProxyV2 unexpected familyAndProtocol"); + callback(nonstd::make_unexpected(IO::NetworkError(IO::NetworkError::ErrorType::InvalidProtocolBehavior))); + return; + } + + // Unexpected body size + uint16_t headerLength = ntohs(proxyHeader->len); + if (headerLength != sizeof(proxy_addr::ipv4_addr)) + { + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "ProxyV2 unexpected body size"); + callback(nonstd::make_unexpected(IO::NetworkError(IO::NetworkError::ErrorType::InvalidProtocolBehavior))); + return; + } + + // Now we can read the actual payload + std::shared_ptr addressBody(new proxy_addr()); + socket->Read((char*)(addressBody.get()), headerLength, [callback, addressBody](IO::NetworkError const& error, size_t) + { + if (error) + { + callback(nonstd::make_unexpected(error)); + return; + } + IO::Networking::IpAddress ipAddress = IO::Networking::IpAddress::FromIpv4Uint32(ntohl(addressBody->ipv4_addr.src_addr)); + callback(ipAddress); + }); + }); +} diff --git a/src/shared/ProxyProtocol/ProxyV2Reader.h b/src/shared/ProxyProtocol/ProxyV2Reader.h new file mode 100644 index 00000000000..9cde1a19558 --- /dev/null +++ b/src/shared/ProxyProtocol/ProxyV2Reader.h @@ -0,0 +1,17 @@ +#ifndef MANGOS_PROXYPROTOCOL_PROXYV2READER_H +#define MANGOS_PROXYPROTOCOL_PROXYV2READER_H + +#include "IO/Networking/AsyncSocket.h" + +#include "nonstd/expected.hpp" + +namespace ProxyProtocol +{ + /// Allows you to receive the correct client IP via proxy protocol V2. + /// If you are using HaProxy you can enable it with the "send-proxy-v2" at the "backend server". + /// This function must be called before any other Read() call, because its sent by the proxy fist. + /// \warning You have to verify if the socket is from a trusted proxy IP! + void ReadProxyV2Handshake(IO::Networking::AsyncSocket* socket, std::function const&)> const& callback); +} + +#endif //MANGOS_PROXYPROTOCOL_PROXYV2READER_H diff --git a/src/shared/ServiceWin32.cpp b/src/shared/ServiceWin32.cpp index 622aad9ef74..ac8671d42d1 100644 --- a/src/shared/ServiceWin32.cpp +++ b/src/shared/ServiceWin32.cpp @@ -24,6 +24,7 @@ #include "Common.h" #include "Log.h" #include +#include #include #if !defined(WINADVAPI) @@ -34,17 +35,12 @@ #endif #endif - -#ifdef main -#undef main -#endif - extern int main(int argc, char** argv); extern char serviceLongName[]; extern char serviceName[]; extern char serviceDescription[]; -extern int m_ServiceStatus; +extern volatile int m_ServiceStatus; SERVICE_STATUS serviceStatus; @@ -275,7 +271,7 @@ bool WinServiceRun() SERVICE_TABLE_ENTRY serviceTable[] = { { serviceName, ServiceMain }, - { 0, 0 } + { nullptr, nullptr } }; if (!StartServiceCtrlDispatcher(serviceTable)) diff --git a/src/shared/SystemConfig.h b/src/shared/SystemConfig.h index 507c1cedc27..1c24fd4075e 100644 --- a/src/shared/SystemConfig.h +++ b/src/shared/SystemConfig.h @@ -28,13 +28,10 @@ // Format is YYYYMMDDRR where RR is the change in the conf file // for that day. #ifndef _MANGOSDCONFVERSION -# define _MANGOSDCONFVERSION 2010100901 +# define _MANGOSDCONFVERSION 2024091701 #endif #ifndef _REALMDCONFVERSION -# define _REALMDCONFVERSION 2020010501 -#endif -#ifndef _MODSCONFVERSION -# define _MODSCONFVERSION 2010062001 +# define _REALMDCONFVERSION 2024091701 #endif #if MANGOS_ENDIAN == MANGOS_BIGENDIAN @@ -70,7 +67,6 @@ # endif # define _MANGOSD_CONFIG SYSCONFDIR "mangosd.conf" # define _REALMD_CONFIG SYSCONFDIR "realmd.conf" -# define _MODS_CONFIG SYSCONFDIR "mods.conf" #else # if defined (__FreeBSD__) # define _ENDIAN_PLATFORM "FreeBSD_" ARCHITECTURE " (" _ENDIAN_STRING ")" @@ -89,7 +85,6 @@ # endif # define _MANGOSD_CONFIG SYSCONFDIR "mangosd.conf" # define _REALMD_CONFIG SYSCONFDIR "realmd.conf" -# define _MODS_CONFIG SYSCONFDIR "mods.conf" #endif #define _FULLVERSION REVISION_HASH " / " REVISION_DATE " / " _ENDIAN_PLATFORM diff --git a/src/shared/ThreadPool.cpp b/src/shared/ThreadPool.cpp index f039b386091..3cb2c8ae176 100644 --- a/src/shared/ThreadPool.cpp +++ b/src/shared/ThreadPool.cpp @@ -18,6 +18,7 @@ #include "ThreadPool.h" #include "Log.h" +#include "IO/Multithreading/CreateThread.h" #include #ifdef WIN32 @@ -25,8 +26,8 @@ #undef IGNORE #endif -ThreadPool::ThreadPool(int numThreads, ClearMode when, ErrorHandling mode) : - m_errorHandling(mode), m_size(numThreads), m_clearMode(when), m_active(0) +ThreadPool::ThreadPool(std::string const& name, int numThreads, ClearMode when, ErrorHandling mode) : + m_poolName(name), m_errorHandling(mode), m_size(numThreads), m_clearMode(when), m_active(0) { m_workers.reserve(m_size); } @@ -114,7 +115,7 @@ void ThreadPool::clearWorkload() } ThreadPool::worker::worker(ThreadPool *pool, int id, ThreadPool::ErrorHandling mode) : - id(id), errorHandling(mode), pool(pool), thread([this](){this->loop_wrapper();}) + id(id), errorHandling(mode), pool(pool), thread(IO::Multithreading::CreateThread(pool->m_poolName + "[" + std::to_string(id) + "]", [this](){this->loop_wrapper();})) { } diff --git a/src/shared/ThreadPool.h b/src/shared/ThreadPool.h index 261e07a4398..8367e930049 100644 --- a/src/shared/ThreadPool.h +++ b/src/shared/ThreadPool.h @@ -80,11 +80,10 @@ class ThreadPool /** * @brief ThreadPool allocates memory, use ThreadPool::start() to spawn the threads. + * @param name the name of the pool * @param numThreads the number of threads that will be created. */ - ThreadPool(int numThreads, ClearMode when = ClearMode::AT_NEXT_WORKLOAD, ErrorHandling mode = ErrorHandling::NONE); - - ThreadPool() = delete; + explicit ThreadPool(std::string const& name, int numThreads, ClearMode when = ClearMode::AT_NEXT_WORKLOAD, ErrorHandling mode = ErrorHandling::NONE); ~ThreadPool(); @@ -150,7 +149,7 @@ class ThreadPool private: struct worker { - worker(ThreadPool *pool, int id, ErrorHandling mode); + explicit worker(ThreadPool *pool, int id, ErrorHandling mode); virtual ~worker(); void loop_wrapper(); @@ -194,6 +193,7 @@ class ThreadPool using workers_t = std::vector>; Status m_status = Status::STOPPED; + std::string m_poolName; ErrorHandling m_errorHandling; size_t m_size; std::shared_timed_mutex m_mutex; diff --git a/src/shared/ThreadSpecificPtr.cpp b/src/shared/ThreadSpecificPtr.cpp new file mode 100644 index 00000000000..203ede8ce51 --- /dev/null +++ b/src/shared/ThreadSpecificPtr.cpp @@ -0,0 +1,13 @@ +#include "./ThreadSpecificPtr.h" +#include "Errors.h" + +thread_local MaNGOS::ThreadSpecificHolder MaNGOS::gtl_ThreadSpecificPtrHolder; + +MaNGOS::ThreadSpecificHolder::~ThreadSpecificHolder() +{ + for (auto it : thread_specific_ptr_data) + { + // Every reference must be deleted before the holder is deallocated, otherwise memory will leak + MANGOS_ASSERT(it.second == nullptr); + } +} diff --git a/src/shared/ThreadSpecificPtr.h b/src/shared/ThreadSpecificPtr.h new file mode 100644 index 00000000000..5a52159935f --- /dev/null +++ b/src/shared/ThreadSpecificPtr.h @@ -0,0 +1,97 @@ +/* + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#ifndef MANGOS_THREAD_SPECIFIC_PTR_H_ +#define MANGOS_THREAD_SPECIFIC_PTR_H_ + +#include "Policies/ObjectConstructorTraits.h" + +#include +#include + +namespace MaNGOS +{ + struct ThreadSpecificHolder + { + ~ThreadSpecificHolder(); + std::map thread_specific_ptr_data; + }; + + extern thread_local ThreadSpecificHolder gtl_ThreadSpecificPtrHolder; + + /// This class is the same as `boost::thread_specific_ptr` or `ACE_TSS` + /// Since you cant use `thread_local` on a non static class member, + /// this class allows you to have thread specific storage on instances bases. + /// Meaning instanceA and instanceB on the same thread, have a different object attached to it. + template + class ThreadSpecificPtr : public MaNGOS::Policies::NoCopyNoMove + { + public: + ThreadSpecificPtr() = default; + + // Important: We cannot `this->reset()` our pointer, because the holder might already be deallocated when ThreadSpecificPtr was used in a global context + // We have to trust that the user cleaned up everything (otherwise an ASSERT() in `ThreadSpecificHolder` will fail) + ~ThreadSpecificPtr() = default; + + T* get() const + { + auto it = gtl_ThreadSpecificPtrHolder.thread_specific_ptr_data.find(const_cast(static_cast(this))); + if (it != gtl_ThreadSpecificPtrHolder.thread_specific_ptr_data.end()) + return static_cast(it->second); + return nullptr; + } + + T* operator->() const + { + return get(); + } + /* + T& operator*() const + { + return *(get()); // I hope its not nullptr + } + */ + + /// Releases the pointer. You have ownership and have to delete it! + T* release() + { + auto it = gtl_ThreadSpecificPtrHolder.thread_specific_ptr_data.find(this); + if (it == gtl_ThreadSpecificPtrHolder.thread_specific_ptr_data.end()) + return nullptr; + + T* ptr = static_cast(it->second); + it->second = nullptr; + return ptr; + } + + /// This function will delete the exising pointer + void reset(T* new_value = nullptr) + { + auto it = gtl_ThreadSpecificPtrHolder.thread_specific_ptr_data.find(this); + if (it != gtl_ThreadSpecificPtrHolder.thread_specific_ptr_data.end()) + { + if (it->second != nullptr) + delete static_cast(it->second); + it->second = new_value; + } + else + { + gtl_ThreadSpecificPtrHolder.thread_specific_ptr_data.insert(it, std::make_pair((void*)this, (void*)new_value)); + } + } + }; +} + +#endif //MANGOS_THREAD_SPECIFIC_PTR_H_ diff --git a/src/shared/Timer.h b/src/shared/Timer.h index f847403b681..fb397de1a7b 100644 --- a/src/shared/Timer.h +++ b/src/shared/Timer.h @@ -23,7 +23,16 @@ #define MANGOS_TIMER_H #include "Common.h" -#include +#include + +inline std::chrono::steady_clock::time_point GetApplicationStartTime() +{ + using namespace std::chrono; + + static const steady_clock::time_point ApplicationStartTime = steady_clock::now(); + + return ApplicationStartTime; +} class WorldTimer { @@ -57,9 +66,6 @@ class WorldTimer WorldTimer(); WorldTimer(WorldTimer const&); - //analogue to getMSTime() but it persists m_SystemTickTime - static uint32 getMSTime_internal(); - static uint32 m_iTime; static uint32 m_iPrevTime; }; diff --git a/src/shared/Util.cpp b/src/shared/Util.cpp index 8eb0c076bd5..6cc53492719 100644 --- a/src/shared/Util.cpp +++ b/src/shared/Util.cpp @@ -22,16 +22,21 @@ #include "Util.h" #include "Timer.h" #include "Log.h" +#include "Errors.h" + +#include "IO/Utils.h" +#include "IO/Networking/IpAddress.h" #include "utf8cpp/utf8.h" #include "mersennetwister/MersenneTwister.h" -#include -#include +#include -typedef ACE_TSS MTRandTSS; -static MTRandTSS mtRand; +#if PLATFORM == PLATFORM_WINDOWS +#include +#endif +thread_local MTRand mtRand; Tokenizer::Tokenizer(std::string const& src, char const sep, uint32 vectorReserve) { @@ -67,8 +72,6 @@ Tokenizer::Tokenizer(std::string const& src, char const sep, uint32 vectorReserv } } -static ACE_Time_Value g_SystemTickTime = ACE_OS::gettimeofday(); - uint32 WorldTimer::m_iTime = 0; uint32 WorldTimer::m_iPrevTime = 0; @@ -81,7 +84,7 @@ uint32 WorldTimer::tick() m_iPrevTime = m_iTime; //get the new one and don't forget to persist current system time in m_SystemTickTime - m_iTime = WorldTimer::getMSTime_internal(); + m_iTime = WorldTimer::getMSTime(); //return tick diff return getMSTimeDiff(m_iPrevTime, m_iTime); @@ -89,63 +92,50 @@ uint32 WorldTimer::tick() uint32 WorldTimer::getMSTime() { - return getMSTime_internal(); -} - -uint32 WorldTimer::getMSTime_internal() -{ - //get current time - ACE_Time_Value const currTime = ACE_OS::gettimeofday(); - //calculate time diff between two world ticks - //special case: curr_time < old_time - we suppose that our time has not ticked at all - //this should be constant value otherwise it is possible that our time can start ticking backwards until next world tick!!! - uint64 diff = 0; - (currTime - g_SystemTickTime).msec(diff); + using namespace std::chrono; - //lets calculate current world time - uint32 iRes = uint32(diff % UI64LIT(0x00000000FFFFFFFF)); - return iRes; + return static_cast(duration_cast(steady_clock::now().time_since_epoch() - GetApplicationStartTime().time_since_epoch()).count()); } ////////////////////////////////////////////////////////////////////////// int32 irand (int32 min, int32 max) { - return int32 (mtRand->randInt (max - min)) + min; + return int32 (mtRand.randInt(max - min)) + min; } uint32 urand (uint32 min, uint32 max) { - return mtRand->randInt (max - min) + min; + return mtRand.randInt(max - min) + min; } float frand (float min, float max) { - return mtRand->randExc (max - min) + min; + return mtRand.randExc (max - min) + min; } int32 rand32 () { - return mtRand->randInt (); + return mtRand.randInt (); } double rand_norm(void) { - return mtRand->randExc (); + return mtRand.randExc (); } float rand_norm_f(void) { - return (float)mtRand->randExc (); + return (float)mtRand.randExc (); } double rand_chance (void) { - return mtRand->randExc (100.0); + return mtRand.randExc (100.0); } float rand_chance_f(void) { - return (float)mtRand->randExc (100.0); + return (float)mtRand.randExc (100.0); } Milliseconds randtime(Milliseconds const& min, Milliseconds const& max) @@ -357,15 +347,14 @@ std::string TimeToTimestampStr(time_t t) return std::string(buf); } -// Check if the string is a valid ip address representation -bool IsIPAddress(char const* ipaddress) +/// Check if the string is a valid ip address representation +bool IsIPAddress(char const* ipAddressString) { - if(!ipaddress) + if (!ipAddressString) return false; - // Let the big boys do it. - // Drawback: all valid ip address formats are recognized e.g.: 12.23,121234,0xABCD) - return ACE_OS::inet_addr(ipaddress) != INADDR_NONE; + auto result = IO::Networking::IpAddress::TryParseFromString(ipAddressString); + return result.has_value(); } // create PID file @@ -375,11 +364,7 @@ uint32 CreatePIDFile(std::string const& filename) if (pid_file == nullptr) return 0; -#ifdef WIN32 - DWORD pid = GetCurrentProcessId(); -#else - pid_t pid = getpid(); -#endif + int pid = IO::Utils::GetCurrentProcessId(); fprintf(pid_file, "%lu", pid); fclose(pid_file); @@ -471,7 +456,7 @@ bool utf8ToConsole(std::string const& utf8str, std::string& conStr) conStr.resize(wstr.size()); CharToOemBuffW(&wstr[0], &conStr[0], wstr.size()); #else - // not implemented yet + // On Linux/MacOS, typically no conversion is needed for UTF-8 strings conStr = utf8str; #endif @@ -487,7 +472,7 @@ bool consoleToUtf8(std::string const& conStr, std::string& utf8str) return WStrToUtf8(wstr, utf8str); #else - // not implemented yet + // On Linux/MacOS, typically no conversion is needed for UTF-8 strings utf8str = conStr; return true; #endif @@ -658,3 +643,20 @@ std::string FlagsToString(uint32 flags, ValueToStringFunc getNameFunc) } return names; } +std::vector SplitStringByDelimiter(std::string const& str, char delimiter) +{ + std::vector vec; + std::size_t old_pos = 0; + std::size_t pos = 0; + while((pos = str.find_first_of(delimiter, old_pos)) != std::string::npos) { + vec.emplace_back(str.substr(old_pos, pos - old_pos)); + old_pos = pos + 1; + } + + // add last element + std::string stringPart = str.substr(old_pos); + if (!stringPart.empty()) + vec.emplace_back(stringPart); + + return vec; +} diff --git a/src/shared/Util.h b/src/shared/Util.h index 4642e882ff4..3859caec210 100644 --- a/src/shared/Util.h +++ b/src/shared/Util.h @@ -417,7 +417,7 @@ bool Utf8FitTo(std::string const& str, std::wstring search); void utf8printf(FILE* out, char const* str, ...); void vutf8printf(FILE* out, char const* str, va_list* ap); -bool IsIPAddress(char const* ipaddress); +bool IsIPAddress(char const* ipAddressString); uint32 CreatePIDFile(std::string const& filename); void hexEncodeByteArray(uint8* bytes, uint32 arrayLen, std::string& result); @@ -452,5 +452,6 @@ inline float InterpolateValueAtIndex(float startIndex, float startValue, float e return startValue + GetLambda(startIndex, endIndex, currentIndex) * (endValue - startValue); } +std::vector SplitStringByDelimiter(std::string const& str, char delimiter); #endif diff --git a/src/shared/WorldPacket.h b/src/shared/WorldPacket.h index 6545aa3afaf..0f30b202e01 100644 --- a/src/shared/WorldPacket.h +++ b/src/shared/WorldPacket.h @@ -40,6 +40,7 @@ class WorldPacket : public ByteBuffer { } + // TODO this std::move() is technically illegal when we want to access packet.m_opcode and m_recvdTime. (maybe make a protected .ctor with just a reference but it std::std move(buf._storage)) WorldPacket(WorldPacket&& packet) : ByteBuffer(std::move(packet)), m_opcode(packet.m_opcode), m_recvdTime(packet.m_recvdTime) { } diff --git a/src/shared/nonstd/expected.hpp b/src/shared/nonstd/expected.hpp new file mode 100644 index 00000000000..235697d0a75 --- /dev/null +++ b/src/shared/nonstd/expected.hpp @@ -0,0 +1,2445 @@ +// SOURCE https://github.com/TartanLlama/expected/blob/292eff8bd8ee230a7df1d6a1c00c4ea0eb2f0362/include/tl/expected.hpp +// +// expected - An implementation of std::expected with extensions +// Written in 2017 by Sy Brand (tartanllama@gmail.com, @TartanLlama) +// +// Documentation available at http://tl.tartanllama.xyz/ +// +// To the extent possible under law, the author(s) have dedicated all +// copyright and related and neighboring rights to this software to the +// public domain worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication +// along with this software. If not, see +// . +// + +#ifndef TL_EXPECTED_HPP +#define TL_EXPECTED_HPP + +#define TL_EXPECTED_VERSION_MAJOR 1 +#define TL_EXPECTED_VERSION_MINOR 1 +#define TL_EXPECTED_VERSION_PATCH 0 + +#include +#include +#include +#include + +#if defined(__EXCEPTIONS) || defined(_CPPUNWIND) +#define TL_EXPECTED_EXCEPTIONS_ENABLED +#endif + +#if (defined(_MSC_VER) && _MSC_VER == 1900) +#define TL_EXPECTED_MSVC2015 +#define TL_EXPECTED_MSVC2015_CONSTEXPR +#else +#define TL_EXPECTED_MSVC2015_CONSTEXPR constexpr +#endif + +#if (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ <= 9 && \ + !defined(__clang__)) +#define TL_EXPECTED_GCC49 +#endif + +#if (defined(__GNUC__) && __GNUC__ == 5 && __GNUC_MINOR__ <= 4 && \ + !defined(__clang__)) +#define TL_EXPECTED_GCC54 +#endif + +#if (defined(__GNUC__) && __GNUC__ == 5 && __GNUC_MINOR__ <= 5 && \ + !defined(__clang__)) +#define TL_EXPECTED_GCC55 +#endif + +#if !defined(TL_ASSERT) +//can't have assert in constexpr in C++11 and GCC 4.9 has a compiler bug +#if (__cplusplus > 201103L) && !defined(TL_EXPECTED_GCC49) +#include +#define TL_ASSERT(x) assert(x) +#else +#define TL_ASSERT(x) +#endif +#endif + +#if (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ <= 9 && \ + !defined(__clang__)) +// GCC < 5 doesn't support overloading on const&& for member functions + +#define TL_EXPECTED_NO_CONSTRR +// GCC < 5 doesn't support some standard C++11 type traits +#define TL_EXPECTED_IS_TRIVIALLY_COPY_CONSTRUCTIBLE(T) \ + std::has_trivial_copy_constructor +#define TL_EXPECTED_IS_TRIVIALLY_COPY_ASSIGNABLE(T) \ + std::has_trivial_copy_assign + +// This one will be different for GCC 5.7 if it's ever supported +#define TL_EXPECTED_IS_TRIVIALLY_DESTRUCTIBLE(T) \ + std::is_trivially_destructible + +// GCC 5 < v < 8 has a bug in is_trivially_copy_constructible which breaks +// std::vector for non-copyable types +#elif (defined(__GNUC__) && __GNUC__ < 8 && !defined(__clang__)) +#ifndef TL_GCC_LESS_8_TRIVIALLY_COPY_CONSTRUCTIBLE_MUTEX +#define TL_GCC_LESS_8_TRIVIALLY_COPY_CONSTRUCTIBLE_MUTEX +namespace nonstd { +namespace detail { +template +struct is_trivially_copy_constructible + : std::is_trivially_copy_constructible {}; +#ifdef _GLIBCXX_VECTOR +template +struct is_trivially_copy_constructible> : std::false_type {}; +#endif +} // namespace detail +} // namespace nonstd +#endif + +#define TL_EXPECTED_IS_TRIVIALLY_COPY_CONSTRUCTIBLE(T) \ + nonstd::detail::is_trivially_copy_constructible +#define TL_EXPECTED_IS_TRIVIALLY_COPY_ASSIGNABLE(T) \ + std::is_trivially_copy_assignable +#define TL_EXPECTED_IS_TRIVIALLY_DESTRUCTIBLE(T) \ + std::is_trivially_destructible +#else +#define TL_EXPECTED_IS_TRIVIALLY_COPY_CONSTRUCTIBLE(T) \ + std::is_trivially_copy_constructible +#define TL_EXPECTED_IS_TRIVIALLY_COPY_ASSIGNABLE(T) \ + std::is_trivially_copy_assignable +#define TL_EXPECTED_IS_TRIVIALLY_DESTRUCTIBLE(T) \ + std::is_trivially_destructible +#endif + +#if __cplusplus > 201103L +#define TL_EXPECTED_CXX14 +#endif + +#ifdef TL_EXPECTED_GCC49 +#define TL_EXPECTED_GCC49_CONSTEXPR +#else +#define TL_EXPECTED_GCC49_CONSTEXPR constexpr +#endif + +#if (__cplusplus == 201103L || defined(TL_EXPECTED_MSVC2015) || \ + defined(TL_EXPECTED_GCC49)) +#define TL_EXPECTED_11_CONSTEXPR +#else +#define TL_EXPECTED_11_CONSTEXPR constexpr +#endif + +namespace nonstd { +template class expected; + +#ifndef TL_MONOSTATE_INPLACE_MUTEX +#define TL_MONOSTATE_INPLACE_MUTEX +class monostate {}; + +struct in_place_t { + explicit in_place_t() = default; +}; +static constexpr in_place_t in_place{}; +#endif + +template class unexpected { +public: + static_assert(!std::is_same::value, "E must not be void"); + + unexpected() = delete; + constexpr explicit unexpected(const E &e) : m_val(e) {} + + constexpr explicit unexpected(E &&e) : m_val(std::move(e)) {} + + template ::value>::type * = nullptr> + constexpr explicit unexpected(Args &&...args) + : m_val(std::forward(args)...) {} + template < + class U, class... Args, + typename std::enable_if &, Args &&...>::value>::type * = nullptr> + constexpr explicit unexpected(std::initializer_list l, Args &&...args) + : m_val(l, std::forward(args)...) {} + + constexpr const E &value() const & { return m_val; } + TL_EXPECTED_11_CONSTEXPR E &value() & { return m_val; } + TL_EXPECTED_11_CONSTEXPR E &&value() && { return std::move(m_val); } + constexpr const E &&value() const && { return std::move(m_val); } + +private: + E m_val; +}; + +#ifdef __cpp_deduction_guides +template unexpected(E) -> unexpected; +#endif + +template +constexpr bool operator==(const unexpected &lhs, const unexpected &rhs) { + return lhs.value() == rhs.value(); +} +template +constexpr bool operator!=(const unexpected &lhs, const unexpected &rhs) { + return lhs.value() != rhs.value(); +} +template +constexpr bool operator<(const unexpected &lhs, const unexpected &rhs) { + return lhs.value() < rhs.value(); +} +template +constexpr bool operator<=(const unexpected &lhs, const unexpected &rhs) { + return lhs.value() <= rhs.value(); +} +template +constexpr bool operator>(const unexpected &lhs, const unexpected &rhs) { + return lhs.value() > rhs.value(); +} +template +constexpr bool operator>=(const unexpected &lhs, const unexpected &rhs) { + return lhs.value() >= rhs.value(); +} + +template +unexpected::type> make_unexpected(E &&e) { + return unexpected::type>(std::forward(e)); +} + +struct unexpect_t { + unexpect_t() = default; +}; +static constexpr unexpect_t unexpect{}; + +namespace detail { +template +[[noreturn]] TL_EXPECTED_11_CONSTEXPR void throw_exception(E &&e) { +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + throw std::forward(e); +#else + (void)e; +#ifdef _MSC_VER + __assume(0); +#else + __builtin_unreachable(); +#endif +#endif +} + +#ifndef TL_TRAITS_MUTEX +#define TL_TRAITS_MUTEX +// C++14-style aliases for brevity +template using remove_const_t = typename std::remove_const::type; +template +using remove_reference_t = typename std::remove_reference::type; +template using decay_t = typename std::decay::type; +template +using enable_if_t = typename std::enable_if::type; +template +using conditional_t = typename std::conditional::type; + +// std::conjunction from C++17 +template struct conjunction : std::true_type {}; +template struct conjunction : B {}; +template +struct conjunction + : std::conditional, B>::type {}; + +#if defined(_LIBCPP_VERSION) && __cplusplus == 201103L +#define TL_TRAITS_LIBCXX_MEM_FN_WORKAROUND +#endif + +// In C++11 mode, there's an issue in libc++'s std::mem_fn +// which results in a hard-error when using it in a noexcept expression +// in some cases. This is a check to workaround the common failing case. +#ifdef TL_TRAITS_LIBCXX_MEM_FN_WORKAROUND +template +struct is_pointer_to_non_const_member_func : std::false_type {}; +template +struct is_pointer_to_non_const_member_func + : std::true_type {}; +template +struct is_pointer_to_non_const_member_func + : std::true_type {}; +template +struct is_pointer_to_non_const_member_func + : std::true_type {}; +template +struct is_pointer_to_non_const_member_func + : std::true_type {}; +template +struct is_pointer_to_non_const_member_func + : std::true_type {}; +template +struct is_pointer_to_non_const_member_func + : std::true_type {}; + +template struct is_const_or_const_ref : std::false_type {}; +template struct is_const_or_const_ref : std::true_type {}; +template struct is_const_or_const_ref : std::true_type {}; +#endif + +// std::invoke from C++17 +// https://stackoverflow.com/questions/38288042/c11-14-invoke-workaround +template < + typename Fn, typename... Args, +#ifdef TL_TRAITS_LIBCXX_MEM_FN_WORKAROUND + typename = enable_if_t::value && + is_const_or_const_ref::value)>, +#endif + typename = enable_if_t>::value>, int = 0> +constexpr auto invoke(Fn &&f, Args &&...args) noexcept( + noexcept(std::mem_fn(f)(std::forward(args)...))) + -> decltype(std::mem_fn(f)(std::forward(args)...)) { + return std::mem_fn(f)(std::forward(args)...); +} + +template >::value>> +constexpr auto invoke(Fn &&f, Args &&...args) noexcept( + noexcept(std::forward(f)(std::forward(args)...))) + -> decltype(std::forward(f)(std::forward(args)...)) { + return std::forward(f)(std::forward(args)...); +} + +// std::invoke_result from C++17 +template struct invoke_result_impl; + +template +struct invoke_result_impl< + F, + decltype(detail::invoke(std::declval(), std::declval()...), void()), + Us...> { + using type = + decltype(detail::invoke(std::declval(), std::declval()...)); +}; + +template +using invoke_result = invoke_result_impl; + +template +using invoke_result_t = typename invoke_result::type; + +#if defined(_MSC_VER) && _MSC_VER <= 1900 +// TODO make a version which works with MSVC 2015 +template struct is_swappable : std::true_type {}; + +template struct is_nothrow_swappable : std::true_type {}; +#else +// https://stackoverflow.com/questions/26744589/what-is-a-proper-way-to-implement-is-swappable-to-test-for-the-swappable-concept +namespace swap_adl_tests { +// if swap ADL finds this then it would call std::swap otherwise (same +// signature) +struct tag {}; + +template tag swap(T &, T &); +template tag swap(T (&a)[N], T (&b)[N]); + +// helper functions to test if an unqualified swap is possible, and if it +// becomes std::swap +template std::false_type can_swap(...) noexcept(false); +template (), std::declval()))> +std::true_type can_swap(int) noexcept(noexcept(swap(std::declval(), + std::declval()))); + +template std::false_type uses_std(...); +template +std::is_same(), std::declval())), tag> +uses_std(int); + +template +struct is_std_swap_noexcept + : std::integral_constant::value && + std::is_nothrow_move_assignable::value> {}; + +template +struct is_std_swap_noexcept : is_std_swap_noexcept {}; + +template +struct is_adl_swap_noexcept + : std::integral_constant(0))> {}; +} // namespace swap_adl_tests + +template +struct is_swappable + : std::integral_constant< + bool, + decltype(detail::swap_adl_tests::can_swap(0))::value && + (!decltype(detail::swap_adl_tests::uses_std(0))::value || + (std::is_move_assignable::value && + std::is_move_constructible::value))> {}; + +template +struct is_swappable + : std::integral_constant< + bool, + decltype(detail::swap_adl_tests::can_swap(0))::value && + (!decltype(detail::swap_adl_tests::uses_std( + 0))::value || + is_swappable::value)> {}; + +template +struct is_nothrow_swappable + : std::integral_constant< + bool, + is_swappable::value && + ((decltype(detail::swap_adl_tests::uses_std(0))::value && + detail::swap_adl_tests::is_std_swap_noexcept::value) || + (!decltype(detail::swap_adl_tests::uses_std(0))::value && + detail::swap_adl_tests::is_adl_swap_noexcept::value))> {}; +#endif +#endif + +// Trait for checking if a type is a nonstd::expected +template struct is_expected_impl : std::false_type {}; +template +struct is_expected_impl> : std::true_type {}; +template using is_expected = is_expected_impl>; + +template +using expected_enable_forward_value = detail::enable_if_t< + std::is_constructible::value && + !std::is_same, in_place_t>::value && + !std::is_same, detail::decay_t>::value && + !std::is_same, detail::decay_t>::value>; + +template +using expected_enable_from_other = detail::enable_if_t< + std::is_constructible::value && + std::is_constructible::value && + !std::is_constructible &>::value && + !std::is_constructible &&>::value && + !std::is_constructible &>::value && + !std::is_constructible &&>::value && + !std::is_convertible &, T>::value && + !std::is_convertible &&, T>::value && + !std::is_convertible &, T>::value && + !std::is_convertible &&, T>::value>; + +template +using is_void_or = conditional_t::value, std::true_type, U>; + +template +using is_copy_constructible_or_void = + is_void_or>; + +template +using is_move_constructible_or_void = + is_void_or>; + +template +using is_copy_assignable_or_void = is_void_or>; + +template +using is_move_assignable_or_void = is_void_or>; + +} // namespace detail + +namespace detail { +struct no_init_t {}; +static constexpr no_init_t no_init{}; + +// Implements the storage of the values, and ensures that the destructor is +// trivial if it can be. +// +// This specialization is for where neither `T` or `E` is trivially +// destructible, so the destructors must be called on destruction of the +// `expected` +template ::value, + bool = std::is_trivially_destructible::value> +struct expected_storage_base { + constexpr expected_storage_base() : m_val(T{}), m_has_val(true) {} + constexpr expected_storage_base(no_init_t) : m_no_init(), m_has_val(false) {} + + template ::value> * = + nullptr> + constexpr expected_storage_base(in_place_t, Args &&...args) + : m_val(std::forward(args)...), m_has_val(true) {} + + template &, Args &&...>::value> * = nullptr> + constexpr expected_storage_base(in_place_t, std::initializer_list il, + Args &&...args) + : m_val(il, std::forward(args)...), m_has_val(true) {} + template ::value> * = + nullptr> + constexpr explicit expected_storage_base(unexpect_t, Args &&...args) + : m_unexpect(std::forward(args)...), m_has_val(false) {} + + template &, Args &&...>::value> * = nullptr> + constexpr explicit expected_storage_base(unexpect_t, + std::initializer_list il, + Args &&...args) + : m_unexpect(il, std::forward(args)...), m_has_val(false) {} + + ~expected_storage_base() { + if (m_has_val) { + m_val.~T(); + } else { + m_unexpect.~unexpected(); + } + } + union { + T m_val; + unexpected m_unexpect; + char m_no_init; + }; + bool m_has_val; +}; + +// This specialization is for when both `T` and `E` are trivially-destructible, +// so the destructor of the `expected` can be trivial. +template struct expected_storage_base { + constexpr expected_storage_base() : m_val(T{}), m_has_val(true) {} + constexpr expected_storage_base(no_init_t) : m_no_init(), m_has_val(false) {} + + template ::value> * = + nullptr> + constexpr expected_storage_base(in_place_t, Args &&...args) + : m_val(std::forward(args)...), m_has_val(true) {} + + template &, Args &&...>::value> * = nullptr> + constexpr expected_storage_base(in_place_t, std::initializer_list il, + Args &&...args) + : m_val(il, std::forward(args)...), m_has_val(true) {} + template ::value> * = + nullptr> + constexpr explicit expected_storage_base(unexpect_t, Args &&...args) + : m_unexpect(std::forward(args)...), m_has_val(false) {} + + template &, Args &&...>::value> * = nullptr> + constexpr explicit expected_storage_base(unexpect_t, + std::initializer_list il, + Args &&...args) + : m_unexpect(il, std::forward(args)...), m_has_val(false) {} + + ~expected_storage_base() = default; + union { + T m_val; + unexpected m_unexpect; + char m_no_init; + }; + bool m_has_val; +}; + +// T is trivial, E is not. +template struct expected_storage_base { + constexpr expected_storage_base() : m_val(T{}), m_has_val(true) {} + TL_EXPECTED_MSVC2015_CONSTEXPR expected_storage_base(no_init_t) + : m_no_init(), m_has_val(false) {} + + template ::value> * = + nullptr> + constexpr expected_storage_base(in_place_t, Args &&...args) + : m_val(std::forward(args)...), m_has_val(true) {} + + template &, Args &&...>::value> * = nullptr> + constexpr expected_storage_base(in_place_t, std::initializer_list il, + Args &&...args) + : m_val(il, std::forward(args)...), m_has_val(true) {} + template ::value> * = + nullptr> + constexpr explicit expected_storage_base(unexpect_t, Args &&...args) + : m_unexpect(std::forward(args)...), m_has_val(false) {} + + template &, Args &&...>::value> * = nullptr> + constexpr explicit expected_storage_base(unexpect_t, + std::initializer_list il, + Args &&...args) + : m_unexpect(il, std::forward(args)...), m_has_val(false) {} + + ~expected_storage_base() { + if (!m_has_val) { + m_unexpect.~unexpected(); + } + } + + union { + T m_val; + unexpected m_unexpect; + char m_no_init; + }; + bool m_has_val; +}; + +// E is trivial, T is not. +template struct expected_storage_base { + constexpr expected_storage_base() : m_val(T{}), m_has_val(true) {} + constexpr expected_storage_base(no_init_t) : m_no_init(), m_has_val(false) {} + + template ::value> * = + nullptr> + constexpr expected_storage_base(in_place_t, Args &&...args) + : m_val(std::forward(args)...), m_has_val(true) {} + + template &, Args &&...>::value> * = nullptr> + constexpr expected_storage_base(in_place_t, std::initializer_list il, + Args &&...args) + : m_val(il, std::forward(args)...), m_has_val(true) {} + template ::value> * = + nullptr> + constexpr explicit expected_storage_base(unexpect_t, Args &&...args) + : m_unexpect(std::forward(args)...), m_has_val(false) {} + + template &, Args &&...>::value> * = nullptr> + constexpr explicit expected_storage_base(unexpect_t, + std::initializer_list il, + Args &&...args) + : m_unexpect(il, std::forward(args)...), m_has_val(false) {} + + ~expected_storage_base() { + if (m_has_val) { + m_val.~T(); + } + } + union { + T m_val; + unexpected m_unexpect; + char m_no_init; + }; + bool m_has_val; +}; + +// `T` is `void`, `E` is trivially-destructible +template struct expected_storage_base { + #if __GNUC__ <= 5 + //no constexpr for GCC 4/5 bug + #else + TL_EXPECTED_MSVC2015_CONSTEXPR + #endif + expected_storage_base() : m_has_val(true) {} + + constexpr expected_storage_base(no_init_t) : m_val(), m_has_val(false) {} + + constexpr expected_storage_base(in_place_t) : m_has_val(true) {} + + template ::value> * = + nullptr> + constexpr explicit expected_storage_base(unexpect_t, Args &&...args) + : m_unexpect(std::forward(args)...), m_has_val(false) {} + + template &, Args &&...>::value> * = nullptr> + constexpr explicit expected_storage_base(unexpect_t, + std::initializer_list il, + Args &&...args) + : m_unexpect(il, std::forward(args)...), m_has_val(false) {} + + ~expected_storage_base() = default; + struct dummy {}; + union { + unexpected m_unexpect; + dummy m_val; + }; + bool m_has_val; +}; + +// `T` is `void`, `E` is not trivially-destructible +template struct expected_storage_base { + constexpr expected_storage_base() : m_dummy(), m_has_val(true) {} + constexpr expected_storage_base(no_init_t) : m_dummy(), m_has_val(false) {} + + constexpr expected_storage_base(in_place_t) : m_dummy(), m_has_val(true) {} + + template ::value> * = + nullptr> + constexpr explicit expected_storage_base(unexpect_t, Args &&...args) + : m_unexpect(std::forward(args)...), m_has_val(false) {} + + template &, Args &&...>::value> * = nullptr> + constexpr explicit expected_storage_base(unexpect_t, + std::initializer_list il, + Args &&...args) + : m_unexpect(il, std::forward(args)...), m_has_val(false) {} + + ~expected_storage_base() { + if (!m_has_val) { + m_unexpect.~unexpected(); + } + } + + union { + unexpected m_unexpect; + char m_dummy; + }; + bool m_has_val; +}; + +// This base class provides some handy member functions which can be used in +// further derived classes +template +struct expected_operations_base : expected_storage_base { + using expected_storage_base::expected_storage_base; + + template void construct(Args &&...args) noexcept { + new (std::addressof(this->m_val)) T(std::forward(args)...); + this->m_has_val = true; + } + + template void construct_with(Rhs &&rhs) noexcept { + new (std::addressof(this->m_val)) T(std::forward(rhs).get()); + this->m_has_val = true; + } + + template void construct_error(Args &&...args) noexcept { + new (std::addressof(this->m_unexpect)) + unexpected(std::forward(args)...); + this->m_has_val = false; + } + +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + + // These assign overloads ensure that the most efficient assignment + // implementation is used while maintaining the strong exception guarantee. + // The problematic case is where rhs has a value, but *this does not. + // + // This overload handles the case where we can just copy-construct `T` + // directly into place without throwing. + template ::value> + * = nullptr> + void assign(const expected_operations_base &rhs) noexcept { + if (!this->m_has_val && rhs.m_has_val) { + geterr().~unexpected(); + construct(rhs.get()); + } else { + assign_common(rhs); + } + } + + // This overload handles the case where we can attempt to create a copy of + // `T`, then no-throw move it into place if the copy was successful. + template ::value && + std::is_nothrow_move_constructible::value> + * = nullptr> + void assign(const expected_operations_base &rhs) noexcept { + if (!this->m_has_val && rhs.m_has_val) { + T tmp = rhs.get(); + geterr().~unexpected(); + construct(std::move(tmp)); + } else { + assign_common(rhs); + } + } + + // This overload is the worst-case, where we have to move-construct the + // unexpected value into temporary storage, then try to copy the T into place. + // If the construction succeeds, then everything is fine, but if it throws, + // then we move the old unexpected value back into place before rethrowing the + // exception. + template ::value && + !std::is_nothrow_move_constructible::value> + * = nullptr> + void assign(const expected_operations_base &rhs) { + if (!this->m_has_val && rhs.m_has_val) { + auto tmp = std::move(geterr()); + geterr().~unexpected(); + +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + try { + construct(rhs.get()); + } catch (...) { + geterr() = std::move(tmp); + throw; + } +#else + construct(rhs.get()); +#endif + } else { + assign_common(rhs); + } + } + + // These overloads do the same as above, but for rvalues + template ::value> + * = nullptr> + void assign(expected_operations_base &&rhs) noexcept { + if (!this->m_has_val && rhs.m_has_val) { + geterr().~unexpected(); + construct(std::move(rhs).get()); + } else { + assign_common(std::move(rhs)); + } + } + + template ::value> + * = nullptr> + void assign(expected_operations_base &&rhs) { + if (!this->m_has_val && rhs.m_has_val) { + auto tmp = std::move(geterr()); + geterr().~unexpected(); +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + try { + construct(std::move(rhs).get()); + } catch (...) { + geterr() = std::move(tmp); + throw; + } +#else + construct(std::move(rhs).get()); +#endif + } else { + assign_common(std::move(rhs)); + } + } + +#else + + // If exceptions are disabled then we can just copy-construct + void assign(const expected_operations_base &rhs) noexcept { + if (!this->m_has_val && rhs.m_has_val) { + geterr().~unexpected(); + construct(rhs.get()); + } else { + assign_common(rhs); + } + } + + void assign(expected_operations_base &&rhs) noexcept { + if (!this->m_has_val && rhs.m_has_val) { + geterr().~unexpected(); + construct(std::move(rhs).get()); + } else { + assign_common(std::move(rhs)); + } + } + +#endif + + // The common part of move/copy assigning + template void assign_common(Rhs &&rhs) { + if (this->m_has_val) { + if (rhs.m_has_val) { + get() = std::forward(rhs).get(); + } else { + destroy_val(); + construct_error(std::forward(rhs).geterr()); + } + } else { + if (!rhs.m_has_val) { + geterr() = std::forward(rhs).geterr(); + } + } + } + + bool has_value() const { return this->m_has_val; } + + TL_EXPECTED_11_CONSTEXPR T &get() & { return this->m_val; } + constexpr const T &get() const & { return this->m_val; } + TL_EXPECTED_11_CONSTEXPR T &&get() && { return std::move(this->m_val); } +#ifndef TL_EXPECTED_NO_CONSTRR + constexpr const T &&get() const && { return std::move(this->m_val); } +#endif + + TL_EXPECTED_11_CONSTEXPR unexpected &geterr() & { + return this->m_unexpect; + } + constexpr const unexpected &geterr() const & { return this->m_unexpect; } + TL_EXPECTED_11_CONSTEXPR unexpected &&geterr() && { + return std::move(this->m_unexpect); + } +#ifndef TL_EXPECTED_NO_CONSTRR + constexpr const unexpected &&geterr() const && { + return std::move(this->m_unexpect); + } +#endif + + TL_EXPECTED_11_CONSTEXPR void destroy_val() { get().~T(); } +}; + +// This base class provides some handy member functions which can be used in +// further derived classes +template +struct expected_operations_base : expected_storage_base { + using expected_storage_base::expected_storage_base; + + template void construct() noexcept { this->m_has_val = true; } + + // This function doesn't use its argument, but needs it so that code in + // levels above this can work independently of whether T is void + template void construct_with(Rhs &&) noexcept { + this->m_has_val = true; + } + + template void construct_error(Args &&...args) noexcept { + new (std::addressof(this->m_unexpect)) + unexpected(std::forward(args)...); + this->m_has_val = false; + } + + template void assign(Rhs &&rhs) noexcept { + if (!this->m_has_val) { + if (rhs.m_has_val) { + geterr().~unexpected(); + construct(); + } else { + geterr() = std::forward(rhs).geterr(); + } + } else { + if (!rhs.m_has_val) { + construct_error(std::forward(rhs).geterr()); + } + } + } + + bool has_value() const { return this->m_has_val; } + + TL_EXPECTED_11_CONSTEXPR unexpected &geterr() & { + return this->m_unexpect; + } + constexpr const unexpected &geterr() const & { return this->m_unexpect; } + TL_EXPECTED_11_CONSTEXPR unexpected &&geterr() && { + return std::move(this->m_unexpect); + } +#ifndef TL_EXPECTED_NO_CONSTRR + constexpr const unexpected &&geterr() const && { + return std::move(this->m_unexpect); + } +#endif + + TL_EXPECTED_11_CONSTEXPR void destroy_val() { + // no-op + } +}; + +// This class manages conditionally having a trivial copy constructor +// This specialization is for when T and E are trivially copy constructible +template :: + value &&TL_EXPECTED_IS_TRIVIALLY_COPY_CONSTRUCTIBLE(E)::value> +struct expected_copy_base : expected_operations_base { + using expected_operations_base::expected_operations_base; +}; + +// This specialization is for when T or E are not trivially copy constructible +template +struct expected_copy_base : expected_operations_base { + using expected_operations_base::expected_operations_base; + + expected_copy_base() = default; + expected_copy_base(const expected_copy_base &rhs) + : expected_operations_base(no_init) { + if (rhs.has_value()) { + this->construct_with(rhs); + } else { + this->construct_error(rhs.geterr()); + } + } + + expected_copy_base(expected_copy_base &&rhs) = default; + expected_copy_base &operator=(const expected_copy_base &rhs) = default; + expected_copy_base &operator=(expected_copy_base &&rhs) = default; +}; + +// This class manages conditionally having a trivial move constructor +// Unfortunately there's no way to achieve this in GCC < 5 AFAIK, since it +// doesn't implement an analogue to std::is_trivially_move_constructible. We +// have to make do with a non-trivial move constructor even if T is trivially +// move constructible +#ifndef TL_EXPECTED_GCC49 +template >::value + &&std::is_trivially_move_constructible::value> +struct expected_move_base : expected_copy_base { + using expected_copy_base::expected_copy_base; +}; +#else +template struct expected_move_base; +#endif +template +struct expected_move_base : expected_copy_base { + using expected_copy_base::expected_copy_base; + + expected_move_base() = default; + expected_move_base(const expected_move_base &rhs) = default; + + expected_move_base(expected_move_base &&rhs) noexcept( + std::is_nothrow_move_constructible::value) + : expected_copy_base(no_init) { + if (rhs.has_value()) { + this->construct_with(std::move(rhs)); + } else { + this->construct_error(std::move(rhs.geterr())); + } + } + expected_move_base &operator=(const expected_move_base &rhs) = default; + expected_move_base &operator=(expected_move_base &&rhs) = default; +}; + +// This class manages conditionally having a trivial copy assignment operator +template >::value + &&TL_EXPECTED_IS_TRIVIALLY_COPY_ASSIGNABLE(E)::value + &&TL_EXPECTED_IS_TRIVIALLY_COPY_CONSTRUCTIBLE(E)::value + &&TL_EXPECTED_IS_TRIVIALLY_DESTRUCTIBLE(E)::value> +struct expected_copy_assign_base : expected_move_base { + using expected_move_base::expected_move_base; +}; + +template +struct expected_copy_assign_base : expected_move_base { + using expected_move_base::expected_move_base; + + expected_copy_assign_base() = default; + expected_copy_assign_base(const expected_copy_assign_base &rhs) = default; + + expected_copy_assign_base(expected_copy_assign_base &&rhs) = default; + expected_copy_assign_base &operator=(const expected_copy_assign_base &rhs) { + this->assign(rhs); + return *this; + } + expected_copy_assign_base & + operator=(expected_copy_assign_base &&rhs) = default; +}; + +// This class manages conditionally having a trivial move assignment operator +// Unfortunately there's no way to achieve this in GCC < 5 AFAIK, since it +// doesn't implement an analogue to std::is_trivially_move_assignable. We have +// to make do with a non-trivial move assignment operator even if T is trivially +// move assignable +#ifndef TL_EXPECTED_GCC49 +template , + std::is_trivially_move_constructible, + std::is_trivially_move_assignable>>:: + value &&std::is_trivially_destructible::value + &&std::is_trivially_move_constructible::value + &&std::is_trivially_move_assignable::value> +struct expected_move_assign_base : expected_copy_assign_base { + using expected_copy_assign_base::expected_copy_assign_base; +}; +#else +template struct expected_move_assign_base; +#endif + +template +struct expected_move_assign_base + : expected_copy_assign_base { + using expected_copy_assign_base::expected_copy_assign_base; + + expected_move_assign_base() = default; + expected_move_assign_base(const expected_move_assign_base &rhs) = default; + + expected_move_assign_base(expected_move_assign_base &&rhs) = default; + + expected_move_assign_base & + operator=(const expected_move_assign_base &rhs) = default; + + expected_move_assign_base & + operator=(expected_move_assign_base &&rhs) noexcept( + std::is_nothrow_move_constructible::value + &&std::is_nothrow_move_assignable::value) { + this->assign(std::move(rhs)); + return *this; + } +}; + +// expected_delete_ctor_base will conditionally delete copy and move +// constructors depending on whether T is copy/move constructible +template ::value && + std::is_copy_constructible::value), + bool EnableMove = (is_move_constructible_or_void::value && + std::is_move_constructible::value)> +struct expected_delete_ctor_base { + expected_delete_ctor_base() = default; + expected_delete_ctor_base(const expected_delete_ctor_base &) = default; + expected_delete_ctor_base(expected_delete_ctor_base &&) noexcept = default; + expected_delete_ctor_base & + operator=(const expected_delete_ctor_base &) = default; + expected_delete_ctor_base & + operator=(expected_delete_ctor_base &&) noexcept = default; +}; + +template +struct expected_delete_ctor_base { + expected_delete_ctor_base() = default; + expected_delete_ctor_base(const expected_delete_ctor_base &) = default; + expected_delete_ctor_base(expected_delete_ctor_base &&) noexcept = delete; + expected_delete_ctor_base & + operator=(const expected_delete_ctor_base &) = default; + expected_delete_ctor_base & + operator=(expected_delete_ctor_base &&) noexcept = default; +}; + +template +struct expected_delete_ctor_base { + expected_delete_ctor_base() = default; + expected_delete_ctor_base(const expected_delete_ctor_base &) = delete; + expected_delete_ctor_base(expected_delete_ctor_base &&) noexcept = default; + expected_delete_ctor_base & + operator=(const expected_delete_ctor_base &) = default; + expected_delete_ctor_base & + operator=(expected_delete_ctor_base &&) noexcept = default; +}; + +template +struct expected_delete_ctor_base { + expected_delete_ctor_base() = default; + expected_delete_ctor_base(const expected_delete_ctor_base &) = delete; + expected_delete_ctor_base(expected_delete_ctor_base &&) noexcept = delete; + expected_delete_ctor_base & + operator=(const expected_delete_ctor_base &) = default; + expected_delete_ctor_base & + operator=(expected_delete_ctor_base &&) noexcept = default; +}; + +// expected_delete_assign_base will conditionally delete copy and move +// constructors depending on whether T and E are copy/move constructible + +// assignable +template ::value && + std::is_copy_constructible::value && + is_copy_assignable_or_void::value && + std::is_copy_assignable::value), + bool EnableMove = (is_move_constructible_or_void::value && + std::is_move_constructible::value && + is_move_assignable_or_void::value && + std::is_move_assignable::value)> +struct expected_delete_assign_base { + expected_delete_assign_base() = default; + expected_delete_assign_base(const expected_delete_assign_base &) = default; + expected_delete_assign_base(expected_delete_assign_base &&) noexcept = + default; + expected_delete_assign_base & + operator=(const expected_delete_assign_base &) = default; + expected_delete_assign_base & + operator=(expected_delete_assign_base &&) noexcept = default; +}; + +template +struct expected_delete_assign_base { + expected_delete_assign_base() = default; + expected_delete_assign_base(const expected_delete_assign_base &) = default; + expected_delete_assign_base(expected_delete_assign_base &&) noexcept = + default; + expected_delete_assign_base & + operator=(const expected_delete_assign_base &) = default; + expected_delete_assign_base & + operator=(expected_delete_assign_base &&) noexcept = delete; +}; + +template +struct expected_delete_assign_base { + expected_delete_assign_base() = default; + expected_delete_assign_base(const expected_delete_assign_base &) = default; + expected_delete_assign_base(expected_delete_assign_base &&) noexcept = + default; + expected_delete_assign_base & + operator=(const expected_delete_assign_base &) = delete; + expected_delete_assign_base & + operator=(expected_delete_assign_base &&) noexcept = default; +}; + +template +struct expected_delete_assign_base { + expected_delete_assign_base() = default; + expected_delete_assign_base(const expected_delete_assign_base &) = default; + expected_delete_assign_base(expected_delete_assign_base &&) noexcept = + default; + expected_delete_assign_base & + operator=(const expected_delete_assign_base &) = delete; + expected_delete_assign_base & + operator=(expected_delete_assign_base &&) noexcept = delete; +}; + +// This is needed to be able to construct the expected_default_ctor_base which +// follows, while still conditionally deleting the default constructor. +struct default_constructor_tag { + explicit constexpr default_constructor_tag() = default; +}; + +// expected_default_ctor_base will ensure that expected has a deleted default +// consturctor if T is not default constructible. +// This specialization is for when T is default constructible +template ::value || std::is_void::value> +struct expected_default_ctor_base { + constexpr expected_default_ctor_base() noexcept = default; + constexpr expected_default_ctor_base( + expected_default_ctor_base const &) noexcept = default; + constexpr expected_default_ctor_base(expected_default_ctor_base &&) noexcept = + default; + expected_default_ctor_base & + operator=(expected_default_ctor_base const &) noexcept = default; + expected_default_ctor_base & + operator=(expected_default_ctor_base &&) noexcept = default; + + constexpr explicit expected_default_ctor_base(default_constructor_tag) {} +}; + +// This specialization is for when T is not default constructible +template struct expected_default_ctor_base { + constexpr expected_default_ctor_base() noexcept = delete; + constexpr expected_default_ctor_base( + expected_default_ctor_base const &) noexcept = default; + constexpr expected_default_ctor_base(expected_default_ctor_base &&) noexcept = + default; + expected_default_ctor_base & + operator=(expected_default_ctor_base const &) noexcept = default; + expected_default_ctor_base & + operator=(expected_default_ctor_base &&) noexcept = default; + + constexpr explicit expected_default_ctor_base(default_constructor_tag) {} +}; +} // namespace detail + +template class bad_expected_access : public std::exception { +public: + explicit bad_expected_access(E e) : m_val(std::move(e)) {} + + virtual const char *what() const noexcept override { + return "Bad expected access"; + } + + const E &error() const & { return m_val; } + E &error() & { return m_val; } + const E &&error() const && { return std::move(m_val); } + E &&error() && { return std::move(m_val); } + +private: + E m_val; +}; + +/// An `expected` object is an object that contains the storage for +/// another object and manages the lifetime of this contained object `T`. +/// Alternatively it could contain the storage for another unexpected object +/// `E`. The contained object may not be initialized after the expected object +/// has been initialized, and may not be destroyed before the expected object +/// has been destroyed. The initialization state of the contained object is +/// tracked by the expected object. +template +class expected : private detail::expected_move_assign_base, + private detail::expected_delete_ctor_base, + private detail::expected_delete_assign_base, + private detail::expected_default_ctor_base { + static_assert(!std::is_reference::value, "T must not be a reference"); + static_assert(!std::is_same::type>::value, + "T must not be in_place_t"); + static_assert(!std::is_same::type>::value, + "T must not be unexpect_t"); + static_assert( + !std::is_same>::type>::value, + "T must not be unexpected"); + static_assert(!std::is_reference::value, "E must not be a reference"); + + T *valptr() { return std::addressof(this->m_val); } + const T *valptr() const { return std::addressof(this->m_val); } + unexpected *errptr() { return std::addressof(this->m_unexpect); } + const unexpected *errptr() const { + return std::addressof(this->m_unexpect); + } + + template ::value> * = nullptr> + TL_EXPECTED_11_CONSTEXPR U &val() { + return this->m_val; + } + TL_EXPECTED_11_CONSTEXPR unexpected &err() { return this->m_unexpect; } + + template ::value> * = nullptr> + constexpr const U &val() const { + return this->m_val; + } + constexpr const unexpected &err() const { return this->m_unexpect; } + + using impl_base = detail::expected_move_assign_base; + using ctor_base = detail::expected_default_ctor_base; + +public: + typedef T value_type; + typedef E error_type; + typedef unexpected unexpected_type; + +#if defined(TL_EXPECTED_CXX14) && !defined(TL_EXPECTED_GCC49) && \ + !defined(TL_EXPECTED_GCC54) && !defined(TL_EXPECTED_GCC55) + template TL_EXPECTED_11_CONSTEXPR auto and_then(F &&f) & { + return and_then_impl(*this, std::forward(f)); + } + template TL_EXPECTED_11_CONSTEXPR auto and_then(F &&f) && { + return and_then_impl(std::move(*this), std::forward(f)); + } + template constexpr auto and_then(F &&f) const & { + return and_then_impl(*this, std::forward(f)); + } + +#ifndef TL_EXPECTED_NO_CONSTRR + template constexpr auto and_then(F &&f) const && { + return and_then_impl(std::move(*this), std::forward(f)); + } +#endif + +#else + template + TL_EXPECTED_11_CONSTEXPR auto + and_then(F &&f) & -> decltype(and_then_impl(std::declval(), + std::forward(f))) { + return and_then_impl(*this, std::forward(f)); + } + template + TL_EXPECTED_11_CONSTEXPR auto + and_then(F &&f) && -> decltype(and_then_impl(std::declval(), + std::forward(f))) { + return and_then_impl(std::move(*this), std::forward(f)); + } + template + constexpr auto and_then(F &&f) const & -> decltype(and_then_impl( + std::declval(), std::forward(f))) { + return and_then_impl(*this, std::forward(f)); + } + +#ifndef TL_EXPECTED_NO_CONSTRR + template + constexpr auto and_then(F &&f) const && -> decltype(and_then_impl( + std::declval(), std::forward(f))) { + return and_then_impl(std::move(*this), std::forward(f)); + } +#endif +#endif + +#if defined(TL_EXPECTED_CXX14) && !defined(TL_EXPECTED_GCC49) && \ + !defined(TL_EXPECTED_GCC54) && !defined(TL_EXPECTED_GCC55) + template TL_EXPECTED_11_CONSTEXPR auto map(F &&f) & { + return expected_map_impl(*this, std::forward(f)); + } + template TL_EXPECTED_11_CONSTEXPR auto map(F &&f) && { + return expected_map_impl(std::move(*this), std::forward(f)); + } + template constexpr auto map(F &&f) const & { + return expected_map_impl(*this, std::forward(f)); + } + template constexpr auto map(F &&f) const && { + return expected_map_impl(std::move(*this), std::forward(f)); + } +#else + template + TL_EXPECTED_11_CONSTEXPR decltype(expected_map_impl( + std::declval(), std::declval())) + map(F &&f) & { + return expected_map_impl(*this, std::forward(f)); + } + template + TL_EXPECTED_11_CONSTEXPR decltype(expected_map_impl(std::declval(), + std::declval())) + map(F &&f) && { + return expected_map_impl(std::move(*this), std::forward(f)); + } + template + constexpr decltype(expected_map_impl(std::declval(), + std::declval())) + map(F &&f) const & { + return expected_map_impl(*this, std::forward(f)); + } + +#ifndef TL_EXPECTED_NO_CONSTRR + template + constexpr decltype(expected_map_impl(std::declval(), + std::declval())) + map(F &&f) const && { + return expected_map_impl(std::move(*this), std::forward(f)); + } +#endif +#endif + +#if defined(TL_EXPECTED_CXX14) && !defined(TL_EXPECTED_GCC49) && \ + !defined(TL_EXPECTED_GCC54) && !defined(TL_EXPECTED_GCC55) + template TL_EXPECTED_11_CONSTEXPR auto transform(F &&f) & { + return expected_map_impl(*this, std::forward(f)); + } + template TL_EXPECTED_11_CONSTEXPR auto transform(F &&f) && { + return expected_map_impl(std::move(*this), std::forward(f)); + } + template constexpr auto transform(F &&f) const & { + return expected_map_impl(*this, std::forward(f)); + } + template constexpr auto transform(F &&f) const && { + return expected_map_impl(std::move(*this), std::forward(f)); + } +#else + template + TL_EXPECTED_11_CONSTEXPR decltype(expected_map_impl( + std::declval(), std::declval())) + transform(F &&f) & { + return expected_map_impl(*this, std::forward(f)); + } + template + TL_EXPECTED_11_CONSTEXPR decltype(expected_map_impl(std::declval(), + std::declval())) + transform(F &&f) && { + return expected_map_impl(std::move(*this), std::forward(f)); + } + template + constexpr decltype(expected_map_impl(std::declval(), + std::declval())) + transform(F &&f) const & { + return expected_map_impl(*this, std::forward(f)); + } + +#ifndef TL_EXPECTED_NO_CONSTRR + template + constexpr decltype(expected_map_impl(std::declval(), + std::declval())) + transform(F &&f) const && { + return expected_map_impl(std::move(*this), std::forward(f)); + } +#endif +#endif + +#if defined(TL_EXPECTED_CXX14) && !defined(TL_EXPECTED_GCC49) && \ + !defined(TL_EXPECTED_GCC54) && !defined(TL_EXPECTED_GCC55) + template TL_EXPECTED_11_CONSTEXPR auto map_error(F &&f) & { + return map_error_impl(*this, std::forward(f)); + } + template TL_EXPECTED_11_CONSTEXPR auto map_error(F &&f) && { + return map_error_impl(std::move(*this), std::forward(f)); + } + template constexpr auto map_error(F &&f) const & { + return map_error_impl(*this, std::forward(f)); + } + template constexpr auto map_error(F &&f) const && { + return map_error_impl(std::move(*this), std::forward(f)); + } +#else + template + TL_EXPECTED_11_CONSTEXPR decltype(map_error_impl(std::declval(), + std::declval())) + map_error(F &&f) & { + return map_error_impl(*this, std::forward(f)); + } + template + TL_EXPECTED_11_CONSTEXPR decltype(map_error_impl(std::declval(), + std::declval())) + map_error(F &&f) && { + return map_error_impl(std::move(*this), std::forward(f)); + } + template + constexpr decltype(map_error_impl(std::declval(), + std::declval())) + map_error(F &&f) const & { + return map_error_impl(*this, std::forward(f)); + } + +#ifndef TL_EXPECTED_NO_CONSTRR + template + constexpr decltype(map_error_impl(std::declval(), + std::declval())) + map_error(F &&f) const && { + return map_error_impl(std::move(*this), std::forward(f)); + } +#endif +#endif +#if defined(TL_EXPECTED_CXX14) && !defined(TL_EXPECTED_GCC49) && \ + !defined(TL_EXPECTED_GCC54) && !defined(TL_EXPECTED_GCC55) + template TL_EXPECTED_11_CONSTEXPR auto transform_error(F &&f) & { + return map_error_impl(*this, std::forward(f)); + } + template TL_EXPECTED_11_CONSTEXPR auto transform_error(F &&f) && { + return map_error_impl(std::move(*this), std::forward(f)); + } + template constexpr auto transform_error(F &&f) const & { + return map_error_impl(*this, std::forward(f)); + } + template constexpr auto transform_error(F &&f) const && { + return map_error_impl(std::move(*this), std::forward(f)); + } +#else + template + TL_EXPECTED_11_CONSTEXPR decltype(map_error_impl(std::declval(), + std::declval())) + transform_error(F &&f) & { + return map_error_impl(*this, std::forward(f)); + } + template + TL_EXPECTED_11_CONSTEXPR decltype(map_error_impl(std::declval(), + std::declval())) + transform_error(F &&f) && { + return map_error_impl(std::move(*this), std::forward(f)); + } + template + constexpr decltype(map_error_impl(std::declval(), + std::declval())) + transform_error(F &&f) const & { + return map_error_impl(*this, std::forward(f)); + } + +#ifndef TL_EXPECTED_NO_CONSTRR + template + constexpr decltype(map_error_impl(std::declval(), + std::declval())) + transform_error(F &&f) const && { + return map_error_impl(std::move(*this), std::forward(f)); + } +#endif +#endif + template expected TL_EXPECTED_11_CONSTEXPR or_else(F &&f) & { + return or_else_impl(*this, std::forward(f)); + } + + template expected TL_EXPECTED_11_CONSTEXPR or_else(F &&f) && { + return or_else_impl(std::move(*this), std::forward(f)); + } + + template expected constexpr or_else(F &&f) const & { + return or_else_impl(*this, std::forward(f)); + } + +#ifndef TL_EXPECTED_NO_CONSTRR + template expected constexpr or_else(F &&f) const && { + return or_else_impl(std::move(*this), std::forward(f)); + } +#endif + constexpr expected() = default; + constexpr expected(const expected &rhs) = default; + constexpr expected(expected &&rhs) = default; + expected &operator=(const expected &rhs) = default; + expected &operator=(expected &&rhs) = default; + + template ::value> * = + nullptr> + constexpr expected(in_place_t, Args &&...args) + : impl_base(in_place, std::forward(args)...), + ctor_base(detail::default_constructor_tag{}) {} + + template &, Args &&...>::value> * = nullptr> + constexpr expected(in_place_t, std::initializer_list il, Args &&...args) + : impl_base(in_place, il, std::forward(args)...), + ctor_base(detail::default_constructor_tag{}) {} + + template ::value> * = + nullptr, + detail::enable_if_t::value> * = + nullptr> + explicit constexpr expected(const unexpected &e) + : impl_base(unexpect, e.value()), + ctor_base(detail::default_constructor_tag{}) {} + + template < + class G = E, + detail::enable_if_t::value> * = + nullptr, + detail::enable_if_t::value> * = nullptr> + constexpr expected(unexpected const &e) + : impl_base(unexpect, e.value()), + ctor_base(detail::default_constructor_tag{}) {} + + template < + class G = E, + detail::enable_if_t::value> * = nullptr, + detail::enable_if_t::value> * = nullptr> + explicit constexpr expected(unexpected &&e) noexcept( + std::is_nothrow_constructible::value) + : impl_base(unexpect, std::move(e.value())), + ctor_base(detail::default_constructor_tag{}) {} + + template < + class G = E, + detail::enable_if_t::value> * = nullptr, + detail::enable_if_t::value> * = nullptr> + constexpr expected(unexpected &&e) noexcept( + std::is_nothrow_constructible::value) + : impl_base(unexpect, std::move(e.value())), + ctor_base(detail::default_constructor_tag{}) {} + + template ::value> * = + nullptr> + constexpr explicit expected(unexpect_t, Args &&...args) + : impl_base(unexpect, std::forward(args)...), + ctor_base(detail::default_constructor_tag{}) {} + + template &, Args &&...>::value> * = nullptr> + constexpr explicit expected(unexpect_t, std::initializer_list il, + Args &&...args) + : impl_base(unexpect, il, std::forward(args)...), + ctor_base(detail::default_constructor_tag{}) {} + + template ::value && + std::is_convertible::value)> * = + nullptr, + detail::expected_enable_from_other + * = nullptr> + explicit TL_EXPECTED_11_CONSTEXPR expected(const expected &rhs) + : ctor_base(detail::default_constructor_tag{}) { + if (rhs.has_value()) { + this->construct(*rhs); + } else { + this->construct_error(rhs.error()); + } + } + + template ::value && + std::is_convertible::value)> * = + nullptr, + detail::expected_enable_from_other + * = nullptr> + TL_EXPECTED_11_CONSTEXPR expected(const expected &rhs) + : ctor_base(detail::default_constructor_tag{}) { + if (rhs.has_value()) { + this->construct(*rhs); + } else { + this->construct_error(rhs.error()); + } + } + + template < + class U, class G, + detail::enable_if_t::value && + std::is_convertible::value)> * = nullptr, + detail::expected_enable_from_other * = nullptr> + explicit TL_EXPECTED_11_CONSTEXPR expected(expected &&rhs) + : ctor_base(detail::default_constructor_tag{}) { + if (rhs.has_value()) { + this->construct(std::move(*rhs)); + } else { + this->construct_error(std::move(rhs.error())); + } + } + + template < + class U, class G, + detail::enable_if_t<(std::is_convertible::value && + std::is_convertible::value)> * = nullptr, + detail::expected_enable_from_other * = nullptr> + TL_EXPECTED_11_CONSTEXPR expected(expected &&rhs) + : ctor_base(detail::default_constructor_tag{}) { + if (rhs.has_value()) { + this->construct(std::move(*rhs)); + } else { + this->construct_error(std::move(rhs.error())); + } + } + + template < + class U = T, + detail::enable_if_t::value> * = nullptr, + detail::expected_enable_forward_value * = nullptr> + explicit TL_EXPECTED_MSVC2015_CONSTEXPR expected(U &&v) + : expected(in_place, std::forward(v)) {} + + template < + class U = T, + detail::enable_if_t::value> * = nullptr, + detail::expected_enable_forward_value * = nullptr> + TL_EXPECTED_MSVC2015_CONSTEXPR expected(U &&v) + : expected(in_place, std::forward(v)) {} + + template < + class U = T, class G = T, + detail::enable_if_t::value> * = + nullptr, + detail::enable_if_t::value> * = nullptr, + detail::enable_if_t< + (!std::is_same, detail::decay_t>::value && + !detail::conjunction, + std::is_same>>::value && + std::is_constructible::value && + std::is_assignable::value && + std::is_nothrow_move_constructible::value)> * = nullptr> + expected &operator=(U &&v) { + if (has_value()) { + val() = std::forward(v); + } else { + err().~unexpected(); + ::new (valptr()) T(std::forward(v)); + this->m_has_val = true; + } + + return *this; + } + + template < + class U = T, class G = T, + detail::enable_if_t::value> * = + nullptr, + detail::enable_if_t::value> * = nullptr, + detail::enable_if_t< + (!std::is_same, detail::decay_t>::value && + !detail::conjunction, + std::is_same>>::value && + std::is_constructible::value && + std::is_assignable::value && + std::is_nothrow_move_constructible::value)> * = nullptr> + expected &operator=(U &&v) { + if (has_value()) { + val() = std::forward(v); + } else { + auto tmp = std::move(err()); + err().~unexpected(); + +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + try { + ::new (valptr()) T(std::forward(v)); + this->m_has_val = true; + } catch (...) { + err() = std::move(tmp); + throw; + } +#else + ::new (valptr()) T(std::forward(v)); + this->m_has_val = true; +#endif + } + + return *this; + } + + template ::value && + std::is_assignable::value> * = nullptr> + expected &operator=(const unexpected &rhs) { + if (!has_value()) { + err() = rhs; + } else { + this->destroy_val(); + ::new (errptr()) unexpected(rhs); + this->m_has_val = false; + } + + return *this; + } + + template ::value && + std::is_move_assignable::value> * = nullptr> + expected &operator=(unexpected &&rhs) noexcept { + if (!has_value()) { + err() = std::move(rhs); + } else { + this->destroy_val(); + ::new (errptr()) unexpected(std::move(rhs)); + this->m_has_val = false; + } + + return *this; + } + + template ::value> * = nullptr> + void emplace(Args &&...args) { + if (has_value()) { + val().~T(); + } else { + err().~unexpected(); + this->m_has_val = true; + } + ::new (valptr()) T(std::forward(args)...); + } + + template ::value> * = nullptr> + void emplace(Args &&...args) { + if (has_value()) { + val().~T(); + ::new (valptr()) T(std::forward(args)...); + } else { + auto tmp = std::move(err()); + err().~unexpected(); + +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + try { + ::new (valptr()) T(std::forward(args)...); + this->m_has_val = true; + } catch (...) { + err() = std::move(tmp); + throw; + } +#else + ::new (valptr()) T(std::forward(args)...); + this->m_has_val = true; +#endif + } + } + + template &, Args &&...>::value> * = nullptr> + void emplace(std::initializer_list il, Args &&...args) { + if (has_value()) { + T t(il, std::forward(args)...); + val() = std::move(t); + } else { + err().~unexpected(); + ::new (valptr()) T(il, std::forward(args)...); + this->m_has_val = true; + } + } + + template &, Args &&...>::value> * = nullptr> + void emplace(std::initializer_list il, Args &&...args) { + if (has_value()) { + T t(il, std::forward(args)...); + val() = std::move(t); + } else { + auto tmp = std::move(err()); + err().~unexpected(); + +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + try { + ::new (valptr()) T(il, std::forward(args)...); + this->m_has_val = true; + } catch (...) { + err() = std::move(tmp); + throw; + } +#else + ::new (valptr()) T(il, std::forward(args)...); + this->m_has_val = true; +#endif + } + } + +private: + using t_is_void = std::true_type; + using t_is_not_void = std::false_type; + using t_is_nothrow_move_constructible = std::true_type; + using move_constructing_t_can_throw = std::false_type; + using e_is_nothrow_move_constructible = std::true_type; + using move_constructing_e_can_throw = std::false_type; + + void swap_where_both_have_value(expected & /*rhs*/, t_is_void) noexcept { + // swapping void is a no-op + } + + void swap_where_both_have_value(expected &rhs, t_is_not_void) { + using std::swap; + swap(val(), rhs.val()); + } + + void swap_where_only_one_has_value(expected &rhs, t_is_void) noexcept( + std::is_nothrow_move_constructible::value) { + ::new (errptr()) unexpected_type(std::move(rhs.err())); + rhs.err().~unexpected_type(); + std::swap(this->m_has_val, rhs.m_has_val); + } + + void swap_where_only_one_has_value(expected &rhs, t_is_not_void) { + swap_where_only_one_has_value_and_t_is_not_void( + rhs, typename std::is_nothrow_move_constructible::type{}, + typename std::is_nothrow_move_constructible::type{}); + } + + void swap_where_only_one_has_value_and_t_is_not_void( + expected &rhs, t_is_nothrow_move_constructible, + e_is_nothrow_move_constructible) noexcept { + auto temp = std::move(val()); + val().~T(); + ::new (errptr()) unexpected_type(std::move(rhs.err())); + rhs.err().~unexpected_type(); + ::new (rhs.valptr()) T(std::move(temp)); + std::swap(this->m_has_val, rhs.m_has_val); + } + + void swap_where_only_one_has_value_and_t_is_not_void( + expected &rhs, t_is_nothrow_move_constructible, + move_constructing_e_can_throw) { + auto temp = std::move(val()); + val().~T(); +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + try { + ::new (errptr()) unexpected_type(std::move(rhs.err())); + rhs.err().~unexpected_type(); + ::new (rhs.valptr()) T(std::move(temp)); + std::swap(this->m_has_val, rhs.m_has_val); + } catch (...) { + val() = std::move(temp); + throw; + } +#else + ::new (errptr()) unexpected_type(std::move(rhs.err())); + rhs.err().~unexpected_type(); + ::new (rhs.valptr()) T(std::move(temp)); + std::swap(this->m_has_val, rhs.m_has_val); +#endif + } + + void swap_where_only_one_has_value_and_t_is_not_void( + expected &rhs, move_constructing_t_can_throw, + e_is_nothrow_move_constructible) { + auto temp = std::move(rhs.err()); + rhs.err().~unexpected_type(); +#ifdef TL_EXPECTED_EXCEPTIONS_ENABLED + try { + ::new (rhs.valptr()) T(std::move(val())); + val().~T(); + ::new (errptr()) unexpected_type(std::move(temp)); + std::swap(this->m_has_val, rhs.m_has_val); + } catch (...) { + rhs.err() = std::move(temp); + throw; + } +#else + ::new (rhs.valptr()) T(std::move(val())); + val().~T(); + ::new (errptr()) unexpected_type(std::move(temp)); + std::swap(this->m_has_val, rhs.m_has_val); +#endif + } + +public: + template + detail::enable_if_t::value && + detail::is_swappable::value && + (std::is_nothrow_move_constructible::value || + std::is_nothrow_move_constructible::value)> + swap(expected &rhs) noexcept( + std::is_nothrow_move_constructible::value + &&detail::is_nothrow_swappable::value + &&std::is_nothrow_move_constructible::value + &&detail::is_nothrow_swappable::value) { + if (has_value() && rhs.has_value()) { + swap_where_both_have_value(rhs, typename std::is_void::type{}); + } else if (!has_value() && rhs.has_value()) { + rhs.swap(*this); + } else if (has_value()) { + swap_where_only_one_has_value(rhs, typename std::is_void::type{}); + } else { + using std::swap; + swap(err(), rhs.err()); + } + } + + constexpr const T *operator->() const { + TL_ASSERT(has_value()); + return valptr(); + } + TL_EXPECTED_11_CONSTEXPR T *operator->() { + TL_ASSERT(has_value()); + return valptr(); + } + + template ::value> * = nullptr> + constexpr const U &operator*() const & { + TL_ASSERT(has_value()); + return val(); + } + template ::value> * = nullptr> + TL_EXPECTED_11_CONSTEXPR U &operator*() & { + TL_ASSERT(has_value()); + return val(); + } + template ::value> * = nullptr> + constexpr const U &&operator*() const && { + TL_ASSERT(has_value()); + return std::move(val()); + } + template ::value> * = nullptr> + TL_EXPECTED_11_CONSTEXPR U &&operator*() && { + TL_ASSERT(has_value()); + return std::move(val()); + } + + constexpr bool has_value() const noexcept { return this->m_has_val; } + constexpr explicit operator bool() const noexcept { return this->m_has_val; } + + template ::value> * = nullptr> + TL_EXPECTED_11_CONSTEXPR const U &value() const & { + if (!has_value()) + detail::throw_exception(bad_expected_access(err().value())); + return val(); + } + template ::value> * = nullptr> + TL_EXPECTED_11_CONSTEXPR U &value() & { + if (!has_value()) + detail::throw_exception(bad_expected_access(err().value())); + return val(); + } + template ::value> * = nullptr> + TL_EXPECTED_11_CONSTEXPR const U &&value() const && { + if (!has_value()) + detail::throw_exception(bad_expected_access(std::move(err()).value())); + return std::move(val()); + } + template ::value> * = nullptr> + TL_EXPECTED_11_CONSTEXPR U &&value() && { + if (!has_value()) + detail::throw_exception(bad_expected_access(std::move(err()).value())); + return std::move(val()); + } + + constexpr const E &error() const & { + TL_ASSERT(!has_value()); + return err().value(); + } + TL_EXPECTED_11_CONSTEXPR E &error() & { + TL_ASSERT(!has_value()); + return err().value(); + } + constexpr const E &&error() const && { + TL_ASSERT(!has_value()); + return std::move(err().value()); + } + TL_EXPECTED_11_CONSTEXPR E &&error() && { + TL_ASSERT(!has_value()); + return std::move(err().value()); + } + + template constexpr T value_or(U &&v) const & { + static_assert(std::is_copy_constructible::value && + std::is_convertible::value, + "T must be copy-constructible and convertible to from U&&"); + return bool(*this) ? **this : static_cast(std::forward(v)); + } + template TL_EXPECTED_11_CONSTEXPR T value_or(U &&v) && { + static_assert(std::is_move_constructible::value && + std::is_convertible::value, + "T must be move-constructible and convertible to from U&&"); + return bool(*this) ? std::move(**this) : static_cast(std::forward(v)); + } +}; + +namespace detail { +template using exp_t = typename detail::decay_t::value_type; +template using err_t = typename detail::decay_t::error_type; +template using ret_t = expected>; + +#ifdef TL_EXPECTED_CXX14 +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + *std::declval()))> +constexpr auto and_then_impl(Exp &&exp, F &&f) { + static_assert(detail::is_expected::value, "F must return an expected"); + + return exp.has_value() + ? detail::invoke(std::forward(f), *std::forward(exp)) + : Ret(unexpect, std::forward(exp).error()); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval()))> +constexpr auto and_then_impl(Exp &&exp, F &&f) { + static_assert(detail::is_expected::value, "F must return an expected"); + + return exp.has_value() ? detail::invoke(std::forward(f)) + : Ret(unexpect, std::forward(exp).error()); +} +#else +template struct TC; +template (), + *std::declval())), + detail::enable_if_t>::value> * = nullptr> +auto and_then_impl(Exp &&exp, F &&f) -> Ret { + static_assert(detail::is_expected::value, "F must return an expected"); + + return exp.has_value() + ? detail::invoke(std::forward(f), *std::forward(exp)) + : Ret(unexpect, std::forward(exp).error()); +} + +template ())), + detail::enable_if_t>::value> * = nullptr> +constexpr auto and_then_impl(Exp &&exp, F &&f) -> Ret { + static_assert(detail::is_expected::value, "F must return an expected"); + + return exp.has_value() ? detail::invoke(std::forward(f)) + : Ret(unexpect, std::forward(exp).error()); +} +#endif + +#ifdef TL_EXPECTED_CXX14 +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + *std::declval())), + detail::enable_if_t::value> * = nullptr> +constexpr auto expected_map_impl(Exp &&exp, F &&f) { + using result = ret_t>; + return exp.has_value() ? result(detail::invoke(std::forward(f), + *std::forward(exp))) + : result(unexpect, std::forward(exp).error()); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + *std::declval())), + detail::enable_if_t::value> * = nullptr> +auto expected_map_impl(Exp &&exp, F &&f) { + using result = expected>; + if (exp.has_value()) { + detail::invoke(std::forward(f), *std::forward(exp)); + return result(); + } + + return result(unexpect, std::forward(exp).error()); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval())), + detail::enable_if_t::value> * = nullptr> +constexpr auto expected_map_impl(Exp &&exp, F &&f) { + using result = ret_t>; + return exp.has_value() ? result(detail::invoke(std::forward(f))) + : result(unexpect, std::forward(exp).error()); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval())), + detail::enable_if_t::value> * = nullptr> +auto expected_map_impl(Exp &&exp, F &&f) { + using result = expected>; + if (exp.has_value()) { + detail::invoke(std::forward(f)); + return result(); + } + + return result(unexpect, std::forward(exp).error()); +} +#else +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + *std::declval())), + detail::enable_if_t::value> * = nullptr> + +constexpr auto expected_map_impl(Exp &&exp, F &&f) + -> ret_t> { + using result = ret_t>; + + return exp.has_value() ? result(detail::invoke(std::forward(f), + *std::forward(exp))) + : result(unexpect, std::forward(exp).error()); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + *std::declval())), + detail::enable_if_t::value> * = nullptr> + +auto expected_map_impl(Exp &&exp, F &&f) -> expected> { + if (exp.has_value()) { + detail::invoke(std::forward(f), *std::forward(exp)); + return {}; + } + + return unexpected>(std::forward(exp).error()); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval())), + detail::enable_if_t::value> * = nullptr> + +constexpr auto expected_map_impl(Exp &&exp, F &&f) + -> ret_t> { + using result = ret_t>; + + return exp.has_value() ? result(detail::invoke(std::forward(f))) + : result(unexpect, std::forward(exp).error()); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval())), + detail::enable_if_t::value> * = nullptr> + +auto expected_map_impl(Exp &&exp, F &&f) -> expected> { + if (exp.has_value()) { + detail::invoke(std::forward(f)); + return {}; + } + + return unexpected>(std::forward(exp).error()); +} +#endif + +#if defined(TL_EXPECTED_CXX14) && !defined(TL_EXPECTED_GCC49) && \ + !defined(TL_EXPECTED_GCC54) && !defined(TL_EXPECTED_GCC55) +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +constexpr auto map_error_impl(Exp &&exp, F &&f) { + using result = expected, detail::decay_t>; + return exp.has_value() + ? result(*std::forward(exp)) + : result(unexpect, detail::invoke(std::forward(f), + std::forward(exp).error())); +} +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +auto map_error_impl(Exp &&exp, F &&f) { + using result = expected, monostate>; + if (exp.has_value()) { + return result(*std::forward(exp)); + } + + detail::invoke(std::forward(f), std::forward(exp).error()); + return result(unexpect, monostate{}); +} +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +constexpr auto map_error_impl(Exp &&exp, F &&f) { + using result = expected, detail::decay_t>; + return exp.has_value() + ? result() + : result(unexpect, detail::invoke(std::forward(f), + std::forward(exp).error())); +} +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +auto map_error_impl(Exp &&exp, F &&f) { + using result = expected, monostate>; + if (exp.has_value()) { + return result(); + } + + detail::invoke(std::forward(f), std::forward(exp).error()); + return result(unexpect, monostate{}); +} +#else +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +constexpr auto map_error_impl(Exp &&exp, F &&f) + -> expected, detail::decay_t> { + using result = expected, detail::decay_t>; + + return exp.has_value() + ? result(*std::forward(exp)) + : result(unexpect, detail::invoke(std::forward(f), + std::forward(exp).error())); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +auto map_error_impl(Exp &&exp, F &&f) -> expected, monostate> { + using result = expected, monostate>; + if (exp.has_value()) { + return result(*std::forward(exp)); + } + + detail::invoke(std::forward(f), std::forward(exp).error()); + return result(unexpect, monostate{}); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +constexpr auto map_error_impl(Exp &&exp, F &&f) + -> expected, detail::decay_t> { + using result = expected, detail::decay_t>; + + return exp.has_value() + ? result() + : result(unexpect, detail::invoke(std::forward(f), + std::forward(exp).error())); +} + +template >::value> * = nullptr, + class Ret = decltype(detail::invoke(std::declval(), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +auto map_error_impl(Exp &&exp, F &&f) -> expected, monostate> { + using result = expected, monostate>; + if (exp.has_value()) { + return result(); + } + + detail::invoke(std::forward(f), std::forward(exp).error()); + return result(unexpect, monostate{}); +} +#endif + +#ifdef TL_EXPECTED_CXX14 +template (), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +constexpr auto or_else_impl(Exp &&exp, F &&f) { + static_assert(detail::is_expected::value, "F must return an expected"); + return exp.has_value() ? std::forward(exp) + : detail::invoke(std::forward(f), + std::forward(exp).error()); +} + +template (), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +detail::decay_t or_else_impl(Exp &&exp, F &&f) { + return exp.has_value() ? std::forward(exp) + : (detail::invoke(std::forward(f), + std::forward(exp).error()), + std::forward(exp)); +} +#else +template (), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +auto or_else_impl(Exp &&exp, F &&f) -> Ret { + static_assert(detail::is_expected::value, "F must return an expected"); + return exp.has_value() ? std::forward(exp) + : detail::invoke(std::forward(f), + std::forward(exp).error()); +} + +template (), + std::declval().error())), + detail::enable_if_t::value> * = nullptr> +detail::decay_t or_else_impl(Exp &&exp, F &&f) { + return exp.has_value() ? std::forward(exp) + : (detail::invoke(std::forward(f), + std::forward(exp).error()), + std::forward(exp)); +} +#endif +} // namespace detail + +template +constexpr bool operator==(const expected &lhs, + const expected &rhs) { + return (lhs.has_value() != rhs.has_value()) + ? false + : (!lhs.has_value() ? lhs.error() == rhs.error() : *lhs == *rhs); +} +template +constexpr bool operator!=(const expected &lhs, + const expected &rhs) { + return (lhs.has_value() != rhs.has_value()) + ? true + : (!lhs.has_value() ? lhs.error() != rhs.error() : *lhs != *rhs); +} +template +constexpr bool operator==(const expected &lhs, + const expected &rhs) { + return (lhs.has_value() != rhs.has_value()) + ? false + : (!lhs.has_value() ? lhs.error() == rhs.error() : true); +} +template +constexpr bool operator!=(const expected &lhs, + const expected &rhs) { + return (lhs.has_value() != rhs.has_value()) + ? true + : (!lhs.has_value() ? lhs.error() == rhs.error() : false); +} + +template +constexpr bool operator==(const expected &x, const U &v) { + return x.has_value() ? *x == v : false; +} +template +constexpr bool operator==(const U &v, const expected &x) { + return x.has_value() ? *x == v : false; +} +template +constexpr bool operator!=(const expected &x, const U &v) { + return x.has_value() ? *x != v : true; +} +template +constexpr bool operator!=(const U &v, const expected &x) { + return x.has_value() ? *x != v : true; +} + +template +constexpr bool operator==(const expected &x, const unexpected &e) { + return x.has_value() ? false : x.error() == e.value(); +} +template +constexpr bool operator==(const unexpected &e, const expected &x) { + return x.has_value() ? false : x.error() == e.value(); +} +template +constexpr bool operator!=(const expected &x, const unexpected &e) { + return x.has_value() ? true : x.error() != e.value(); +} +template +constexpr bool operator!=(const unexpected &e, const expected &x) { + return x.has_value() ? true : x.error() != e.value(); +} + +template ::value || + std::is_move_constructible::value) && + detail::is_swappable::value && + std::is_move_constructible::value && + detail::is_swappable::value> * = nullptr> +void swap(expected &lhs, + expected &rhs) noexcept(noexcept(lhs.swap(rhs))) { + lhs.swap(rhs); +} +} // namespace nonstd + +#endif