зеркало из https://github.com/mozilla/marian.git
update to marian-dev
This commit is contained in:
Коммит
4b23fe76ff
|
@ -1,6 +1,7 @@
|
|||
# Config files from CMake
|
||||
src/common/project_version.h
|
||||
src/common/git_revision.h
|
||||
src/common/build_info.cpp
|
||||
|
||||
*.vcxproj.user
|
||||
/vs/x64
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
[submodule "examples"]
|
||||
path = examples
|
||||
url = https://github.com/marian-nmt/marian-examples
|
||||
[submodule "regression-tests"]
|
||||
path = regression-tests
|
||||
url = https://github.com/marian-nmt/marian-regression-tests
|
||||
[submodule "src/3rd_party/sentencepiece"]
|
||||
path = src/3rd_party/sentencepiece
|
||||
url = https://github.com/marian-nmt/sentencepiece
|
||||
[submodule "src/3rd_party/nccl"]
|
||||
path = src/3rd_party/nccl
|
||||
url = https://github.com/marian-nmt/nccl
|
||||
[submodule "src/3rd_party/fbgemm"]
|
||||
path = src/3rd_party/fbgemm
|
||||
url = https://github.com/marian-nmt/FBGEMM
|
||||
branch = master
|
||||
|
|
100
CHANGELOG.md
100
CHANGELOG.md
|
@ -5,23 +5,109 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
||||
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Automatic detection of CPU intrisics when building with -arch=native
|
||||
- An option to print cached variables from CMake
|
||||
- Add support for compiling on Mac (and clang)
|
||||
- An option for resetting stalled validation metrics
|
||||
- Add CMAKE options to disable compilation for specific GPU SM types
|
||||
- An option to print word-level translation scores
|
||||
- An option to turn off automatic detokenization from SentencePiece
|
||||
- Separate quantization types for 8-bit FBGEMM for AVX2 and AVX512
|
||||
- Sequence-level unliklihood training
|
||||
- Allow file name templated valid-translation-output files
|
||||
- Support for lexical shortlists in marian-server
|
||||
- Support for 8-bit matrix multiplication with FBGEMM
|
||||
- CMakeLists.txt now looks for SSE 4.2
|
||||
- Purging of finished hypotheses during beam-search. A lot faster for large batches.
|
||||
- Faster option look-up, up to 20-30% faster translation
|
||||
- Added --cite and --authors flag
|
||||
- Added optional support for ccache
|
||||
- Switch to change abort to exception, only to be used in library mode
|
||||
- Support for 16-bit packed models with FBGEMM
|
||||
- Multiple separated parameter types in ExpressionGraph, currently inference-only
|
||||
- Safe handling of sigterm signal
|
||||
- Automatic vectorization of elementwise operations on CPU for tensors dims that
|
||||
are divisible by 4 (AVX) and 8 (AVX2)
|
||||
- Replacing std::shared_ptr<T> with custom IntrusivePtr<T> for small objects like
|
||||
Tensors, Hypotheses and Expressions.
|
||||
- Fp16 inference working for translation
|
||||
- Gradient-checkpointing
|
||||
|
||||
### Fixed
|
||||
- Windows build with recent changes
|
||||
- Bug with read-ahead buffer
|
||||
- Fixed handling of "dump-config: false" in YAML config
|
||||
- Errors due to warnings
|
||||
- Fixed issue concerning failed saving with single GPU training and --sync-sgd option.
|
||||
- Replace value for INVALID_PATH_SCORE with std::numer_limits<float>::lowest()
|
||||
to avoid overflow with long sequences
|
||||
- Break up potential circular references for GraphGroup*
|
||||
- Fix empty source batch entries with batch purging
|
||||
- Clear RNN chache in transformer model, add correct hash functions to nodes
|
||||
- Gather-operation for all index sizes
|
||||
- Fix word weighting with max length cropping
|
||||
- Fixed compilation on CPUs without support for AVX
|
||||
- FastOpt now reads "n" and "y" values as strings, not as boolean values
|
||||
- Fixed multiple reduction kernels on GPU
|
||||
- Fixed guided-alignment training with cross-entropy
|
||||
- Replace IntrusivePtr with std::uniq_ptr in FastOpt, fixes random segfaults
|
||||
due to thread-non-safty of reference counting.
|
||||
- Make sure that items are 256-byte aligned during saving
|
||||
- Make explicit matmul functions respect setting of cublasMathMode
|
||||
- Fix memory mapping for mixed paramter models
|
||||
- Removed naked pointer and potential memory-leak from file_stream.{cpp,h}
|
||||
- Compilation for GCC >= 7 due to exception thrown in destructor
|
||||
- Sort parameters by lexicographical order during allocation to ensure consistent
|
||||
memory-layout during allocation, loading, saving.
|
||||
- Output empty line when input is empty line. Previous behavior might result in
|
||||
hallucinated outputs.
|
||||
- Compilation with CUDA 10.1
|
||||
|
||||
### Changed
|
||||
- Combine two for-loops in nth_element.cpp on CPU
|
||||
- Revert LayerNorm eps to old position, i.e. sigma' = sqrt(sigma^2 + eps)
|
||||
- Downgrade NCCL to 2.3.7 as 2.4.2 is buggy (hangs with larger models)
|
||||
- Return error signal on SIGTERM
|
||||
- Dropped support for CUDA 8.0, CUDA 9.0 is now minimal requirement
|
||||
- Removed autotuner for now, will be switched back on later
|
||||
- Boost depdendency is now optional and only required for marian_server
|
||||
- Dropped support for g++-4.9
|
||||
- Simplified file stream and temporary file handling
|
||||
- Unified node intializers, same function API.
|
||||
- Remove overstuff/understuff code
|
||||
|
||||
## [1.8.0] - 2019-09-04
|
||||
|
||||
### Added
|
||||
- Alias options and new --task option
|
||||
- Automatic detection of CPU intrisics when building with -arch=native
|
||||
- First version of BERT-training and BERT-classifier, currently not compatible with TF models
|
||||
- New reduction operators
|
||||
- Use Cmake's ExternalProject to build NCCL and potentially other external libs
|
||||
- Code for Factored Vocabulary, currently not usable yet without outside tools
|
||||
|
||||
### Fixed
|
||||
- Issue with relative paths in automatically generated decoder config files
|
||||
- Bug with overlapping CXX flags and building spm_train executable
|
||||
- Compilation with gcc 8
|
||||
- Overwriting and unsetting vector options
|
||||
- Windows build with recent changes
|
||||
- Bug with read-ahead buffer
|
||||
- Handling of "dump-config: false" in YAML config
|
||||
- Errors due to warnings
|
||||
- Issue concerning failed saving with single GPU training and --sync-sgd option.
|
||||
- NaN problem when training with Tensor Cores on Volta GPUs
|
||||
- Fix pipe-handling
|
||||
- Fix compilation with GCC 9.1
|
||||
- Fix CMake build types
|
||||
|
||||
### Changed
|
||||
- Error message when using left-to-right and right-to-left models together in ensembles
|
||||
- Regression tests included as a submodule
|
||||
- Update NCCL to 2.4.2
|
||||
- Add zlib source to Marian's source tree, builds now as object lib
|
||||
- -DUSE_STATIC_LIBS=on now also looks for static versions of CUDA libraries
|
||||
- Include NCCL build from github.com/marian-nmt/nccl and compile within source tree
|
||||
- Set nearly all warnings as errors for Marian's own targets. Disable warnings for 3rd party.
|
||||
- Set nearly all warnings as errors for Marian's own targets. Disable warnings for 3rd party
|
||||
- Refactored beam search
|
||||
|
||||
## [1.7.0] - 2018-11-27
|
||||
|
||||
|
|
241
CMakeLists.txt
241
CMakeLists.txt
|
@ -5,7 +5,6 @@ if (POLICY CMP0074)
|
|||
cmake_policy(SET CMP0074 NEW) # CMake 3.12
|
||||
endif ()
|
||||
|
||||
|
||||
project(marian CXX C)
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
@ -14,14 +13,33 @@ set(BUILD_ARCH native CACHE STRING "Compile for this CPU architecture.")
|
|||
# Custom CMake options
|
||||
option(COMPILE_CPU "Compile CPU version" ON)
|
||||
option(COMPILE_CUDA "Compile GPU version" ON)
|
||||
option(COMPILE_CUDA_SM35 "Compile GPU version with SM35 support" ON)
|
||||
option(COMPILE_CUDA_SM50 "Compile GPU version with SM50 support" ON)
|
||||
option(COMPILE_CUDA_SM60 "Compile GPU version with SM60 support" ON)
|
||||
option(COMPILE_CUDA_SM70 "Compile GPU version with SM70 support" ON)
|
||||
option(COMPILE_EXAMPLES "Compile examples" OFF)
|
||||
option(COMPILE_SERVER "Compile marian-server" OFF)
|
||||
option(COMPILE_TESTS "Compile tests" OFF)
|
||||
option(USE_CCACHE "Use ccache compiler cache (https://ccache.dev)" OFF)
|
||||
option(USE_CUDNN "Use CUDNN library" OFF)
|
||||
option(USE_DOXYGEN "Build documentation with Doxygen" ON)
|
||||
option(USE_FBGEMM "Use FBGEMM" OFF)
|
||||
option(USE_MKL "Compile with MKL support" ON)
|
||||
option(USE_MPI "Use MPI library" OFF)
|
||||
option(USE_NCCL "Use NCCL library" ON)
|
||||
option(USE_SENTENCEPIECE "Download and compile SentencePiece" OFF)
|
||||
option(USE_STATIC_LIBS "Link statically against non-system libs" OFF)
|
||||
option(USE_CUDNN "Use CUDNN library" OFF)
|
||||
option(USE_NCCL "Use NCCL library" ON)
|
||||
option(USE_MPI "Use MPI library" OFF)
|
||||
option(COMPILE_EXAMPLES "Compile examples" OFF)
|
||||
option(COMPILE_TESTS "Compile tests" OFF)
|
||||
option(COMPILE_SERVER "Compile marian-server" ON)
|
||||
|
||||
# use ccache (https://ccache.dev) for faster compilation if requested and available
|
||||
if(USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
message(STATUS "Will be using ccache for faster repeat compilation (use cmake -DUSE_CCACHE=off to disable).")
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}")
|
||||
else(CCACHE_PROGRAM)
|
||||
message(WARNING "Compilation with ccache requested but no ccache found.")
|
||||
endif(CCACHE_PROGRAM)
|
||||
endif(USE_CCACHE)
|
||||
|
||||
# Project versioning
|
||||
find_package(Git QUIET)
|
||||
|
@ -33,6 +51,12 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}")
|
|||
execute_process(COMMAND git submodule update --init --recursive --no-fetch
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
||||
###############################################################################
|
||||
# Set compilation flags
|
||||
if(MSVC)
|
||||
# These are used in src/CMakeLists.txt on a per-target basis
|
||||
|
@ -59,60 +83,110 @@ if(MSVC)
|
|||
|
||||
find_library(SHLWAPI Shlwapi.lib)
|
||||
set(EXT_LIBS ${EXT_LIBS} SHLWAPI)
|
||||
else()
|
||||
else(MSVC)
|
||||
|
||||
# Check we are using at least g++ 5.0
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 5.0)
|
||||
message(FATAL_ERROR "FATAL ERROR: Compiling Marian requires at least g++ 5.0, your version is ${CMAKE_CXX_COMPILER_VERSION}")
|
||||
endif()
|
||||
|
||||
# Detect support CPU instrinsics for the current platform. This will
|
||||
# only by used with BUILD_ARCH=native. For overridden BUILD_ARCH we
|
||||
# minimally use -msse4.1. This seems to work with MKL.
|
||||
set(INTRINSICS "")
|
||||
list(APPEND INTRINSICS_NVCC)
|
||||
|
||||
if(BUILD_ARCH STREQUAL "native")
|
||||
message(STATUS "Checking support for CPU intrinsics")
|
||||
include(FindSSE)
|
||||
if(SSE2_FOUND)
|
||||
message(STATUS "SSE2 support found")
|
||||
set(INTRINSICS "${INTRINSICS} -msse2")
|
||||
list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse2)
|
||||
endif(SSE2_FOUND)
|
||||
if(SSE3_FOUND)
|
||||
message(STATUS "SSE3 support found")
|
||||
set(INTRINSICS "${INTRINSICS} -msse3")
|
||||
list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse3)
|
||||
endif(SSE3_FOUND)
|
||||
if(SSE4_1_FOUND)
|
||||
message(STATUS "SSE4.1 support found")
|
||||
set(INTRINSICS "${INTRINSICS} -msse4.1")
|
||||
list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse4.1)
|
||||
endif(SSE4_1_FOUND)
|
||||
if(SSE4_2_FOUND)
|
||||
message(STATUS "SSE4.2 support found")
|
||||
set(INTRINSICS "${INTRINSICS} -msse4.2")
|
||||
list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse4.2)
|
||||
endif(SSE4_2_FOUND)
|
||||
if(AVX_FOUND)
|
||||
message(STATUS "AVX support found")
|
||||
set(INTRINSICS "${INTRINSICS} -mavx")
|
||||
list(APPEND INTRINSICS_NVCC -Xcompiler\ -mavx)
|
||||
endif(AVX_FOUND)
|
||||
if(AVX2_FOUND)
|
||||
message(STATUS "AVX2 support found")
|
||||
set(INTRINSICS "${INTRINSICS} -mavx2")
|
||||
list(APPEND INTRINSICS_NVCC -Xcompiler\ -mavx2)
|
||||
endif(AVX2_FOUND)
|
||||
if(AVX512_FOUND)
|
||||
message(STATUS "AVX512 support found")
|
||||
set(INTRINSICS "${INTRINSICS} -mavx512f")
|
||||
list(APPEND INTRINSICS_NVCC -Xcompiler\ -mavx512f)
|
||||
endif(AVX512_FOUND)
|
||||
else()
|
||||
set(INTRINSICS "-msse4.1")
|
||||
endif()
|
||||
|
||||
set(DISABLE_GLOBALLY "-Wno-unused-result")
|
||||
if(USE_FBGEMM)
|
||||
set(EXT_LIBS ${EXT_LIBS} fbgemm dl)
|
||||
add_definitions(-DUSE_FBGEMM=1)
|
||||
endif(USE_FBGEMM)
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
# Clang-10.0.0 complains when CUDA is newer than 10.1
|
||||
set(CLANG_IGNORE_UNKNOWN_CUDA "-Wno-unknown-cuda-version")
|
||||
endif()
|
||||
set(DISABLE_GLOBALLY "-Wno-unused-result -Wno-unknown-warning-option ${CLANG_IGNORE_UNKNOWN_CUDA}")
|
||||
|
||||
# These are used in src/CMakeLists.txt on a per-target basis
|
||||
list(APPEND ALL_WARNINGS -Wall; -Werror; -Wno-unused-result; -Wno-deprecated; -Wno-pragmas; -Wno-unused-parameter; -Wextra; -Wno-unused-function;
|
||||
-Wno-unused-value; -Wno-unknown-pragmas; -Wno-sign-compare; -Wno-missing-field-initializers;)
|
||||
list(APPEND ALL_WARNINGS -Wall; -Werror; -Wextra; -Wno-unused-result; -Wno-deprecated;
|
||||
-Wno-pragmas; -Wno-unused-parameter; -Wno-unused-function;
|
||||
-Wno-unused-value; -Wno-unknown-pragmas; -Wno-sign-compare;
|
||||
-Wno-missing-field-initializers;)
|
||||
|
||||
# This warning does not exist prior to gcc 5.0
|
||||
if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0)
|
||||
list(APPEND ALL_WARNINGS -Wsuggest-override)
|
||||
list(APPEND ALL_WARNINGS -Wsuggest-override -Wno-int-in-bool-context)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "-std=c++11 -O3 -Ofast -m64 -pthread -march=${BUILD_ARCH} ${INTRINSICS} -Wl,--no-as-needed -funroll-loops -ffinite-math-only -fPIC ${DISABLE_GLOBALLY}")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} -g -rdynamic")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "-std=c++11 -g -rdynamic -O0 -pthread -Wl,--no-as-needed -fPIC -Wno-unused-result -Wno-deprecated -Wno-pragmas")
|
||||
set(CMAKE_CXX_FLAGS_SLIM "${CMAKE_CXX_FLAGS} -DNDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS} -g -rdynamic")
|
||||
set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg -g -rdynamic")
|
||||
if(CMAKE_COMPILER_IS_GNUCC)
|
||||
# these flags are not known to clang
|
||||
set(CMAKE_GCC_FLAGS "-Wl,--no-as-needed")
|
||||
set(CMAKE_RDYNAMIC_FLAG "-rdynamic")
|
||||
endif(CMAKE_COMPILER_IS_GNUCC)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "-std=c++11 -pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_CXX_FLAGS_SLIM "-Ofast -m64 -funroll-loops -ffinite-math-only -DNDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg")
|
||||
set(CMAKE_CXX_FLAGS_PROFGEN "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-generate -fprofile-correction")
|
||||
set(CMAKE_CXX_FLAGS_PROFUSE "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
|
||||
endif()
|
||||
|
||||
# these need to be set separately
|
||||
set(CMAKE_C_FLAGS "-pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}")
|
||||
set(CMAKE_C_FLAGS_RELEASE "-O3 -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_C_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_C_FLAGS_SLIM "-O3 -m64 -funroll-loops -ffinite-math-only -DNDEBUG")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELEASE}")
|
||||
set(CMAKE_C_FLAGS_PROFILE "${CMAKE_C_FLAGS_RELEASE} -pg")
|
||||
set(CMAKE_C_FLAGS_PROFGEN "${CMAKE_C_FLAGS_RELEASE} -fprofile-generate -fprofile-correction")
|
||||
set(CMAKE_C_FLAGS_PROFUSE "${CMAKE_C_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
|
||||
endif(MSVC)
|
||||
|
||||
###############################################################################
|
||||
# Downloading SentencePiece if requested and set to compile with it.
|
||||
# Requires all the dependencies imposed by SentencePiece
|
||||
if(USE_SENTENCEPIECE)
|
||||
|
@ -121,10 +195,10 @@ if(USE_SENTENCEPIECE)
|
|||
set(EXT_LIBS ${EXT_LIBS} sentencepiece sentencepiece_train)
|
||||
endif()
|
||||
|
||||
|
||||
# Find packages
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CMAKE_DL_LIBS})
|
||||
|
||||
###############################################################################
|
||||
if(COMPILE_CUDA)
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
|
@ -140,12 +214,37 @@ if(USE_STATIC_LIBS)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
find_package(CUDA "8.0")
|
||||
find_package(CUDA "9.0") # TODO: only enable FP16-related options for compute_70 and higher.
|
||||
if(CUDA_FOUND)
|
||||
# CUDA >= 10.0 requires CMake >= 3.12.2
|
||||
if((CUDA_VERSION VERSION_EQUAL "10.0" OR CUDA_VERSION VERSION_GREATER "10.0") AND (CMAKE_VERSION VERSION_LESS "3.12.2"))
|
||||
message(WARNING "On some Unix systems CUDA 10.0+ requires CMake 3.12.2+; you use CMake ${CMAKE_VERSION}")
|
||||
endif()
|
||||
|
||||
if(COMPILE_CUDA_SM35)
|
||||
LIST(APPEND COMPUTE -arch=sm_35; -gencode=arch=compute_35,code=sm_35;) # Tesla K40 and above
|
||||
endif(COMPILE_CUDA_SM35)
|
||||
if(COMPILE_CUDA_SM50)
|
||||
LIST(APPEND COMPUTE -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52;) # Maxwell GPUs
|
||||
endif(COMPILE_CUDA_SM50)
|
||||
if(COMPILE_CUDA_SM60)
|
||||
LIST(APPEND COMPUTE -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61;) # Pascal GPUs
|
||||
endif(COMPILE_CUDA_SM60)
|
||||
if(COMPILE_CUDA_SM70)
|
||||
LIST(APPEND COMPUTE -gencode=arch=compute_70,code=sm_70; -gencode=arch=compute_70,code=compute_70) # Volta GPUs
|
||||
endif(COMPILE_CUDA_SM70)
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
message(STATUS "Found CUDA libraries: ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}")
|
||||
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
# CUDA 10.1 introduces cublasLt library that is required on static build
|
||||
if ((CUDA_VERSION VERSION_EQUAL "10.1" OR CUDA_VERSION VERSION_GREATER "10.1"))
|
||||
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||
endif()
|
||||
message(STATUS "Found CUDA libraries: ${CUDA_LIBS}")
|
||||
else(USE_STATIC_LIBS)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
message(STATUS "Found CUDA libraries: ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}")
|
||||
|
@ -168,41 +267,11 @@ endif(USE_STATIC_LIBS)
|
|||
list(APPEND CUDA_NVCC_FLAGS -DBOOST_PP_VARIADICS=0; )
|
||||
endif()
|
||||
|
||||
# We compile NCCL ourselves, using the NVidia Makefile rather than CMake, this requires to pass a couple of parameters from
|
||||
# Cmake. This is also fairly untested, let's hope it does not explode.
|
||||
# @TODO: Make sure it does not use pre-installed NCCL headers
|
||||
if(USE_NCCL)
|
||||
# define and set the include dir for the generated nccl.h header
|
||||
set(NCCL_HEADER_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/nccl/include")
|
||||
include_directories(${NCCL_HEADER_LOCATION})
|
||||
|
||||
# set the path for the generated static lib
|
||||
set(NCCL_LIB_STATIC "${CMAKE_CURRENT_BINARY_DIR}/nccl/lib/libnccl_static.a")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_NCCL")
|
||||
|
||||
LIST(APPEND CUDA_NVCC_FLAGS -DUSE_NCCL; )
|
||||
|
||||
# disables compilation for sm_30 to avoid ptxas warning... that's general Kepler support. But K80s are supported for instance by sm_35
|
||||
set(GENCODE "-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61")
|
||||
|
||||
# We build using NVidia's custom makefile, for that we pass a number of variables from CMake.
|
||||
# Sets output to the chosen build folder, i.e. where the binaries and objects are generated.
|
||||
# Also passes CUDA location from FindCUDA, sets c++ compiler to the same one CMake uses.
|
||||
add_custom_command(OUTPUT ${NCCL_LIB_STATIC}
|
||||
COMMAND ${CMAKE_MAKE_PROGRAM} src.build
|
||||
BUILDDIR=${CMAKE_CURRENT_BINARY_DIR}/nccl
|
||||
CUDA_HOME=${CUDA_TOOLKIT_ROOT_DIR}
|
||||
CUDA8_GENCODE=${GENCODE}
|
||||
CXX=${CMAKE_CXX_COMPILER}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/3rd_party/nccl)
|
||||
add_custom_target(nccl_target DEPENDS ${NCCL_LIB_STATIC})
|
||||
add_library(nccl STATIC IMPORTED)
|
||||
set_target_properties(nccl PROPERTIES IMPORTED_LOCATION ${NCCL_LIB_STATIC})
|
||||
add_dependencies(nccl nccl_target)
|
||||
set(EXT_LIBS ${EXT_LIBS} nccl)
|
||||
|
||||
# adds the resulting files to be removed by `make clean`
|
||||
set_directory_properties(PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_CURRENT_BINARY_DIR}/nccl)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_NCCL")
|
||||
LIST(APPEND CUDA_NVCC_FLAGS -DUSE_NCCL; )
|
||||
endif(USE_NCCL)
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
|
@ -210,21 +279,30 @@ if(USE_STATIC_LIBS)
|
|||
endif()
|
||||
|
||||
else(CUDA_FOUND)
|
||||
message(FATAL_ERROR "CUDA has not been found, set -DCOMPILE_CUDA=off to avoid this check and to compile the CPU version only")
|
||||
message("
|
||||
Cannot find suitable CUDA libraries. Specify the path explicitly with
|
||||
-DCUDA_TOOLKIT_ROOT_DIR=/path/to/appropriate/cuda/installation
|
||||
(hint: try /usr/local/$(readlink /usr/local/cuda))
|
||||
OR compile the CPU-only version of Marian with
|
||||
-DCOMPILE_CUDA=off
|
||||
")
|
||||
message(FATAL_ERROR "FATAL ERROR: No suitable CUDA library found.")
|
||||
endif(CUDA_FOUND)
|
||||
|
||||
else(COMPILE_CUDA)
|
||||
message(WARNING "COMPILE_CUDA=off : Building only CPU version")
|
||||
endif(COMPILE_CUDA)
|
||||
|
||||
# TODO: make compatible with older CUDA versions
|
||||
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O0; -g; -arch=sm_30; -gencode=arch=compute_30,code=sm_30; -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52; -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61; -gencode=arch=compute_61,code=compute_61 ;)
|
||||
list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O0; -g; --use_fast_math; ${COMPUTE})
|
||||
else(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O3; -g; --use_fast_math; -arch=sm_30; -gencode=arch=compute_30,code=sm_30; -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52; -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61; -gencode=arch=compute_61,code=compute_61 ;)
|
||||
list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O3; -g; --use_fast_math; ${COMPUTE})
|
||||
endif(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
if(NOT MSVC)
|
||||
# @TODO: add warnings here too
|
||||
list(APPEND CUDA_NVCC_FLAGS -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
|
||||
list(APPEND CUDA_NVCC_FLAGS -ccbin ${CMAKE_C_COMPILER}; -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
|
||||
list(APPEND CUDA_NVCC_FLAGS ${INTRINSICS_NVCC})
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; )
|
||||
endif()
|
||||
|
@ -241,6 +319,8 @@ if(USE_STATIC_LIBS)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
###############################################################################
|
||||
# Find Tcmalloc
|
||||
if(NOT WIN32)
|
||||
find_package(Tcmalloc)
|
||||
if(Tcmalloc_FOUND)
|
||||
|
@ -251,6 +331,8 @@ if(NOT WIN32)
|
|||
endif(Tcmalloc_FOUND)
|
||||
endif()
|
||||
|
||||
###############################################################################
|
||||
# Find MPI
|
||||
if(USE_MPI)
|
||||
find_package(MPI 2.0)
|
||||
if(MPI_FOUND)
|
||||
|
@ -260,38 +342,40 @@ if(USE_MPI)
|
|||
endif(MPI_FOUND)
|
||||
endif(USE_MPI)
|
||||
|
||||
###############################################################################
|
||||
# Find MKL
|
||||
if(COMPILE_CPU)
|
||||
if(USE_MKL)
|
||||
find_package(MKL)
|
||||
endif(USE_MKL)
|
||||
if(MKL_FOUND)
|
||||
include_directories(${MKL_INCLUDE_DIR})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${MKL_LIBRARIES})
|
||||
add_definitions(-DBLAS_FOUND=1 -DMKL_FOUND=1)
|
||||
else(MKL_FOUND)
|
||||
set(BLA_VENDOR "OpenBLAS")
|
||||
set(BLAS_VENDOR "OpenBLAS")
|
||||
find_package(BLAS)
|
||||
if(BLAS_FOUND)
|
||||
include_directories(${BLAS_INCLUDE_DIR})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${BLAS_LIBRARIES})
|
||||
include(FindCBLAS)
|
||||
if(CBLAS_FOUND)
|
||||
include_directories(${BLAS_INCLUDE_DIR} ${CBLAS_INCLUDE_DIR})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${BLAS_LIBRARIES} ${CBLAS_LIBRARIES})
|
||||
add_definitions(-DBLAS_FOUND=1)
|
||||
endif(CBLAS_FOUND)
|
||||
endif(BLAS_FOUND)
|
||||
endif(MKL_FOUND)
|
||||
endif(COMPILE_CPU)
|
||||
|
||||
set(BOOST_COMPONENTS timer iostreams filesystem system chrono)
|
||||
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 4.9)
|
||||
add_definitions(-DUSE_BOOST_REGEX=1)
|
||||
set(BOOST_COMPONENTS ${BOOST_COMPONENTS} regex)
|
||||
message(STATUS "Using boost::regex")
|
||||
else()
|
||||
message(STATUS "Using std::regex")
|
||||
endif()
|
||||
|
||||
###############################################################################
|
||||
# Find OpenSSL
|
||||
set(BOOST_COMPONENTS "")
|
||||
if(COMPILE_SERVER)
|
||||
find_package(OpenSSL)
|
||||
if(OpenSSL_FOUND)
|
||||
message(STATUS "Found OpenSSL")
|
||||
include_directories(${OPENSSL_INCLUDE_DIR})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${OPENSSL_CRYPTO_LIBRARY})
|
||||
set(BOOST_COMPONENTS ${BOOST_COMPONENTS} system)
|
||||
else(OpenSSL_FOUND)
|
||||
message(WARNING "Cannot find OpenSSL library. Not compiling server.")
|
||||
set(COMPILE_SERVER "off")
|
||||
|
@ -302,10 +386,14 @@ if(USE_STATIC_LIBS)
|
|||
set(CMAKE_FIND_LIBRARY_SUFFIXES ${_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||
endif()
|
||||
|
||||
# TODO: move inside if(BOOST_COMPONENTS) ?
|
||||
if(USE_STATIC_LIBS)
|
||||
set(Boost_USE_STATIC_LIBS ON)
|
||||
endif()
|
||||
|
||||
###############################################################################
|
||||
# Find Boost if required
|
||||
if(BOOST_COMPONENTS)
|
||||
find_package(Boost COMPONENTS ${BOOST_COMPONENTS})
|
||||
if(Boost_FOUND)
|
||||
include_directories(${Boost_INCLUDE_DIRS})
|
||||
|
@ -314,7 +402,9 @@ if(Boost_FOUND)
|
|||
else(Boost_FOUND)
|
||||
message(SEND_ERROR "Cannot find Boost libraries. Terminating.")
|
||||
endif(Boost_FOUND)
|
||||
endif(BOOST_COMPONENTS)
|
||||
|
||||
###############################################################################
|
||||
if(COMPILE_TESTS)
|
||||
enable_testing()
|
||||
endif(COMPILE_TESTS)
|
||||
|
@ -327,11 +417,18 @@ endif(COMPILE_EXAMPLES)
|
|||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h.in
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h @ONLY)
|
||||
|
||||
# Generate build_info.cpp with CMake cache variables
|
||||
include(GetCacheVariables)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp.in
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp @ONLY)
|
||||
|
||||
# Compile source files
|
||||
include_directories(${marian_SOURCE_DIR}/src)
|
||||
add_subdirectory(src)
|
||||
|
||||
|
||||
###############################################################################
|
||||
if(USE_DOXYGEN)
|
||||
# Add a target to generate API documentation with Doxygen
|
||||
find_package(Doxygen)
|
||||
if(DOXYGEN_FOUND)
|
||||
|
@ -343,4 +440,4 @@ if(DOXYGEN_FOUND)
|
|||
COMMENT "Generating API documentation with Doxygen" VERBATIM
|
||||
)
|
||||
endif(DOXYGEN_FOUND)
|
||||
|
||||
endif(USE_DOXYGEN)
|
||||
|
|
|
@ -1592,7 +1592,7 @@ PAPER_TYPE = a4
|
|||
# If left blank no extra packages will be included.
|
||||
# This tag requires that the tag GENERATE_LATEX is set to YES.
|
||||
|
||||
EXTRA_PACKAGES =
|
||||
EXTRA_PACKAGES = amsmath
|
||||
|
||||
# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the
|
||||
# generated LaTeX document. The header should contain everything until the first
|
||||
|
|
51
README.md
51
README.md
|
@ -1,40 +1,26 @@
|
|||
Marian
|
||||
======
|
||||
|
||||
[![Build Status CUDA 9](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cuda-9.2.svg?label=CUDA%209)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cuda-9.2/)
|
||||
[![Build Status CUDA 10](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cuda-10.1.svg?label=CUDA%2010)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cuda-10.1/)
|
||||
[![CPU Build Status](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cpu.svg?label=CPU)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cpu/)
|
||||
[![Build Status CPU](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cpu.svg?label=CPU)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cpu/)
|
||||
[![Tests Status](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-regression-tests.svg?label=tests)](http://vali.inf.ed.ac.uk/jenkins/job/marian-regression-tests/)
|
||||
[![Latest release](https://img.shields.io/github/release/marian-nmt/marian.svg?label=release)](https://github.com/marian-nmt/marian/releases)
|
||||
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE.md)
|
||||
[![Twitter](https://img.shields.io/twitter/follow/marian_nmt.svg?style=social)](https://twitter.com/intent/follow?screen_name=marian_nmt)
|
||||
|
||||
<p>
|
||||
<b>Marian</b> is an efficient Neural Machine Translation framework written
|
||||
in pure C++ with minimal dependencies.
|
||||
*Marian* is an efficient Neural Machine Translation framework written in pure
|
||||
C++ with minimal dependencies.
|
||||
|
||||
Named in honour of Marian Rejewski, a Polish mathematician and cryptologist.
|
||||
|
||||
<!--It has mainly been developed at the
|
||||
Adam Mickiewicz University in Poznań (AMU) and at the University of Edinburgh.-->
|
||||
</p>
|
||||
|
||||
<!--p>
|
||||
It is currently being deployed in
|
||||
multiple European projects and is the main translation and training engine
|
||||
behind the neural MT launch at the
|
||||
<a href="http://www.wipo.int/pressroom/en/articles/2016/article_0014.html">World Intellectual Property Organization</a>.
|
||||
</p-->
|
||||
|
||||
<p>
|
||||
Main features:
|
||||
<ul>
|
||||
<li> Fast multi-gpu training and translation </li>
|
||||
<li> Compatible with Nematus and DL4MT </li>
|
||||
<li> Efficient pure C++ implementation </li>
|
||||
<li> Permissive open source license (MIT) </li>
|
||||
<li> <a href="https://marian-nmt.github.io/features/"> more details... </a> </li>
|
||||
</ul>
|
||||
</p>
|
||||
|
||||
- Efficient pure C++ implementation
|
||||
- Fast multi-GPU training and GPU/CPU translation
|
||||
- State-of-the-art NMT architectures: deep RNN and transformer
|
||||
- Permissive open source license (MIT)
|
||||
- [more detail...](https://marian-nmt.github.io/features)
|
||||
|
||||
If you use this, please cite:
|
||||
|
||||
|
@ -59,20 +45,11 @@ Machine Translation in C++ (http://www.aclweb.org/anthology/P18-4020)
|
|||
url = {http://www.aclweb.org/anthology/P18-4020}
|
||||
}
|
||||
|
||||
<!--
|
||||
## Compilation
|
||||
|
||||
```
|
||||
cd marian-dev
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=Release
|
||||
make -j
|
||||
```
|
||||
-->
|
||||
|
||||
## Amun
|
||||
The handwritten decoder for RNN models compatible with Marian and Nematus has been superseded by the Marian decoder. The code is available in a separate repository: https://github.com/marian-nmt/amun
|
||||
|
||||
The handwritten decoder for RNN models compatible with Marian and Nematus has
|
||||
been superseded by the Marian decoder. The code is available in a separate
|
||||
repository: https://github.com/marian-nmt/amun
|
||||
|
||||
## Website
|
||||
|
||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
v1.7.6
|
||||
v1.8.52
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
# - Find CBLAS library
|
||||
#
|
||||
# This module finds an installed fortran library that implements the CBLAS
|
||||
# linear-algebra interface (see http://www.netlib.org/blas/), with CBLAS
|
||||
# interface.
|
||||
#
|
||||
# This module sets the following variables:
|
||||
# CBLAS_FOUND - set to true if a library implementing the CBLAS interface
|
||||
# is found
|
||||
# CBLAS_LINKER_FLAGS - uncached list of required linker flags (excluding -l
|
||||
# and -L).
|
||||
# CBLAS_LIBRARIES - uncached list of libraries (using full path name) to
|
||||
# link against to use CBLAS
|
||||
# CBLAS_INCLUDE_DIR - path to includes
|
||||
# CBLAS_INCLUDE_FILE - the file to be included to use CBLAS
|
||||
#
|
||||
|
||||
## Based on https://github.com/Eyescale/CMake/blob/master/FindCBLAS.cmake
|
||||
|
||||
INCLUDE(CheckFunctionExists)
|
||||
INCLUDE(CheckIncludeFile)
|
||||
|
||||
MACRO(CHECK_ALL_LIBRARIES LIBRARIES INCLUDE _prefix _name _flags _list _include _search_include)
|
||||
# This macro checks for the existence of the combination of fortran libraries
|
||||
# given by _list. If the combination is found, this macro checks (using the
|
||||
# Check_Fortran_Function_Exists macro) whether can link against that library
|
||||
# combination using the name of a routine given by _name using the linker
|
||||
# flags given by _flags. If the combination of libraries is found and passes
|
||||
# the link test, LIBRARIES is set to the list of complete library paths that
|
||||
# have been found. Otherwise, LIBRARIES is set to FALSE.
|
||||
|
||||
# N.B. _prefix is the prefix applied to the names of all cached variables that
|
||||
# are generated internally and marked advanced by this macro.
|
||||
|
||||
SET(__list)
|
||||
FOREACH(_elem ${_list})
|
||||
IF(__list)
|
||||
SET(__list "${__list} - ${_elem}")
|
||||
ELSE(__list)
|
||||
SET(__list "${_elem}")
|
||||
ENDIF(__list)
|
||||
ENDFOREACH(_elem)
|
||||
MESSAGE(STATUS "Checking for [${__list}]")
|
||||
SET(_libraries_work TRUE)
|
||||
SET(${LIBRARIES})
|
||||
SET(_combined_name)
|
||||
SET(_paths)
|
||||
FOREACH(_library ${_list})
|
||||
SET(_combined_name ${_combined_name}_${_library})
|
||||
|
||||
# did we find all the libraries in the _list until now?
|
||||
# (we stop at the first unfound one)
|
||||
IF(_libraries_work)
|
||||
IF(APPLE)
|
||||
FIND_LIBRARY(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library}
|
||||
PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 /usr/local/opt/openblas/lib ENV
|
||||
DYLD_LIBRARY_PATH
|
||||
)
|
||||
ELSE(APPLE)
|
||||
FIND_LIBRARY(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library}
|
||||
PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 ENV
|
||||
LD_LIBRARY_PATH
|
||||
)
|
||||
ENDIF(APPLE)
|
||||
MARK_AS_ADVANCED(${_prefix}_${_library}_LIBRARY)
|
||||
IF(${_prefix}_${_library}_LIBRARY)
|
||||
GET_FILENAME_COMPONENT(_path ${${_prefix}_${_library}_LIBRARY} PATH)
|
||||
LIST(APPEND _paths ${_path}/../include ${_path}/../../include)
|
||||
ENDIF(${_prefix}_${_library}_LIBRARY)
|
||||
SET(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY})
|
||||
SET(_libraries_work ${${_prefix}_${_library}_LIBRARY})
|
||||
ENDIF(_libraries_work)
|
||||
ENDFOREACH(_library ${_list})
|
||||
|
||||
# Test include
|
||||
SET(_bug_search_include ${_search_include}) #CMAKE BUG!!! SHOULD NOT BE THAT
|
||||
IF(_bug_search_include)
|
||||
FIND_PATH(${_prefix}${_combined_name}_INCLUDE ${_include} ${_paths})
|
||||
MARK_AS_ADVANCED(${_prefix}${_combined_name}_INCLUDE)
|
||||
IF(${_prefix}${_combined_name}_INCLUDE)
|
||||
MESSAGE(STATUS "Checking for [${__list}] -- includes found")
|
||||
SET(${_prefix}_INCLUDE_DIR ${${_prefix}${_combined_name}_INCLUDE})
|
||||
SET(${_prefix}_INCLUDE_FILE ${_include})
|
||||
SET(${INCLUDE} ${${_prefix}_INCLUDE_DIR})
|
||||
ELSE(${_prefix}${_combined_name}_INCLUDE)
|
||||
MESSAGE(STATUS "Checking for [${__list}] -- includes not found")
|
||||
SET(_libraries_work FALSE)
|
||||
ENDIF(${_prefix}${_combined_name}_INCLUDE)
|
||||
ELSE(_bug_search_include)
|
||||
SET(${_prefix}_INCLUDE_DIR)
|
||||
SET(${_prefix}_INCLUDE_FILE ${_include})
|
||||
ENDIF(_bug_search_include)
|
||||
|
||||
IF(_libraries_work)
|
||||
# Test this combination of libraries.
|
||||
SET(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}})
|
||||
CHECK_FUNCTION_EXISTS(${_name} ${_prefix}${_combined_name}_WORKS)
|
||||
SET(CMAKE_REQUIRED_LIBRARIES)
|
||||
MARK_AS_ADVANCED(${_prefix}${_combined_name}_WORKS)
|
||||
SET(_libraries_work ${${_prefix}${_combined_name}_WORKS})
|
||||
|
||||
IF(_libraries_work)
|
||||
MESSAGE(STATUS "Checking for [${__list}] -- libraries found")
|
||||
ENDIF(_libraries_work)
|
||||
|
||||
ENDIF(_libraries_work)
|
||||
|
||||
|
||||
IF(NOT _libraries_work)
|
||||
SET(${LIBRARIES} FALSE)
|
||||
ENDIF(NOT _libraries_work)
|
||||
|
||||
ENDMACRO(CHECK_ALL_LIBRARIES)
|
||||
|
||||
SET(CBLAS_LINKER_FLAGS)
|
||||
SET(CBLAS_LIBRARIES)
|
||||
SET(CBLAS_INCLUDE_DIR)
|
||||
|
||||
# CBLAS in openBLAS
|
||||
IF(NOT CBLAS_LIBRARIES)
|
||||
CHECK_ALL_LIBRARIES(
|
||||
CBLAS_LIBRARIES
|
||||
CBLAS_INCLUDE_DIR
|
||||
cblas
|
||||
cblas_sgemm
|
||||
""
|
||||
"openblas"
|
||||
"cblas.h"
|
||||
TRUE
|
||||
)
|
||||
ENDIF(NOT CBLAS_LIBRARIES)
|
||||
|
||||
#MESSAGE(STATUS ${openblas_INCLUDE_DIR})
|
||||
|
||||
# CBLAS in CBLAS
|
||||
IF(NOT CBLAS_LIBRARIES)
|
||||
CHECK_ALL_LIBRARIES(
|
||||
CBLAS_LIBRARIES
|
||||
CBLAS_INCLUDE_DIR
|
||||
cblas
|
||||
cblas_sgemm
|
||||
""
|
||||
"cblas"
|
||||
"cblas.h"
|
||||
TRUE
|
||||
)
|
||||
ENDIF(NOT CBLAS_LIBRARIES)
|
||||
|
||||
#MESSAGE(STATUS ${cblas_INCLUDE_DIR})
|
||||
|
||||
# CBLAS in lapacke
|
||||
IF(NOT CBLAS_LIBRARIES)
|
||||
CHECK_ALL_LIBRARIES(
|
||||
CBLAS_LIBRARIES
|
||||
CBLAS_INCLUDE_DIR
|
||||
cblas
|
||||
cblas_sgemm
|
||||
""
|
||||
"lapacke"
|
||||
"cblas.h"
|
||||
TRUE
|
||||
)
|
||||
ENDIF(NOT CBLAS_LIBRARIES)
|
||||
|
||||
#MESSAGE(STATUS ${lapacke_INCLUDE_DIR})
|
||||
|
||||
IF(CBLAS_LIBRARIES)
|
||||
SET(CBLAS_FOUND TRUE)
|
||||
ELSE(CBLAS_LIBRARIES)
|
||||
SET(CBLAS_FOUND FALSE)
|
||||
ENDIF(CBLAS_LIBRARIES)
|
||||
|
||||
IF(NOT CBLAS_FOUND AND CBLAS_FIND_REQUIRED)
|
||||
MESSAGE(FATAL_ERROR "CBLAS library not found. Please specify library location")
|
||||
ENDIF(NOT CBLAS_FOUND AND CBLAS_FIND_REQUIRED)
|
||||
|
||||
IF(NOT CBLAS_FIND_QUIETLY)
|
||||
IF(CBLAS_FOUND)
|
||||
MESSAGE(STATUS "CBLAS library found: " ${CBLAS_LIBRARIES})
|
||||
MESSAGE(STATUS "cblas.h include directory: " ${CBLAS_INCLUDE_DIR})
|
||||
ELSE(CBLAS_FOUND)
|
||||
MESSAGE(STATUS "CBLAS library not found. Please specify library location")
|
||||
ENDIF(CBLAS_FOUND)
|
||||
ENDIF(NOT CBLAS_FIND_QUIETLY)
|
|
@ -89,7 +89,10 @@ find_library(MKL_CORE_LIBRARY
|
|||
NO_DEFAULT_PATH)
|
||||
|
||||
set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR})
|
||||
set(MKL_LIBRARIES ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY})
|
||||
# Added -Wl block to avoid circular dependencies.
|
||||
# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
|
||||
# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
|
||||
set(MKL_LIBRARIES -Wl,--start-group ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY} -Wl,--end-group)
|
||||
|
||||
# message("1 ${MKL_INCLUDE_DIR}")
|
||||
# message("2 ${MKL_INTERFACE_LIBRARY}")
|
||||
|
|
|
@ -41,6 +41,14 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|||
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
|
||||
ENDIF (SSE41_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(sse4_2).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "sse4_2" "${SSE_THERE}" SSE42_TRUE)
|
||||
IF (SSE42_TRUE)
|
||||
set(SSE4_2_FOUND true CACHE BOOL "SSE4.2 available on host")
|
||||
ELSE (SSE42_TRUE)
|
||||
set(SSE4_2_FOUND false CACHE BOOL "SSE4.2 available on host")
|
||||
ENDIF (SSE42_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(avx).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "avx" "${SSE_THERE}" AVX_TRUE)
|
||||
IF (AVX_TRUE)
|
||||
|
@ -57,6 +65,14 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|||
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
|
||||
ENDIF (AVX2_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(avx512).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "avx512" "${SSE_THERE}" AVX512_TRUE)
|
||||
IF (AVX512_TRUE)
|
||||
set(AVX512_FOUND true CACHE BOOL "AVX512 available on host")
|
||||
ELSE (AVX512_TRUE)
|
||||
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
|
||||
ENDIF (AVX512_TRUE)
|
||||
|
||||
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE
|
||||
CPUINFO)
|
||||
|
@ -109,6 +125,14 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|||
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
|
||||
ENDIF (AVX2_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(avx512).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "avx512" "${SSE_THERE}" AVX512_TRUE)
|
||||
IF (AVX512_TRUE)
|
||||
set(AVX512_FOUND true CACHE BOOL "AVX512 available on host")
|
||||
ELSE (AVX512_TRUE)
|
||||
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
|
||||
ENDIF (AVX512_TRUE)
|
||||
|
||||
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
# TODO
|
||||
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
|
||||
|
@ -117,6 +141,7 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
|||
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
|
||||
set(AVX_FOUND false CACHE BOOL "AVX available on host")
|
||||
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
|
||||
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
|
||||
ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
|
||||
set(SSE3_FOUND false CACHE BOOL "SSE3 available on host")
|
||||
|
@ -124,6 +149,7 @@ ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|||
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
|
||||
set(AVX_FOUND false CACHE BOOL "AVX available on host")
|
||||
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
|
||||
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
|
||||
ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
|
||||
if(NOT SSE2_FOUND)
|
||||
|
@ -144,5 +170,8 @@ endif(NOT AVX_FOUND)
|
|||
if(NOT AVX2_FOUND)
|
||||
MESSAGE(STATUS "Could not find hardware support for AVX2 on this machine.")
|
||||
endif(NOT AVX2_FOUND)
|
||||
if(NOT AVX512_FOUND)
|
||||
MESSAGE(STATUS "Could not find hardware support for AVX512 on this machine.")
|
||||
endif(NOT AVX512_FOUND)
|
||||
|
||||
mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND, AVX_FOUND, AVX2_FOUND)
|
||||
mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND, AVX_FOUND, AVX2_FOUND, AVX512_FOUND)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
##
|
||||
# This module extracts CMake cached variables into a variable.
|
||||
#
|
||||
# Author: snukky
|
||||
#
|
||||
# This module sets the following variables:
|
||||
# * PROJECT_CMAKE_CACHE - to the output of "cmake -L" - an uncached list of
|
||||
# non-advanced cached variables
|
||||
# * PROJECT_CMAKE_CACHE_ADVANCED - to the output of "cmake -LA" - an uncached
|
||||
# list of advanced cached variables
|
||||
#
|
||||
|
||||
set(PROJECT_CMAKE_CACHE "")
|
||||
set(PROJECT_CMAKE_CACHE_ADVANCED "")
|
||||
|
||||
# Get all CMake variables
|
||||
get_cmake_property(_variableNames VARIABLES)
|
||||
list(SORT _variableNames)
|
||||
list(REMOVE_DUPLICATES _variableNames)
|
||||
|
||||
foreach(_variableName ${_variableNames})
|
||||
# If it is a cache variable
|
||||
get_property(_cachePropIsSet CACHE "${_variableName}" PROPERTY VALUE SET)
|
||||
if(_cachePropIsSet)
|
||||
# Get the variable's type
|
||||
get_property(_variableType CACHE ${_variableName} PROPERTY TYPE)
|
||||
|
||||
# Get the variable's value
|
||||
set(_variableValue "${${_variableName}}")
|
||||
|
||||
# Skip static or internal cached variables, cmake -L[A] does not print them, see
|
||||
# https://github.com/Kitware/CMake/blob/master/Source/cmakemain.cxx#L282
|
||||
if( (NOT "${_variableType}" STREQUAL "STATIC") AND
|
||||
(NOT "${_variableType}" STREQUAL "INTERNAL") AND
|
||||
(NOT "${_variableValue}" STREQUAL "") )
|
||||
|
||||
|
||||
set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValue}\\n\"\n")
|
||||
|
||||
# Get the variable's advanced flag
|
||||
get_property(_isAdvanced CACHE ${_variableName} PROPERTY ADVANCED SET)
|
||||
if(NOT _isAdvanced)
|
||||
set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValue}\\n\"\n")
|
||||
endif()
|
||||
|
||||
# Print variables for debugging
|
||||
#message(STATUS "${_variableName}=${${_variableName}}")
|
||||
#message(STATUS " Type=${_variableType}")
|
||||
#message(STATUS " Advanced=${_isAdvanced}")
|
||||
endif()
|
||||
endif(_cachePropIsSet)
|
||||
endforeach()
|
2
examples
2
examples
|
@ -1 +1 @@
|
|||
Subproject commit 336740065d9c23e53e912a1befff18981d9d27ab
|
||||
Subproject commit c19b7814d71febf1053bd93af6ac314b46204092
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 6a08849b23f6c14eefbe12f4eb73dc638b962587
|
|
@ -0,0 +1,154 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script takes a Tensorflow BERT checkpoint and a model description in a JSON file and converts
|
||||
it to a Marian weight file with numpy weights and an internal YAML description.
|
||||
|
||||
This works with checkpoints from https://github.com/google-research/bert
|
||||
|
||||
Assmung a BERT checkpoint like this:
|
||||
drwxr-xr-x 2 marcinjd marcinjd 4.0K Nov 23 16:39 .
|
||||
-rw-r--r-- 1 marcinjd marcinjd 521 Nov 23 16:38 bert_config.json
|
||||
-rw-r--r-- 1 marcinjd marcinjd 682M Nov 23 16:39 bert_model.ckpt.data-00000-of-00001
|
||||
-rw-r--r-- 1 marcinjd marcinjd 8.5K Nov 23 16:39 bert_model.ckpt.index
|
||||
-rw-r--r-- 1 marcinjd marcinjd 888K Nov 23 16:39 bert_model.ckpt.meta
|
||||
-rw-r--r-- 1 marcinjd marcinjd 973K Nov 23 16:37 vocab.txt
|
||||
|
||||
usage:
|
||||
|
||||
./bert.py --bert_prefix bert_model.ckpt --bert_config bert_config.json --marian bert.npz
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys
|
||||
import yaml
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert Tensorflow BERT model to Marian weight file.')
|
||||
parser.add_argument('--bert_prefix', help='Prefix for Tensorflow BERT checkpoint', required=True)
|
||||
parser.add_argument('--bert_config', help='Path to Tensorflow BERT JSON config', required=True)
|
||||
parser.add_argument('--marian', help='Output path for Marian weight file', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading TensorFlow config from %s" % (args.bert_config,))
|
||||
bertConfig = yaml.load(open(args.bert_config))
|
||||
bertConfigYamlStr = yaml.dump(bertConfig, default_flow_style=False)
|
||||
print(bertConfigYamlStr)
|
||||
|
||||
print("Loading TensorFlow model from %s" % (args.bert_prefix,))
|
||||
|
||||
# Collect tensors from TF model as numpy matrices
|
||||
tfModel = dict()
|
||||
with tf.Session() as sess:
|
||||
preloader = tf.train.import_meta_graph(args.bert_prefix + ".meta")
|
||||
preloader.restore(sess, args.bert_prefix)
|
||||
vars = tf.global_variables()
|
||||
for v in vars:
|
||||
if len(v.shape) > 0:
|
||||
if "adam" not in v.name: # ignore adam parameters
|
||||
print(v.name, v.shape)
|
||||
tfModel[v.name] = sess.run(v.name) # get numpy matrix
|
||||
|
||||
# Prepare Marian model config
|
||||
config = dict()
|
||||
config["type"] = "bert"
|
||||
config["input-types"] = ["sequence", "class"]
|
||||
config["tied-embeddings-all"] = True
|
||||
config["dim-emb"] = tfModel["bert/embeddings/word_embeddings:0"].shape[-1]
|
||||
config["dim-vocabs"] = [ tfModel["bert/embeddings/word_embeddings:0"].shape[0],
|
||||
tfModel["cls/seq_relationship/output_weights:0"].shape[0] ]
|
||||
|
||||
config["transformer-dim-ffn"] = tfModel["bert/encoder/layer_0/intermediate/dense/kernel:0"].shape[-1]
|
||||
config["transformer-ffn-activation"] = bertConfig["hidden_act"]
|
||||
config["transformer-ffn-depth"] = 2
|
||||
config["transformer-heads"] = bertConfig["num_attention_heads"]
|
||||
config["transformer-train-position-embeddings"] = True
|
||||
config["transformer-preprocess"] = ""
|
||||
config["transformer-postprocess"] = "dan"
|
||||
config["transformer-postprocess-emb"] = "nd"
|
||||
config["bert-train-type-embeddings"] = True
|
||||
config["bert-type-vocab-size"] = tfModel["bert/embeddings/token_type_embeddings:0"].shape[0]
|
||||
config["version"] = "bert4marian.py conversion"
|
||||
|
||||
# check number of layers
|
||||
found = True
|
||||
config["enc-depth"] = 0;
|
||||
while found:
|
||||
found = False
|
||||
for key in tfModel:
|
||||
if "bert/encoder/layer_" + str(config["enc-depth"]) in key:
|
||||
config["enc-depth"] += 1
|
||||
found = True
|
||||
break
|
||||
|
||||
if config["enc-depth"] != bertConfig["num_hidden_layers"]:
|
||||
sys.exit("Number of layers in JSON config (%s) and number of layers found in checkpoint (%s) do not match!" % (config["enc-depth"], bertConfig["num_hidden_layers"]))
|
||||
|
||||
configYamlStr = yaml.dump(config, default_flow_style=False)
|
||||
desc = list(configYamlStr)
|
||||
npDesc = np.chararray((len(desc),))
|
||||
npDesc[:] = desc
|
||||
npDesc.dtype = np.int8
|
||||
|
||||
marianModel = dict()
|
||||
marianModel["special:model.yml"] = npDesc
|
||||
|
||||
# Map model weights here #
|
||||
# Embedding layers
|
||||
marianModel["Wemb"] = tfModel["bert/embeddings/word_embeddings:0"]
|
||||
marianModel["Wpos"] = tfModel["bert/embeddings/position_embeddings:0"]
|
||||
marianModel["Wtype"] = tfModel["bert/embeddings/token_type_embeddings:0"]
|
||||
marianModel["encoder_emb_ln_scale_pre"] = tfModel["bert/embeddings/LayerNorm/gamma:0"]
|
||||
marianModel["encoder_emb_ln_bias_pre"] = tfModel["bert/embeddings/LayerNorm/beta:0"]
|
||||
|
||||
for layer in range(config["enc-depth"]):
|
||||
marianPrefix = "encoder_l%s" % (layer + 1,)
|
||||
tfPrefix = "bert/encoder/layer_%s" % (layer,)
|
||||
|
||||
# Attention
|
||||
marianModel[marianPrefix + "_self_Wq"] = tfModel[tfPrefix + "/attention/self/query/kernel:0"]
|
||||
marianModel[marianPrefix + "_self_bq"] = tfModel[tfPrefix + "/attention/self/query/bias:0"]
|
||||
|
||||
marianModel[marianPrefix + "_self_Wk"] = tfModel[tfPrefix + "/attention/self/key/kernel:0"]
|
||||
marianModel[marianPrefix + "_self_bk"] = tfModel[tfPrefix + "/attention/self/key/bias:0"]
|
||||
|
||||
marianModel[marianPrefix + "_self_Wv"] = tfModel[tfPrefix + "/attention/self/value/kernel:0"]
|
||||
marianModel[marianPrefix + "_self_bv"] = tfModel[tfPrefix + "/attention/self/value/bias:0"]
|
||||
|
||||
marianModel[marianPrefix + "_self_Wo"] = tfModel[tfPrefix + "/attention/output/dense/kernel:0"]
|
||||
marianModel[marianPrefix + "_self_bo"] = tfModel[tfPrefix + "/attention/output/dense/bias:0"]
|
||||
|
||||
marianModel[marianPrefix + "_self_Wo_ln_scale"] = tfModel[tfPrefix + "/attention/output/LayerNorm/gamma:0"]
|
||||
marianModel[marianPrefix + "_self_Wo_ln_bias"] = tfModel[tfPrefix + "/attention/output/LayerNorm/beta:0"]
|
||||
|
||||
# FFN
|
||||
marianModel[marianPrefix + "_ffn_W1"] = tfModel[tfPrefix + "/intermediate/dense/kernel:0"]
|
||||
marianModel[marianPrefix + "_ffn_b1"] = tfModel[tfPrefix + "/intermediate/dense/bias:0"]
|
||||
|
||||
marianModel[marianPrefix + "_ffn_W2"] = tfModel[tfPrefix + "/output/dense/kernel:0"]
|
||||
marianModel[marianPrefix + "_ffn_b2"] = tfModel[tfPrefix + "/output/dense/bias:0"]
|
||||
|
||||
marianModel[marianPrefix + "_ffn_ffn_ln_scale"] = tfModel[tfPrefix + "/output/LayerNorm/gamma:0"]
|
||||
marianModel[marianPrefix + "_ffn_ffn_ln_bias"] = tfModel[tfPrefix + "/output/LayerNorm/beta:0"]
|
||||
|
||||
# Training objectives
|
||||
# Masked-LM output layer
|
||||
marianModel["masked-lm_ff_logit_l1_W"] = tfModel["cls/predictions/transform/dense/kernel:0"]
|
||||
marianModel["masked-lm_ff_logit_l1_b"] = tfModel["cls/predictions/transform/dense/bias:0"]
|
||||
|
||||
marianModel["masked-lm_ff_ln_scale"] = tfModel["cls/predictions/transform/LayerNorm/gamma:0"]
|
||||
marianModel["masked-lm_ff_ln_bias"] = tfModel["cls/predictions/transform/LayerNorm/beta:0"]
|
||||
|
||||
marianModel["masked-lm_ff_logit_l2_b"] = tfModel["cls/predictions/output_bias:0"]
|
||||
|
||||
# Next Sentence classifier
|
||||
marianModel["next-sentence_ff_logit_l1_W"] = tfModel["bert/pooler/dense/kernel:0"]
|
||||
marianModel["next-sentence_ff_logit_l1_b"] = tfModel["bert/pooler/dense/bias:0"]
|
||||
|
||||
marianModel["next-sentence_ff_logit_l2_W"] = np.transpose(tfModel["cls/seq_relationship/output_weights:0"]) # transpose?!
|
||||
marianModel["next-sentence_ff_logit_l2_b"] = tfModel["cls/seq_relationship/output_bias:0"]
|
||||
|
||||
print("\nMarian config:")
|
||||
print(configYamlStr)
|
||||
print("Saving Marian model to %s" % (args.marian,))
|
||||
np.savez(args.marian, **marianModel)
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
This script takes multiple Marian *.npz model files and outputs an elementwise average of the model,
|
||||
meant to do check-point averaging from:
|
||||
|
||||
https://www.aclweb.org/anthology/W16-2316
|
||||
|
||||
usage:
|
||||
|
||||
./average.py -m model.1.npz model.2.npz --output model.avg.npz
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-m', '--model', nargs='+', required=True,
|
||||
help="models to average")
|
||||
parser.add_argument('-o', '--output', required=True,
|
||||
help="output path")
|
||||
args = parser.parse_args()
|
||||
|
||||
# *average* holds the model matrix
|
||||
average = dict()
|
||||
# No. of models.
|
||||
n = len(args.model)
|
||||
|
||||
for filename in args.model:
|
||||
print("Loading {}".format(filename))
|
||||
with open(filename, "rb") as mfile:
|
||||
# Loads matrix from model file
|
||||
m = np.load(mfile)
|
||||
for k in m:
|
||||
if k != "history_errs":
|
||||
# Initialize the key
|
||||
if k not in average:
|
||||
average[k] = m[k]
|
||||
# Add to the appropriate value
|
||||
elif average[k].shape == m[k].shape and "special" not in k:
|
||||
average[k] += m[k]
|
||||
|
||||
# Actual averaging
|
||||
for k in average:
|
||||
if "special" not in k:
|
||||
average[k] /= n
|
||||
|
||||
# Save averaged model to file
|
||||
print("Saving to {}".format(args.output))
|
||||
np.savez(args.output, **average)
|
|
@ -1,4 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import print_function
|
||||
|
@ -9,18 +9,22 @@ import numpy as np
|
|||
|
||||
|
||||
def main():
|
||||
desc = """Export word embedding from model"""
|
||||
desc = """Export word embeddings from model"""
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter, description=desc)
|
||||
parser.add_argument("-m", "--model", help="Model file", required=True)
|
||||
parser.add_argument(
|
||||
"-o", "--output-prefix", help="Output files prefix", required=True)
|
||||
parser.add_argument("-m", "--model", help="path to model.npz file", required=True)
|
||||
parser.add_argument("-o", "--output-prefix", help="prefix for output files", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading model")
|
||||
model = np.load(args.model)
|
||||
special = yaml.load(model["special:model.yml"][:-1].tobytes())
|
||||
|
||||
if special["tied-embeddings-all"] or special["tied-embeddings-src"]:
|
||||
all_emb = model["Wemb"]
|
||||
export_emb(args.output_prefix + ".all", all_emb)
|
||||
exit()
|
||||
|
||||
if special["type"] == "amun":
|
||||
enc_emb = model["Wemb"]
|
||||
dec_emb = model["Wemb_dec"]
|
||||
|
@ -28,16 +32,15 @@ def main():
|
|||
enc_emb = model["encoder_Wemb"]
|
||||
dec_emb = model["decoder_Wemb"]
|
||||
|
||||
with open(args.output_prefix + ".src", "w") as out:
|
||||
out.write("{0} {1}\n".format(*enc_emb.shape))
|
||||
for i in range(enc_emb.shape[0]):
|
||||
vec = " ".join("{0:.8f}".format(v) for v in enc_emb[i])
|
||||
out.write("{0} {1}\n".format(i, vec))
|
||||
export_emb(args.output_prefix + ".src", enc_emb)
|
||||
export_emb(args.output_prefix + ".trg", dec_emb)
|
||||
|
||||
with open(args.output_prefix + ".trg", "w") as out:
|
||||
out.write("{0} {1}\n".format(*dec_emb.shape))
|
||||
for i in range(dec_emb.shape[0]):
|
||||
vec = " ".join("{0:.8f}".format(v) for v in dec_emb[i])
|
||||
|
||||
def export_emb(filename, emb):
|
||||
with open(filename, "w") as out:
|
||||
out.write("{0} {1}\n".format(*emb.shape))
|
||||
for i in range(emb.shape[0]):
|
||||
vec = " ".join("{0:.8f}".format(v) for v in emb[i])
|
||||
out.write("{0} {1}\n".format(i, vec))
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
bin
|
||||
fast_align
|
||||
extract-lex
|
|
@ -0,0 +1,8 @@
|
|||
`install.sh` is a helper script that downloads and compiles fastalign and extract-lex, and copies
|
||||
required binaries into _./bin_.
|
||||
|
||||
Shortlist files (_lex.s2t_ and _lex.t2s_) can be created using `generate_shortlists.pl`, for
|
||||
example:
|
||||
|
||||
perl generate_shortlists.pl --bindir ./bin -s corpus.bpe.src -t corpus.bpe.tgt
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
#!/usr/bin/env perl
|
||||
|
||||
use strict;
|
||||
use Getopt::Long;
|
||||
use FindBin qw($Bin);
|
||||
use File::Temp qw(tempdir tempfile);
|
||||
use POSIX;
|
||||
|
||||
my $PID = $$;
|
||||
$SIG{TERM} = $SIG{INT} = $SIG{QUIT} = sub { die; };
|
||||
|
||||
my $BINDIR = "$Bin/bin";
|
||||
my $SRC;
|
||||
my $TRG;
|
||||
my $OUTPUT = "lex";
|
||||
my $THREADS = 8;
|
||||
my $PARALLEL = 0;
|
||||
my $HELP;
|
||||
|
||||
GetOptions(
|
||||
"b|bindir=s" => \$BINDIR,
|
||||
"s|source=s" => \$SRC,
|
||||
"t|target=s" => \$TRG,
|
||||
"o|output=s" => \$OUTPUT,
|
||||
"threads=i" => \$THREADS,
|
||||
"parallel" => \$PARALLEL,
|
||||
"h|help" => \$HELP,
|
||||
);
|
||||
|
||||
if($HELP) {
|
||||
print "Usage: perl $0 -b bindir -s corpus.src -t corpus.tgt [-o outputprefix] [--threads 8] [--parallel]\n";
|
||||
exit 0;
|
||||
}
|
||||
|
||||
die "--bindir arg is required" if not defined $BINDIR;
|
||||
die "-s|--source arg is required" if not defined $SRC;
|
||||
die "-t|--target arg is required" if not defined $TRG;
|
||||
die "-o|--output arg is required" if not defined $OUTPUT;
|
||||
|
||||
for my $app (qw(fast_align atools extract_lex)) {
|
||||
die "Could not find $app in $BINDIR" if not -e "$BINDIR/$app";
|
||||
}
|
||||
|
||||
my $TEMPDIR = tempdir(CLEANUP => 1);
|
||||
|
||||
my (undef, $CORPUS) = tempfile(DIR => $TEMPDIR);
|
||||
my (undef, $ALN_S2T) = tempfile(DIR => $TEMPDIR);
|
||||
my (undef, $ALN_T2S) = tempfile(DIR => $TEMPDIR);
|
||||
my (undef, $ALN_GDF) = tempfile(DIR => $TEMPDIR);
|
||||
|
||||
execute("paste $SRC $TRG | sed 's/\\t/ ||| /' > $CORPUS");
|
||||
|
||||
my @COMMANDS = (
|
||||
"OMP_NUM_THREADS=$THREADS $BINDIR/fast_align -vdo -i $CORPUS > $ALN_S2T",
|
||||
"OMP_NUM_THREADS=$THREADS $BINDIR/fast_align -vdor -i $CORPUS > $ALN_T2S"
|
||||
);
|
||||
|
||||
my @PIDS;
|
||||
for my $c (@COMMANDS) {
|
||||
if ($PARALLEL) {
|
||||
my $pid = fork();
|
||||
if (!$pid) {
|
||||
execute($c);
|
||||
exit(0);
|
||||
} else {
|
||||
push(@PIDS, $pid);
|
||||
print "Forked process $pid\n";
|
||||
}
|
||||
} else {
|
||||
execute($c);
|
||||
}
|
||||
}
|
||||
if ($PARALLEL) {
|
||||
waitpid($_, 0) foreach(@PIDS);
|
||||
}
|
||||
|
||||
execute("$BINDIR/atools -c grow-diag-final -i $ALN_S2T -j $ALN_T2S > $ALN_GDF");
|
||||
execute("$BINDIR/extract_lex $TRG $SRC $ALN_GDF $OUTPUT.s2t $OUTPUT.t2s");
|
||||
|
||||
sub execute {
|
||||
my $command = shift;
|
||||
logMessage("Executing:\t$command");
|
||||
my $ret = system($command);
|
||||
if ($ret != 0) {
|
||||
logMessage("Command '$command' finished with return status $ret");
|
||||
logMessage("Aborting and killing parent process");
|
||||
kill(2, $PID);
|
||||
die;
|
||||
}
|
||||
}
|
||||
|
||||
sub logMessage {
|
||||
my $message = shift;
|
||||
my $time = POSIX::strftime("%m/%d/%Y %H:%M:%S", localtime());
|
||||
my $log_message = $time."\t$message\n";
|
||||
print STDERR $log_message;
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash -v
|
||||
|
||||
mkdir -p bin
|
||||
|
||||
# download and compile fast_align
|
||||
if [ ! -e bin/fast_align ]; then
|
||||
git clone https://github.com/clab/fast_align
|
||||
mkdir -p fast_align/build
|
||||
cd fast_align/build
|
||||
cmake ..
|
||||
make -j4
|
||||
cp fast_align atools ../../bin
|
||||
cd ../../
|
||||
fi
|
||||
|
||||
# download and compile extract-lex
|
||||
if [ ! -e bin/extract_lex ]; then
|
||||
git clone https://github.com/marian-nmt/extract-lex
|
||||
mkdir -p extract-lex/build
|
||||
cd extract-lex/build
|
||||
cmake ..
|
||||
make -j4
|
||||
cp extract_lex ../../bin
|
||||
cd ../../
|
||||
fi
|
|
@ -1590,7 +1590,12 @@ class App {
|
|||
}
|
||||
|
||||
// Unlimited vector parser
|
||||
// RG: A negative number for the total number of expected values means that the option is a
|
||||
// vector and accepts an unlimited number of values
|
||||
if(num < 0) {
|
||||
// RG: We need to keep track if the vector option is empty and handle this separately as
|
||||
// otherwise the parser will mark the command-line option as not set
|
||||
bool emptyVectorArgs = true;
|
||||
while(!args.empty() && _recognize(args.back()) == detail::Classifer::NONE) {
|
||||
if(collected >= -num) {
|
||||
// We could break here for allow extras, but we don't
|
||||
|
@ -1603,12 +1608,28 @@ class App {
|
|||
parse_order_.push_back(op.get());
|
||||
args.pop_back();
|
||||
collected++;
|
||||
emptyVectorArgs = false;
|
||||
}
|
||||
|
||||
// Allow -- to end an unlimited list and "eat" it
|
||||
if(!args.empty() && _recognize(args.back()) == detail::Classifer::POSITIONAL_MARK)
|
||||
args.pop_back();
|
||||
|
||||
// RG: Handle empty vector-like options
|
||||
if(emptyVectorArgs) {
|
||||
// RG: Set implicit value(s) if the option has it (them)
|
||||
if(op->get_implicit()) {
|
||||
for(const auto& ival : detail::split_up(op->get_implicitval())) {
|
||||
op->add_result(ival);
|
||||
parse_order_.push_back(op.get());
|
||||
}
|
||||
// RG: Abort if there is a minimum number of values expected. Note: get_expected()
|
||||
// equals to -N means at least N values are expected
|
||||
} else if (op->get_expected() < 0) {
|
||||
parse_order_.push_back(op.get());
|
||||
throw ArgumentMismatch(op->get_name(), op->get_expected(), 0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while(num > 0 && !args.empty()) {
|
||||
num--;
|
||||
|
|
|
@ -6,6 +6,33 @@ add_subdirectory(./SQLiteCpp)
|
|||
add_subdirectory(./pathie-cpp)
|
||||
add_subdirectory(./zlib)
|
||||
|
||||
if(USE_FBGEMM)
|
||||
# @TODO: find out if this is somehow harmful. This is supppressing CMake warnings for CMAKE_SUPPRESS_DEVELOPER_WARNINGS
|
||||
# meant to silence CMakeFiles of 3rd_party tools.
|
||||
if(NOT DEFINED CMAKE_SUPPRESS_DEVELOPER_WARNINGS)
|
||||
set(CMAKE_SUPPRESS_DEVELOPER_WARNINGS 1 CACHE INTERNAL "No dev warnings")
|
||||
endif()
|
||||
|
||||
if(NOT MSVC)
|
||||
# only locally disabled for the 3rd_party folder
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-value -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused")
|
||||
endif()
|
||||
|
||||
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests")
|
||||
set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "Disable fbgemm benchmark")
|
||||
add_subdirectory(./fbgemm)
|
||||
|
||||
# asmjit (3rd-party submodule of fbgemm) sets -Wall -Wextra near the end of
|
||||
# the compile options, invalidating any -Wno-... flags that we may have set
|
||||
# earlier. Let's remove them.
|
||||
get_property(ASMJIT_COMPILE_OPTIONS TARGET asmjit PROPERTY COMPILE_OPTIONS)
|
||||
list(REMOVE_ITEM ASMJIT_COMPILE_OPTIONS -Wall -Wextra)
|
||||
set_property(TARGET asmjit PROPERTY COMPILE_OPTIONS ${ASMJIT_COMPILE_OPTIONS})
|
||||
message(" ASMJIT COMPILE FLAGS: ${ASMJIT_COMPILE_OPTIONS}")
|
||||
|
||||
endif(USE_FBGEMM)
|
||||
|
||||
if(USE_SENTENCEPIECE)
|
||||
if(USE_STATIC_LIBS)
|
||||
set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||
|
@ -16,16 +43,37 @@ if(USE_SENTENCEPIECE)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
||||
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available." FORCE)
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
message(WARNING "You are compiling SentencePiece binaries with -DUSE_STATIC_LIBS=on. \
|
||||
This will cause spm_train to segfault. No need to worry if you do not intend to use that binary. \
|
||||
Marian support for SentencePiece will work fine.")
|
||||
|
||||
set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
||||
set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE)
|
||||
else(USE_STATIC_LIBS)
|
||||
set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
||||
set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC." FORCE)
|
||||
endif(USE_STATIC_LIBS)
|
||||
|
||||
add_subdirectory(./sentencepiece)
|
||||
include_directories(./sentencepiece)
|
||||
|
||||
set_target_properties(spm_encode spm_decode spm_train spm_normalize spm_export_vocab
|
||||
PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
foreach(t sentencepiece sentencepiece_train sentencepiece_train-static
|
||||
spm_decode spm_encode spm_export_vocab spm_normalize spm_train)
|
||||
set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-tautological-compare -Wno-unused")
|
||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-range-loop-construct")
|
||||
endif()
|
||||
# get_property(SENTENCEPIECE_COMPILE_FLAGS TARGET ${t} PROPERTY COMPILE_FLAGS)
|
||||
# message("-- SENTENCPIECE: compile flags for target ${t}: ${SENTENCEPIECE_COMPILE_FLAGS}")
|
||||
endforeach(t)
|
||||
endif()
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
set(CMAKE_FIND_LIBRARY_SUFFIXES ${_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||
|
@ -36,5 +84,66 @@ include_directories(./SQLiteCpp/include)
|
|||
include_directories(./CLI)
|
||||
include_directories(./pathie-cpp/include)
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
#set_target_properties(SQLiteCpp PROPERTIES COMPILE_FLAGS
|
||||
set_property(TARGET SQLiteCpp APPEND_STRING PROPERTY COMPILE_FLAGS
|
||||
" -Wno-parentheses-equality -Wno-unused-value")
|
||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
set_property(TARGET SQLiteCpp APPEND_STRING PROPERTY COMPILE_FLAGS
|
||||
" -Wno-implicit-int-float-conversion")
|
||||
endif()
|
||||
set_property(TARGET libyaml-cpp APPEND_STRING PROPERTY COMPILE_FLAGS
|
||||
" -fPIC -Wno-unused-value")
|
||||
set_property(TARGET pathie-cpp APPEND_STRING PROPERTY COMPILE_FLAGS
|
||||
" -fPIC -Wno-unused-value")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
include_directories(./zlib)
|
||||
|
||||
include(ExternalProject)
|
||||
|
||||
set(INSTALLS "") # this will contain a list of 3rd part dependencies that we install locally
|
||||
if(CUDA_FOUND)
|
||||
if(USE_NCCL)
|
||||
|
||||
# disables compilation for sm_30 to avoid ptxas warning... that is general Kepler support. But K80s are supported for instance by sm_35
|
||||
|
||||
set(GENCODE "")
|
||||
if(COMPILE_CUDA_SM35)
|
||||
set(GENCODE "${GENCODE} -gencode=arch=compute_35,code=sm_35")
|
||||
endif(COMPILE_CUDA_SM35)
|
||||
if(COMPILE_CUDA_SM50)
|
||||
set(GENCODE "${GENCODE} -gencode=arch=compute_50,code=sm_50")
|
||||
endif(COMPILE_CUDA_SM50)
|
||||
if(COMPILE_CUDA_SM60)
|
||||
set(GENCODE "${GENCODE} -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61")
|
||||
endif(COMPILE_CUDA_SM60)
|
||||
if(COMPILE_CUDA_SM70)
|
||||
set(GENCODE "${GENCODE} -gencode=arch=compute_70,code=sm_70")
|
||||
endif(COMPILE_CUDA_SM70)
|
||||
|
||||
# install nccl in ${CMAKE_BINARY_DIR}/local similar to /usr/local linux installation
|
||||
ExternalProject_Add(nccl_install
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/nccl
|
||||
BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/nccl
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND
|
||||
$(MAKE) -f ${CMAKE_CURRENT_SOURCE_DIR}/nccl/Makefile src.build
|
||||
BUILDDIR=${CMAKE_BINARY_DIR}/local CUDA_HOME=${CUDA_TOOLKIT_ROOT_DIR}
|
||||
CUDA8_GENCODE=${GENCODE} CXX=${CMAKE_CXX_COMPILER}
|
||||
INSTALL_COMMAND "")
|
||||
|
||||
set_target_properties(nccl PROPERTIES IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/local/lib/libnccl_static.a)
|
||||
add_dependencies(nccl nccl_install)
|
||||
set(INSTALLS ${INSTALLS} nccl_install)
|
||||
|
||||
endif(USE_NCCL)
|
||||
endif(CUDA_FOUND)
|
||||
|
||||
# @TODO: do the same for SentencePiece, Protobuf etc.
|
||||
# make clean will clean "${CMAKE_BINARY_DIR}/local"
|
||||
set_directory_properties(PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_BINARY_DIR}/local)
|
||||
|
||||
add_custom_target(3rd_party_installs DEPENDS ${INSTALLS})
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
// ExceptionWithCallStack.h - debug util functions
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
|
|
@ -0,0 +1,726 @@
|
|||
/*
|
||||
AVX implementation of sin, cos, sincos, exp and log
|
||||
|
||||
Based on "sse_mathfun.h", by Julien Pommier
|
||||
http://gruntthepeon.free.fr/ssemath/
|
||||
|
||||
Copyright (C) 2012 Giovanni Garberoglio
|
||||
Interdisciplinary Laboratory for Computational Science (LISC)
|
||||
Fondazione Bruno Kessler and University of Trento
|
||||
via Sommarive, 18
|
||||
I-38123 Trento (Italy)
|
||||
|
||||
This software is provided 'as-is', without any express or implied
|
||||
warranty. In no event will the authors be held liable for any damages
|
||||
arising from the use of this software.
|
||||
|
||||
Permission is granted to anyone to use this software for any purpose,
|
||||
including commercial applications, and to alter it and redistribute it
|
||||
freely, subject to the following restrictions:
|
||||
|
||||
1. The origin of this software must not be misrepresented; you must not
|
||||
claim that you wrote the original software. If you use this software
|
||||
in a product, an acknowledgment in the product documentation would be
|
||||
appreciated but is not required.
|
||||
2. Altered source versions must be plainly marked as such, and must not be
|
||||
misrepresented as being the original software.
|
||||
3. This notice may not be removed or altered from any source distribution.
|
||||
|
||||
(this is the zlib license)
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
/* yes I know, the top of this file is quite ugly */
|
||||
#ifdef _MSC_VER
|
||||
# define ALIGN32_BEG __declspec(align(32))
|
||||
# define ALIGN32_END
|
||||
#else /* gcc or icc */
|
||||
# define ALIGN32_BEG
|
||||
# define ALIGN32_END __attribute__((aligned(32)))
|
||||
#endif
|
||||
|
||||
/* __m128 is ugly to write */
|
||||
typedef __m256 v8sf; // vector of 8 float (avx)
|
||||
typedef __m256i v8si; // vector of 8 int (avx)
|
||||
typedef __m128i v4si; // vector of 8 int (avx)
|
||||
|
||||
#define _PI32AVX_CONST(Name, Val) \
|
||||
static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = { Val, Val, Val, Val }
|
||||
|
||||
_PI32AVX_CONST(1, 1);
|
||||
_PI32AVX_CONST(inv1, ~1);
|
||||
_PI32AVX_CONST(2, 2);
|
||||
_PI32AVX_CONST(4, 4);
|
||||
|
||||
|
||||
/* declare some AVX constants -- why can't I figure a better way to do that? */
|
||||
#define _PS256_CONST(Name, Val) \
|
||||
static const ALIGN32_BEG float _ps256_##Name[8] ALIGN32_END = { (float)Val, (float)Val, (float)Val, (float)Val, (float)Val, (float)Val, (float)Val, (float)Val }
|
||||
#define _PI32_CONST256(Name, Val) \
|
||||
static const ALIGN32_BEG int _pi32_256_##Name[8] ALIGN32_END = { Val, Val, Val, Val, Val, Val, Val, Val }
|
||||
#define _PS256_CONST_TYPE(Name, Type, Val) \
|
||||
static const ALIGN32_BEG Type _ps256_##Name[8] ALIGN32_END = { Val, Val, Val, Val, Val, Val, Val, Val }
|
||||
|
||||
_PS256_CONST(1 , 1.0f);
|
||||
_PS256_CONST(0p5, 0.5f);
|
||||
/* the smallest non denormalized float number */
|
||||
_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
|
||||
_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
|
||||
_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
|
||||
|
||||
_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
|
||||
_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
|
||||
|
||||
_PI32_CONST256(0, 0);
|
||||
_PI32_CONST256(1, 1);
|
||||
_PI32_CONST256(inv1, ~1);
|
||||
_PI32_CONST256(2, 2);
|
||||
_PI32_CONST256(4, 4);
|
||||
_PI32_CONST256(0x7f, 0x7f);
|
||||
|
||||
_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
|
||||
_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
|
||||
_PS256_CONST(cephes_log_p1, - 1.1514610310E-1);
|
||||
_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
|
||||
_PS256_CONST(cephes_log_p3, - 1.2420140846E-1);
|
||||
_PS256_CONST(cephes_log_p4, + 1.4249322787E-1);
|
||||
_PS256_CONST(cephes_log_p5, - 1.6668057665E-1);
|
||||
_PS256_CONST(cephes_log_p6, + 2.0000714765E-1);
|
||||
_PS256_CONST(cephes_log_p7, - 2.4999993993E-1);
|
||||
_PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
|
||||
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
|
||||
_PS256_CONST(cephes_log_q2, 0.693359375);
|
||||
|
||||
#ifndef __AVX2__
|
||||
|
||||
typedef union imm_xmm_union {
|
||||
v8si imm;
|
||||
v4si xmm[2];
|
||||
} imm_xmm_union;
|
||||
|
||||
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) { \
|
||||
ALIGN32_BEG imm_xmm_union u ALIGN32_END; \
|
||||
u.imm = imm_; \
|
||||
xmm0_ = u.xmm[0]; \
|
||||
xmm1_ = u.xmm[1]; \
|
||||
}
|
||||
|
||||
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) { \
|
||||
ALIGN32_BEG imm_xmm_union u ALIGN32_END; \
|
||||
u.xmm[0]=xmm0_; u.xmm[1]=xmm1_; imm_ = u.imm; \
|
||||
}
|
||||
|
||||
|
||||
#define AVX2_BITOP_USING_SSE2(fn) \
|
||||
static inline v8si avx2_mm256_##fn(v8si x, int a) \
|
||||
{ \
|
||||
/* use SSE2 instruction to perform the bitop AVX2 */ \
|
||||
v4si x1, x2; \
|
||||
v8si ret; \
|
||||
COPY_IMM_TO_XMM(x, x1, x2); \
|
||||
x1 = _mm_##fn(x1,a); \
|
||||
x2 = _mm_##fn(x2,a); \
|
||||
COPY_XMM_TO_IMM(x1, x2, ret); \
|
||||
return(ret); \
|
||||
}
|
||||
|
||||
//#warning "Using SSE2 to perform AVX2 bitshift ops"
|
||||
AVX2_BITOP_USING_SSE2(slli_epi32)
|
||||
AVX2_BITOP_USING_SSE2(srli_epi32)
|
||||
|
||||
#define AVX2_INTOP_USING_SSE2(fn) \
|
||||
static inline v8si avx2_mm256_##fn(v8si x, v8si y) \
|
||||
{ \
|
||||
/* use SSE2 instructions to perform the AVX2 integer operation */ \
|
||||
v4si x1, x2; \
|
||||
v4si y1, y2; \
|
||||
v8si ret; \
|
||||
COPY_IMM_TO_XMM(x, x1, x2); \
|
||||
COPY_IMM_TO_XMM(y, y1, y2); \
|
||||
x1 = _mm_##fn(x1,y1); \
|
||||
x2 = _mm_##fn(x2,y2); \
|
||||
COPY_XMM_TO_IMM(x1, x2, ret); \
|
||||
return(ret); \
|
||||
}
|
||||
|
||||
//#warning "Using SSE2 to perform AVX2 integer ops"
|
||||
AVX2_INTOP_USING_SSE2(and_si128)
|
||||
AVX2_INTOP_USING_SSE2(andnot_si128)
|
||||
AVX2_INTOP_USING_SSE2(cmpeq_epi32)
|
||||
AVX2_INTOP_USING_SSE2(sub_epi32)
|
||||
AVX2_INTOP_USING_SSE2(add_epi32)
|
||||
#define avx2_mm256_and_si256 avx2_mm256_and_si128
|
||||
#define avx2_mm256_andnot_si256 avx2_mm256_andnot_si128
|
||||
#else
|
||||
#define avx2_mm256_slli_epi32 _mm256_slli_epi32
|
||||
#define avx2_mm256_srli_epi32 _mm256_srli_epi32
|
||||
#define avx2_mm256_and_si256 _mm256_and_si256
|
||||
#define avx2_mm256_andnot_si256 _mm256_andnot_si256
|
||||
#define avx2_mm256_cmpeq_epi32 _mm256_cmpeq_epi32
|
||||
#define avx2_mm256_sub_epi32 _mm256_sub_epi32
|
||||
#define avx2_mm256_add_epi32 _mm256_add_epi32
|
||||
#endif /* __AVX2__ */
|
||||
|
||||
|
||||
/* natural logarithm computed for 8 simultaneous float
|
||||
return NaN for x <= 0
|
||||
*/
|
||||
static inline v8sf log256_ps(v8sf x) {
|
||||
v8si imm0;
|
||||
v8sf one = *(v8sf*)_ps256_1;
|
||||
|
||||
//v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
|
||||
v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);
|
||||
|
||||
x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */
|
||||
|
||||
// can be done with AVX2
|
||||
imm0 = avx2_mm256_srli_epi32(_mm256_castps_si256(x), 23);
|
||||
|
||||
/* keep only the fractional part */
|
||||
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
|
||||
x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);
|
||||
|
||||
// this is again another AVX2 instruction
|
||||
imm0 = avx2_mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
|
||||
v8sf e = _mm256_cvtepi32_ps(imm0);
|
||||
|
||||
e = _mm256_add_ps(e, one);
|
||||
|
||||
/* part2:
|
||||
if( x < SQRTHF ) {
|
||||
e -= 1;
|
||||
x = x + x - 1.0;
|
||||
} else { x = x - 1.0; }
|
||||
*/
|
||||
//v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
|
||||
v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
|
||||
v8sf tmp = _mm256_and_ps(x, mask);
|
||||
x = _mm256_sub_ps(x, one);
|
||||
e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
|
||||
x = _mm256_add_ps(x, tmp);
|
||||
|
||||
v8sf z = _mm256_mul_ps(x,x);
|
||||
|
||||
v8sf y = *(v8sf*)_ps256_cephes_log_p0;
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
|
||||
y = _mm256_mul_ps(y, z);
|
||||
|
||||
tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
|
||||
y = _mm256_add_ps(y, tmp);
|
||||
|
||||
|
||||
tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
||||
y = _mm256_sub_ps(y, tmp);
|
||||
|
||||
tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
|
||||
x = _mm256_add_ps(x, y);
|
||||
x = _mm256_add_ps(x, tmp);
|
||||
x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN
|
||||
return x;
|
||||
}
|
||||
|
||||
_PS256_CONST(exp_hi, 88.3762626647949f);
|
||||
_PS256_CONST(exp_lo, -88.3762626647949f);
|
||||
|
||||
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
|
||||
_PS256_CONST(cephes_exp_C1, 0.693359375);
|
||||
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
|
||||
|
||||
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
|
||||
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
|
||||
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
|
||||
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
|
||||
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
|
||||
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
|
||||
|
||||
static inline v8sf exp256_ps(v8sf x) {
|
||||
v8sf tmp = _mm256_setzero_ps(), fx;
|
||||
v8si imm0;
|
||||
v8sf one = *(v8sf*)_ps256_1;
|
||||
|
||||
x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
|
||||
x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);
|
||||
|
||||
/* express exp(x) as exp(g + n*log(2)) */
|
||||
fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
|
||||
fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);
|
||||
|
||||
/* how to perform a floorf with SSE: just below */
|
||||
//imm0 = _mm256_cvttps_epi32(fx);
|
||||
//tmp = _mm256_cvtepi32_ps(imm0);
|
||||
|
||||
tmp = _mm256_floor_ps(fx);
|
||||
|
||||
/* if greater, substract 1 */
|
||||
//v8sf mask = _mm256_cmpgt_ps(tmp, fx);
|
||||
v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
|
||||
mask = _mm256_and_ps(mask, one);
|
||||
fx = _mm256_sub_ps(tmp, mask);
|
||||
|
||||
tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
|
||||
v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
|
||||
x = _mm256_sub_ps(x, tmp);
|
||||
x = _mm256_sub_ps(x, z);
|
||||
|
||||
z = _mm256_mul_ps(x,x);
|
||||
|
||||
v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
|
||||
y = _mm256_mul_ps(y, x);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_add_ps(y, x);
|
||||
y = _mm256_add_ps(y, one);
|
||||
|
||||
/* build 2^n */
|
||||
imm0 = _mm256_cvttps_epi32(fx);
|
||||
// another two AVX2 instructions
|
||||
imm0 = avx2_mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
|
||||
imm0 = avx2_mm256_slli_epi32(imm0, 23);
|
||||
v8sf pow2n = _mm256_castsi256_ps(imm0);
|
||||
y = _mm256_mul_ps(y, pow2n);
|
||||
return y;
|
||||
}
|
||||
|
||||
_PS256_CONST(minus_cephes_DP1, -0.78515625);
|
||||
_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
|
||||
_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
|
||||
_PS256_CONST(sincof_p0, -1.9515295891E-4);
|
||||
_PS256_CONST(sincof_p1, 8.3321608736E-3);
|
||||
_PS256_CONST(sincof_p2, -1.6666654611E-1);
|
||||
_PS256_CONST(coscof_p0, 2.443315711809948E-005);
|
||||
_PS256_CONST(coscof_p1, -1.388731625493765E-003);
|
||||
_PS256_CONST(coscof_p2, 4.166664568298827E-002);
|
||||
_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
|
||||
|
||||
|
||||
/* evaluation of 8 sines at onces using AVX intrisics
|
||||
|
||||
The code is the exact rewriting of the cephes sinf function.
|
||||
Precision is excellent as long as x < 8192 (I did not bother to
|
||||
take into account the special handling they have for greater values
|
||||
-- it does not return garbage for arguments over 8192, though, but
|
||||
the extra precision is missing).
|
||||
|
||||
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
|
||||
surprising but correct result.
|
||||
|
||||
*/
|
||||
static inline v8sf sin256_ps(v8sf x) { // any x
|
||||
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
|
||||
v8si imm0, imm2;
|
||||
|
||||
#ifndef __AVX2__
|
||||
v4si imm0_1, imm0_2;
|
||||
v4si imm2_1, imm2_2;
|
||||
#endif
|
||||
|
||||
sign_bit = x;
|
||||
/* take the absolute value */
|
||||
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
||||
/* extract the sign bit (upper one) */
|
||||
sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);
|
||||
|
||||
/* scale by 4/Pi */
|
||||
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
||||
|
||||
/*
|
||||
Here we start a series of integer operations, which are in the
|
||||
realm of AVX2.
|
||||
If we don't have AVX, let's perform them using SSE2 directives
|
||||
*/
|
||||
|
||||
#ifdef __AVX2__
|
||||
/* store the integer part of y in mm0 */
|
||||
imm2 = _mm256_cvttps_epi32(y);
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
// another two AVX2 instruction
|
||||
imm2 = avx2_mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
||||
imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
||||
y = _mm256_cvtepi32_ps(imm2);
|
||||
|
||||
/* get the swap sign flag */
|
||||
imm0 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
|
||||
imm0 = avx2_mm256_slli_epi32(imm0, 29);
|
||||
/* get the polynom selection mask
|
||||
there is one polynom for 0 <= x <= Pi/4
|
||||
and another one for Pi/4<x<=Pi/2
|
||||
|
||||
Both branches will be computed.
|
||||
*/
|
||||
imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
||||
imm2 = avx2_mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
|
||||
#else
|
||||
/* we use SSE2 routines to perform the integer ops */
|
||||
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
|
||||
|
||||
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
|
||||
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
|
||||
|
||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
|
||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
|
||||
|
||||
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
|
||||
y = _mm256_cvtepi32_ps(imm2);
|
||||
|
||||
imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
|
||||
imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
|
||||
|
||||
imm0_1 = _mm_slli_epi32(imm0_1, 29);
|
||||
imm0_2 = _mm_slli_epi32(imm0_2, 29);
|
||||
|
||||
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
|
||||
|
||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
|
||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
|
||||
|
||||
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
|
||||
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
|
||||
|
||||
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
|
||||
#endif
|
||||
|
||||
v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
|
||||
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||
sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);
|
||||
|
||||
/* The magic pass: "Extended precision modular arithmetic"
|
||||
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
||||
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
||||
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
||||
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
||||
xmm1 = _mm256_mul_ps(y, xmm1);
|
||||
xmm2 = _mm256_mul_ps(y, xmm2);
|
||||
xmm3 = _mm256_mul_ps(y, xmm3);
|
||||
x = _mm256_add_ps(x, xmm1);
|
||||
x = _mm256_add_ps(x, xmm2);
|
||||
x = _mm256_add_ps(x, xmm3);
|
||||
|
||||
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
||||
y = *(v8sf*)_ps256_coscof_p0;
|
||||
v8sf z = _mm256_mul_ps(x,x);
|
||||
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
||||
y = _mm256_sub_ps(y, tmp);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
||||
|
||||
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
||||
|
||||
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_mul_ps(y2, x);
|
||||
y2 = _mm256_add_ps(y2, x);
|
||||
|
||||
/* select the correct result from the two polynoms */
|
||||
xmm3 = poly_mask;
|
||||
y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
|
||||
y = _mm256_andnot_ps(xmm3, y);
|
||||
y = _mm256_add_ps(y,y2);
|
||||
/* update the sign */
|
||||
y = _mm256_xor_ps(y, sign_bit);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
/* almost the same as sin_ps */
|
||||
static inline v8sf cos256_ps(v8sf x) { // any x
|
||||
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
|
||||
v8si imm0, imm2;
|
||||
|
||||
#ifndef __AVX2__
|
||||
v4si imm0_1, imm0_2;
|
||||
v4si imm2_1, imm2_2;
|
||||
#endif
|
||||
|
||||
/* take the absolute value */
|
||||
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
||||
|
||||
/* scale by 4/Pi */
|
||||
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
||||
|
||||
#ifdef __AVX2__
|
||||
/* store the integer part of y in mm0 */
|
||||
imm2 = _mm256_cvttps_epi32(y);
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
imm2 = avx2_mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
||||
imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
||||
y = _mm256_cvtepi32_ps(imm2);
|
||||
imm2 = avx2_mm256_sub_epi32(imm2, *(v8si*)_pi32_256_2);
|
||||
|
||||
/* get the swap sign flag */
|
||||
imm0 = avx2_mm256_andnot_si256(imm2, *(v8si*)_pi32_256_4);
|
||||
imm0 = avx2_mm256_slli_epi32(imm0, 29);
|
||||
/* get the polynom selection mask */
|
||||
imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
||||
imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
||||
#else
|
||||
|
||||
/* we use SSE2 routines to perform the integer ops */
|
||||
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
|
||||
|
||||
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
|
||||
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
|
||||
|
||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
|
||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
|
||||
|
||||
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
|
||||
y = _mm256_cvtepi32_ps(imm2);
|
||||
|
||||
imm2_1 = _mm_sub_epi32(imm2_1, *(v4si*)_pi32avx_2);
|
||||
imm2_2 = _mm_sub_epi32(imm2_2, *(v4si*)_pi32avx_2);
|
||||
|
||||
imm0_1 = _mm_andnot_si128(imm2_1, *(v4si*)_pi32avx_4);
|
||||
imm0_2 = _mm_andnot_si128(imm2_2, *(v4si*)_pi32avx_4);
|
||||
|
||||
imm0_1 = _mm_slli_epi32(imm0_1, 29);
|
||||
imm0_2 = _mm_slli_epi32(imm0_2, 29);
|
||||
|
||||
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
|
||||
|
||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
|
||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
|
||||
|
||||
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
|
||||
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
|
||||
|
||||
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
|
||||
#endif
|
||||
|
||||
v8sf sign_bit = _mm256_castsi256_ps(imm0);
|
||||
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||
|
||||
/* The magic pass: "Extended precision modular arithmetic"
|
||||
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
||||
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
||||
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
||||
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
||||
xmm1 = _mm256_mul_ps(y, xmm1);
|
||||
xmm2 = _mm256_mul_ps(y, xmm2);
|
||||
xmm3 = _mm256_mul_ps(y, xmm3);
|
||||
x = _mm256_add_ps(x, xmm1);
|
||||
x = _mm256_add_ps(x, xmm2);
|
||||
x = _mm256_add_ps(x, xmm3);
|
||||
|
||||
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
||||
y = *(v8sf*)_ps256_coscof_p0;
|
||||
v8sf z = _mm256_mul_ps(x,x);
|
||||
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
||||
y = _mm256_sub_ps(y, tmp);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
||||
|
||||
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
||||
|
||||
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_mul_ps(y2, x);
|
||||
y2 = _mm256_add_ps(y2, x);
|
||||
|
||||
/* select the correct result from the two polynoms */
|
||||
xmm3 = poly_mask;
|
||||
y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
|
||||
y = _mm256_andnot_ps(xmm3, y);
|
||||
y = _mm256_add_ps(y,y2);
|
||||
/* update the sign */
|
||||
y = _mm256_xor_ps(y, sign_bit);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could replace both of them..
|
||||
it is almost as fast, and gives you a free cosine with your sine */
|
||||
static inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
||||
|
||||
v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
|
||||
v8si imm0, imm2, imm4;
|
||||
|
||||
#ifndef __AVX2__
|
||||
v4si imm0_1, imm0_2;
|
||||
v4si imm2_1, imm2_2;
|
||||
v4si imm4_1, imm4_2;
|
||||
#endif
|
||||
|
||||
sign_bit_sin = x;
|
||||
/* take the absolute value */
|
||||
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
||||
/* extract the sign bit (upper one) */
|
||||
sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf*)_ps256_sign_mask);
|
||||
|
||||
/* scale by 4/Pi */
|
||||
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
||||
|
||||
#ifdef __AVX2__
|
||||
/* store the integer part of y in imm2 */
|
||||
imm2 = _mm256_cvttps_epi32(y);
|
||||
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
imm2 = avx2_mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
||||
imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
||||
|
||||
y = _mm256_cvtepi32_ps(imm2);
|
||||
imm4 = imm2;
|
||||
|
||||
/* get the swap sign flag for the sine */
|
||||
imm0 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
|
||||
imm0 = avx2_mm256_slli_epi32(imm0, 29);
|
||||
//v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
||||
|
||||
/* get the polynom selection mask for the sine*/
|
||||
imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
||||
imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
||||
//v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||
#else
|
||||
/* we use SSE2 routines to perform the integer ops */
|
||||
COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
|
||||
|
||||
imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
|
||||
imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
|
||||
|
||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
|
||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
|
||||
|
||||
COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
|
||||
y = _mm256_cvtepi32_ps(imm2);
|
||||
|
||||
imm4_1 = imm2_1;
|
||||
imm4_2 = imm2_2;
|
||||
|
||||
imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
|
||||
imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
|
||||
|
||||
imm0_1 = _mm_slli_epi32(imm0_1, 29);
|
||||
imm0_2 = _mm_slli_epi32(imm0_2, 29);
|
||||
|
||||
COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
|
||||
|
||||
imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
|
||||
imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
|
||||
|
||||
imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
|
||||
imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
|
||||
|
||||
COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
|
||||
#endif
|
||||
v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
||||
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
||||
|
||||
/* The magic pass: "Extended precision modular arithmetic"
|
||||
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
||||
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
||||
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
||||
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
||||
xmm1 = _mm256_mul_ps(y, xmm1);
|
||||
xmm2 = _mm256_mul_ps(y, xmm2);
|
||||
xmm3 = _mm256_mul_ps(y, xmm3);
|
||||
x = _mm256_add_ps(x, xmm1);
|
||||
x = _mm256_add_ps(x, xmm2);
|
||||
x = _mm256_add_ps(x, xmm3);
|
||||
|
||||
#ifdef __AVX2__
|
||||
imm4 = avx2_mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
|
||||
imm4 = avx2_mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
|
||||
imm4 = avx2_mm256_slli_epi32(imm4, 29);
|
||||
#else
|
||||
imm4_1 = _mm_sub_epi32(imm4_1, *(v4si*)_pi32avx_2);
|
||||
imm4_2 = _mm_sub_epi32(imm4_2, *(v4si*)_pi32avx_2);
|
||||
|
||||
imm4_1 = _mm_andnot_si128(imm4_1, *(v4si*)_pi32avx_4);
|
||||
imm4_2 = _mm_andnot_si128(imm4_2, *(v4si*)_pi32avx_4);
|
||||
|
||||
imm4_1 = _mm_slli_epi32(imm4_1, 29);
|
||||
imm4_2 = _mm_slli_epi32(imm4_2, 29);
|
||||
|
||||
COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4);
|
||||
#endif
|
||||
|
||||
v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
|
||||
|
||||
sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);
|
||||
|
||||
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
||||
v8sf z = _mm256_mul_ps(x,x);
|
||||
y = *(v8sf*)_ps256_coscof_p0;
|
||||
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
y = _mm256_mul_ps(y, z);
|
||||
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
||||
y = _mm256_sub_ps(y, tmp);
|
||||
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
||||
|
||||
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
||||
|
||||
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
||||
y2 = _mm256_mul_ps(y2, z);
|
||||
y2 = _mm256_mul_ps(y2, x);
|
||||
y2 = _mm256_add_ps(y2, x);
|
||||
|
||||
/* select the correct result from the two polynoms */
|
||||
xmm3 = poly_mask;
|
||||
v8sf ysin2 = _mm256_and_ps(xmm3, y2);
|
||||
v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
|
||||
y2 = _mm256_sub_ps(y2,ysin2);
|
||||
y = _mm256_sub_ps(y, ysin1);
|
||||
|
||||
xmm1 = _mm256_add_ps(ysin1,ysin2);
|
||||
xmm2 = _mm256_add_ps(y,y2);
|
||||
|
||||
/* update the sign */
|
||||
*s = _mm256_xor_ps(xmm1, sign_bit_sin);
|
||||
*c = _mm256_xor_ps(xmm2, sign_bit_cos);
|
||||
}
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1 @@
|
|||
Subproject commit 84e66a976046180187724aff60a236c5378fde7c
|
|
@ -0,0 +1,115 @@
|
|||
|
||||
|
||||
|
||||
#include "umHalf.h"
|
||||
#include <iostream>
|
||||
#include <assert.h>
|
||||
|
||||
#define VALIDATE(x) if (!(x)){std::cout << "Failed: " << #x << std::endl;assert((x));}
|
||||
|
||||
int main(int, char*)
|
||||
{
|
||||
half h = 1.f, h2 = 2.f;
|
||||
--h2;
|
||||
++h2;
|
||||
--h;
|
||||
++h;
|
||||
h2 -= 1.f;
|
||||
float f = h2, f2 = h;
|
||||
VALIDATE(1.f == f && f == f2);
|
||||
|
||||
h = h2;
|
||||
h2 = 15.5f;
|
||||
|
||||
f = h2, f2 = h;
|
||||
VALIDATE(15.5f == f && 1.f == f2);
|
||||
h2 *= h;
|
||||
f = h2, f2 = h;
|
||||
VALIDATE(15.5f == f && 1.f == f2);
|
||||
h2 /= h;
|
||||
f = h2, f2 = h;
|
||||
VALIDATE(15.5f == f && 1.f == f2);
|
||||
h2 += h;
|
||||
f = h2, f2 = h;
|
||||
VALIDATE(16.5f == f && 1.f == f2);
|
||||
h++;h++;h++;
|
||||
h2 = -h2;
|
||||
h2 += 17.5f;
|
||||
h2 *= h;
|
||||
f = h2, f2 = h;
|
||||
VALIDATE(4.f == f && 4.f == f2);
|
||||
VALIDATE(h == h2);
|
||||
VALIDATE(h <= h2);
|
||||
--h;
|
||||
VALIDATE(h <= h2);
|
||||
|
||||
h -= 250.f;
|
||||
VALIDATE(h < h2);
|
||||
|
||||
h += 500.f;
|
||||
VALIDATE(h > h2);
|
||||
VALIDATE(h >= h2);
|
||||
|
||||
f = h2, f2 = h;
|
||||
VALIDATE(h * h2 == (half)(f * f2));
|
||||
|
||||
// addition
|
||||
// ****************************************************************************
|
||||
|
||||
// identical exponents
|
||||
for (float v = 0.f; v < 1000.f; ++v)
|
||||
{
|
||||
half one = v;
|
||||
half two = v;
|
||||
half three = one + two;
|
||||
f2 = three;
|
||||
VALIDATE(v*2.f == f2);
|
||||
}
|
||||
|
||||
// different exponents
|
||||
for (float v = 0.f, fp = 1000.f; v < 500.f; ++v, --fp)
|
||||
{
|
||||
half one = v;
|
||||
half two = fp;
|
||||
half three = one + two;
|
||||
f2 = three;
|
||||
VALIDATE(v+fp == f2);
|
||||
}
|
||||
|
||||
// very small numbers - this is already beyond the accuracy of 16 bit floats.
|
||||
for (float v = 0.003f; v < 1000.f; v += 0.0005f)
|
||||
{
|
||||
half one = v;
|
||||
half two = v;
|
||||
half three = one + two;
|
||||
f2 = three;
|
||||
float m = v*2.f;
|
||||
VALIDATE(f2 > (m-0.05*m) && f2 < (m+0.05*m));
|
||||
}
|
||||
|
||||
|
||||
// subtraction
|
||||
// ****************************************************************************
|
||||
|
||||
// identical exponents
|
||||
for (float v = 0.f; v < 1000.f; ++v)
|
||||
{
|
||||
half one = v;
|
||||
half two = v;
|
||||
half three = one - two;
|
||||
f2 = three;
|
||||
VALIDATE(0.f == f2);
|
||||
}
|
||||
|
||||
// different exponents
|
||||
for (float v = 0.f, fp = 1000.f; v < 500.f; ++v, --fp)
|
||||
{
|
||||
half one = v;
|
||||
half two = fp;
|
||||
half three = one - two;
|
||||
f2 = three;
|
||||
VALIDATE(v-fp == f2);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
half_float
|
||||
========
|
||||
|
||||
#### 16 bit floating-point data type for C++ ####
|
||||
|
||||
Implements a `HalfFloat` class that implements all the common arithmetic operations for a 16 bit
|
||||
floating-point type (10 bits mantissa, 5 bits exponent and one sign bit) and can thus be used (almost)
|
||||
interchangeably with regular `float`s. Not all operations have efficent implementations (some just convert to `float`,
|
||||
compute the result and convert back again) - if in doubt, check out the source code.
|
||||
|
||||
The implementation tries to adhere to IEEE 754 in that it supports NaN and Infinity, but fails in other points:
|
||||
|
||||
- no difference between qnan and snan
|
||||
- no traps
|
||||
- no well-defined rounding mode
|
||||
|
||||
|
||||
We also supply a specialization for `std::numeric_limits<half>` that `half` be usable in template code
|
||||
dependent on type traits.
|
||||
|
||||
|
||||
#### Usage ####
|
||||
|
||||
// get some halfs (half is a typedef for HalfFloat)
|
||||
half a = 1.0f;
|
||||
half b = 0.5f;
|
||||
|
||||
// and have some FUN
|
||||
half c = (a+b) / (a-b);
|
||||
++c;
|
||||
|
||||
// now that we have a result in loosy precision,
|
||||
// convert it back to double precision.
|
||||
// if anybody asks, it's for the lulz.
|
||||
double result = c;
|
||||
|
||||
|
||||
Credits to _Chris Maiwald_ for the conversion code to `double` and extensive testing.
|
||||
|
||||
|
||||
#### License ####
|
||||
|
||||
3-clause BSD license: use it for anything, but give credit, don't blame us if your rocket crashes and don't advertise with it (who would).
|
|
@ -0,0 +1,222 @@
|
|||
// ISO C9x compliant stdint.h for Microsoft Visual Studio
|
||||
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
|
||||
//
|
||||
// Copyright (c) 2006 Alexander Chemeris
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// 3. The name of the author may be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
|
||||
// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
|
||||
// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
||||
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
|
||||
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef _MSC_VER // [
|
||||
#error "Use this header only with Microsoft Visual C++ compilers!"
|
||||
#endif // _MSC_VER ]
|
||||
|
||||
#ifndef _MSC_STDINT_H_ // [
|
||||
#define _MSC_STDINT_H_
|
||||
|
||||
#if _MSC_VER > 1000
|
||||
#pragma once
|
||||
#endif
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
// For Visual Studio 6 in C++ mode wrap <wchar.h> include with 'extern "C++" {}'
|
||||
// or compiler give many errors like this:
|
||||
// error C2733: second C linkage of overloaded function 'wmemchr' not allowed
|
||||
#if (_MSC_VER < 1300) && defined(__cplusplus)
|
||||
extern "C++" {
|
||||
#endif
|
||||
# include <wchar.h>
|
||||
#if (_MSC_VER < 1300) && defined(__cplusplus)
|
||||
}
|
||||
#endif
|
||||
|
||||
// 7.18.1 Integer types
|
||||
|
||||
// 7.18.1.1 Exact-width integer types
|
||||
typedef __int8 int8_t;
|
||||
typedef __int16 int16_t;
|
||||
typedef __int32 int32_t;
|
||||
typedef __int64 int64_t;
|
||||
typedef unsigned __int8 uint8_t;
|
||||
typedef unsigned __int16 uint16_t;
|
||||
typedef unsigned __int32 uint32_t;
|
||||
typedef unsigned __int64 uint64_t;
|
||||
|
||||
// 7.18.1.2 Minimum-width integer types
|
||||
typedef int8_t int_least8_t;
|
||||
typedef int16_t int_least16_t;
|
||||
typedef int32_t int_least32_t;
|
||||
typedef int64_t int_least64_t;
|
||||
typedef uint8_t uint_least8_t;
|
||||
typedef uint16_t uint_least16_t;
|
||||
typedef uint32_t uint_least32_t;
|
||||
typedef uint64_t uint_least64_t;
|
||||
|
||||
// 7.18.1.3 Fastest minimum-width integer types
|
||||
typedef int8_t int_fast8_t;
|
||||
typedef int16_t int_fast16_t;
|
||||
typedef int32_t int_fast32_t;
|
||||
typedef int64_t int_fast64_t;
|
||||
typedef uint8_t uint_fast8_t;
|
||||
typedef uint16_t uint_fast16_t;
|
||||
typedef uint32_t uint_fast32_t;
|
||||
typedef uint64_t uint_fast64_t;
|
||||
|
||||
// 7.18.1.4 Integer types capable of holding object pointers
|
||||
#ifdef _WIN64 // [
|
||||
typedef __int64 intptr_t;
|
||||
typedef unsigned __int64 uintptr_t;
|
||||
#else // _WIN64 ][
|
||||
typedef int intptr_t;
|
||||
typedef unsigned int uintptr_t;
|
||||
#endif // _WIN64 ]
|
||||
|
||||
// 7.18.1.5 Greatest-width integer types
|
||||
typedef int64_t intmax_t;
|
||||
typedef uint64_t uintmax_t;
|
||||
|
||||
|
||||
// 7.18.2 Limits of specified-width integer types
|
||||
|
||||
#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
|
||||
|
||||
// 7.18.2.1 Limits of exact-width integer types
|
||||
#define INT8_MIN ((int8_t)_I8_MIN)
|
||||
#define INT8_MAX _I8_MAX
|
||||
#define INT16_MIN ((int16_t)_I16_MIN)
|
||||
#define INT16_MAX _I16_MAX
|
||||
#define INT32_MIN ((int32_t)_I32_MIN)
|
||||
#define INT32_MAX _I32_MAX
|
||||
#define INT64_MIN ((int64_t)_I64_MIN)
|
||||
#define INT64_MAX _I64_MAX
|
||||
#define UINT8_MAX _UI8_MAX
|
||||
#define UINT16_MAX _UI16_MAX
|
||||
#define UINT32_MAX _UI32_MAX
|
||||
#define UINT64_MAX _UI64_MAX
|
||||
|
||||
// 7.18.2.2 Limits of minimum-width integer types
|
||||
#define INT_LEAST8_MIN INT8_MIN
|
||||
#define INT_LEAST8_MAX INT8_MAX
|
||||
#define INT_LEAST16_MIN INT16_MIN
|
||||
#define INT_LEAST16_MAX INT16_MAX
|
||||
#define INT_LEAST32_MIN INT32_MIN
|
||||
#define INT_LEAST32_MAX INT32_MAX
|
||||
#define INT_LEAST64_MIN INT64_MIN
|
||||
#define INT_LEAST64_MAX INT64_MAX
|
||||
#define UINT_LEAST8_MAX UINT8_MAX
|
||||
#define UINT_LEAST16_MAX UINT16_MAX
|
||||
#define UINT_LEAST32_MAX UINT32_MAX
|
||||
#define UINT_LEAST64_MAX UINT64_MAX
|
||||
|
||||
// 7.18.2.3 Limits of fastest minimum-width integer types
|
||||
#define INT_FAST8_MIN INT8_MIN
|
||||
#define INT_FAST8_MAX INT8_MAX
|
||||
#define INT_FAST16_MIN INT16_MIN
|
||||
#define INT_FAST16_MAX INT16_MAX
|
||||
#define INT_FAST32_MIN INT32_MIN
|
||||
#define INT_FAST32_MAX INT32_MAX
|
||||
#define INT_FAST64_MIN INT64_MIN
|
||||
#define INT_FAST64_MAX INT64_MAX
|
||||
#define UINT_FAST8_MAX UINT8_MAX
|
||||
#define UINT_FAST16_MAX UINT16_MAX
|
||||
#define UINT_FAST32_MAX UINT32_MAX
|
||||
#define UINT_FAST64_MAX UINT64_MAX
|
||||
|
||||
// 7.18.2.4 Limits of integer types capable of holding object pointers
|
||||
#ifdef _WIN64 // [
|
||||
# define INTPTR_MIN INT64_MIN
|
||||
# define INTPTR_MAX INT64_MAX
|
||||
# define UINTPTR_MAX UINT64_MAX
|
||||
#else // _WIN64 ][
|
||||
# define INTPTR_MIN INT32_MIN
|
||||
# define INTPTR_MAX INT32_MAX
|
||||
# define UINTPTR_MAX UINT32_MAX
|
||||
#endif // _WIN64 ]
|
||||
|
||||
// 7.18.2.5 Limits of greatest-width integer types
|
||||
#define INTMAX_MIN INT64_MIN
|
||||
#define INTMAX_MAX INT64_MAX
|
||||
#define UINTMAX_MAX UINT64_MAX
|
||||
|
||||
// 7.18.3 Limits of other integer types
|
||||
|
||||
#ifdef _WIN64 // [
|
||||
# define PTRDIFF_MIN _I64_MIN
|
||||
# define PTRDIFF_MAX _I64_MAX
|
||||
#else // _WIN64 ][
|
||||
# define PTRDIFF_MIN _I32_MIN
|
||||
# define PTRDIFF_MAX _I32_MAX
|
||||
#endif // _WIN64 ]
|
||||
|
||||
#define SIG_ATOMIC_MIN INT_MIN
|
||||
#define SIG_ATOMIC_MAX INT_MAX
|
||||
|
||||
#ifndef SIZE_MAX // [
|
||||
# ifdef _WIN64 // [
|
||||
# define SIZE_MAX _UI64_MAX
|
||||
# else // _WIN64 ][
|
||||
# define SIZE_MAX _UI32_MAX
|
||||
# endif // _WIN64 ]
|
||||
#endif // SIZE_MAX ]
|
||||
|
||||
// WCHAR_MIN and WCHAR_MAX are also defined in <wchar.h>
|
||||
#ifndef WCHAR_MIN // [
|
||||
# define WCHAR_MIN 0
|
||||
#endif // WCHAR_MIN ]
|
||||
#ifndef WCHAR_MAX // [
|
||||
# define WCHAR_MAX _UI16_MAX
|
||||
#endif // WCHAR_MAX ]
|
||||
|
||||
#define WINT_MIN 0
|
||||
#define WINT_MAX _UI16_MAX
|
||||
|
||||
#endif // __STDC_LIMIT_MACROS ]
|
||||
|
||||
|
||||
// 7.18.4 Limits of other integer types
|
||||
|
||||
#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260
|
||||
|
||||
// 7.18.4.1 Macros for minimum-width integer constants
|
||||
|
||||
#define INT8_C(val) val##i8
|
||||
#define INT16_C(val) val##i16
|
||||
#define INT32_C(val) val##i32
|
||||
#define INT64_C(val) val##i64
|
||||
|
||||
#define UINT8_C(val) val##ui8
|
||||
#define UINT16_C(val) val##ui16
|
||||
#define UINT32_C(val) val##ui32
|
||||
#define UINT64_C(val) val##ui64
|
||||
|
||||
// 7.18.4.2 Macros for greatest-width integer constants
|
||||
#define INTMAX_C INT64_C
|
||||
#define UINTMAX_C UINT64_C
|
||||
|
||||
#endif // __STDC_CONSTANT_MACROS ]
|
||||
|
||||
|
||||
#endif // _MSC_STDINT_H_ ]
|
|
@ -0,0 +1,294 @@
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
Copyright (c) 2006-2008,
|
||||
Chris "Krishty" Maiwald, Alexander "Aramis" Gessler
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use of this software in source and binary forms,
|
||||
with or without modification, are permitted provided that the following
|
||||
conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above
|
||||
copyright notice, this list of conditions and the
|
||||
following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the
|
||||
following disclaimer in the documentation and/or other
|
||||
materials provided with the distribution.
|
||||
|
||||
* Neither the name of the class, nor the names of its
|
||||
contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior
|
||||
written permission of the Development Team.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef UM_HALF_H_INCLUDED
|
||||
#define UM_HALF_H_INCLUDED
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#undef min
|
||||
#undef max
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
/** 1. Represents a half-precision floating point value (16 bits) that behaves
|
||||
* nearly conformant to the IEE 754 standard for floating-point computations.
|
||||
*
|
||||
* Not all operators have special implementations, most perform time-consuming
|
||||
* conversions from half to float and back again.
|
||||
* Differences to IEEE 754:
|
||||
* - no difference between qnan and snan
|
||||
* - no traps
|
||||
* - no well-defined rounding mode
|
||||
*/
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
class HalfFloat
|
||||
{
|
||||
friend HalfFloat operator+ (HalfFloat, HalfFloat);
|
||||
friend HalfFloat operator- (HalfFloat, HalfFloat);
|
||||
friend HalfFloat operator* (HalfFloat, HalfFloat);
|
||||
friend HalfFloat operator/ (HalfFloat, HalfFloat);
|
||||
|
||||
public:
|
||||
|
||||
enum { BITS_MANTISSA = 10 };
|
||||
enum { BITS_EXPONENT = 5 };
|
||||
|
||||
enum { MAX_EXPONENT_VALUE = 31 };
|
||||
enum { BIAS = MAX_EXPONENT_VALUE/2 };
|
||||
|
||||
enum { MAX_EXPONENT = BIAS };
|
||||
enum { MIN_EXPONENT = -BIAS };
|
||||
|
||||
enum { MAX_EXPONENT10 = 9 };
|
||||
enum { MIN_EXPONENT10 = -9 };
|
||||
|
||||
public:
|
||||
|
||||
/** Default constructor. Unitialized by default.
|
||||
*/
|
||||
inline HalfFloat() {}
|
||||
|
||||
/** Construction from an existing half
|
||||
*/
|
||||
inline HalfFloat(const HalfFloat& other)
|
||||
: bits(other.GetBits())
|
||||
{}
|
||||
|
||||
/** Construction from existing values for mantissa, sign
|
||||
* and exponent. No validation is performed.
|
||||
* @note The exponent is unsigned and biased by #BIAS
|
||||
*/
|
||||
inline HalfFloat(uint16_t _m,uint16_t _e,uint16_t _s);
|
||||
|
||||
|
||||
/** Construction from a single-precision float
|
||||
*/
|
||||
inline HalfFloat(float other);
|
||||
|
||||
/** Conversion operator to convert from half to float
|
||||
*/
|
||||
inline operator float() const;
|
||||
|
||||
/** Assignment operator to assign another half to
|
||||
* *this* object.
|
||||
*/
|
||||
inline HalfFloat& operator= (HalfFloat other);
|
||||
inline HalfFloat& operator= (float other);
|
||||
|
||||
/** Comparison operators
|
||||
*/
|
||||
inline bool operator== (HalfFloat other) const;
|
||||
inline bool operator!= (HalfFloat other) const;
|
||||
|
||||
|
||||
/** Relational comparison operators
|
||||
*/
|
||||
inline bool operator< (HalfFloat other) const;
|
||||
inline bool operator> (HalfFloat other) const;
|
||||
inline bool operator<= (HalfFloat other) const;
|
||||
inline bool operator>= (HalfFloat other) const;
|
||||
|
||||
inline bool operator< (float other) const;
|
||||
inline bool operator> (float other) const;
|
||||
inline bool operator<= (float other) const;
|
||||
inline bool operator>= (float other) const;
|
||||
|
||||
|
||||
/** Combined assignment operators
|
||||
*/
|
||||
inline HalfFloat& operator += (HalfFloat other);
|
||||
inline HalfFloat& operator -= (HalfFloat other);
|
||||
inline HalfFloat& operator *= (HalfFloat other);
|
||||
inline HalfFloat& operator /= (HalfFloat other);
|
||||
|
||||
inline HalfFloat& operator += (float other);
|
||||
inline HalfFloat& operator -= (float other);
|
||||
inline HalfFloat& operator *= (float other);
|
||||
inline HalfFloat& operator /= (float other);
|
||||
|
||||
/** Post and prefix increment operators
|
||||
*/
|
||||
inline HalfFloat& operator++();
|
||||
inline HalfFloat operator++(int);
|
||||
|
||||
/** Post and prefix decrement operators
|
||||
*/
|
||||
inline HalfFloat& operator--();
|
||||
inline HalfFloat operator--(int);
|
||||
|
||||
/** Unary minus operator
|
||||
*/
|
||||
inline HalfFloat operator-() const;
|
||||
|
||||
|
||||
/** Provides direct access to the bits of a half float
|
||||
*/
|
||||
inline uint16_t GetBits() const;
|
||||
inline uint16_t& GetBits();
|
||||
|
||||
|
||||
/** Classification of floating-point types
|
||||
*/
|
||||
inline bool IsNaN() const;
|
||||
inline bool IsInfinity() const;
|
||||
inline bool IsDenorm() const;
|
||||
|
||||
/** Returns the sign of the floating-point value -
|
||||
* true stands for positive.
|
||||
*/
|
||||
inline bool GetSign() const;
|
||||
|
||||
public:
|
||||
|
||||
union
|
||||
{
|
||||
uint16_t bits; // All bits
|
||||
struct
|
||||
{
|
||||
uint16_t Frac : 10; // mantissa
|
||||
uint16_t Exp : 5; // exponent
|
||||
uint16_t Sign : 1; // sign
|
||||
} IEEE;
|
||||
};
|
||||
|
||||
|
||||
union IEEESingle
|
||||
{
|
||||
float Float;
|
||||
struct
|
||||
{
|
||||
uint32_t Frac : 23;
|
||||
uint32_t Exp : 8;
|
||||
uint32_t Sign : 1;
|
||||
} IEEE;
|
||||
};
|
||||
};
|
||||
|
||||
/** 2. Binary operations
|
||||
*/
|
||||
inline HalfFloat operator+ (HalfFloat one, HalfFloat two);
|
||||
inline HalfFloat operator- (HalfFloat one, HalfFloat two);
|
||||
inline HalfFloat operator* (HalfFloat one, HalfFloat two);
|
||||
inline HalfFloat operator/ (HalfFloat one, HalfFloat two);
|
||||
|
||||
inline float operator+ (HalfFloat one, float two);
|
||||
inline float operator- (HalfFloat one, float two);
|
||||
inline float operator* (HalfFloat one, float two);
|
||||
inline float operator/ (HalfFloat one, float two);
|
||||
|
||||
inline float operator+ (float one, HalfFloat two);
|
||||
inline float operator- (float one, HalfFloat two);
|
||||
inline float operator* (float one, HalfFloat two);
|
||||
inline float operator/ (float one, HalfFloat two);
|
||||
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
/** 3. Specialization of std::numeric_limits for type half.
|
||||
*/
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
namespace std {
|
||||
template <>
|
||||
class numeric_limits<HalfFloat> {
|
||||
|
||||
public:
|
||||
|
||||
// General -- meaningful for all specializations.
|
||||
|
||||
static const bool is_specialized = true;
|
||||
static HalfFloat min ()
|
||||
{return HalfFloat(0,1,0);}
|
||||
static HalfFloat max ()
|
||||
{return HalfFloat((uint16_t)~0,HalfFloat::MAX_EXPONENT_VALUE-1,0);}
|
||||
static const int radix = 2;
|
||||
static const int digits = 10; // conservative assumption
|
||||
static const int digits10 = 2; // conservative assumption
|
||||
static const bool is_signed = true;
|
||||
static const bool is_integer = true;
|
||||
static const bool is_exact = false;
|
||||
static const bool traps = false;
|
||||
static const bool is_modulo = false;
|
||||
static const bool is_bounded = true;
|
||||
|
||||
static const HalfFloat lowest() {
|
||||
return HalfFloat((uint16_t)~0,HalfFloat::MAX_EXPONENT_VALUE-1,(uint16_t)~0);
|
||||
}
|
||||
|
||||
// Floating point specific.
|
||||
|
||||
static HalfFloat epsilon ()
|
||||
{return HalfFloat(0.00097656f);} // from OpenEXR, needs to be confirmed
|
||||
static HalfFloat round_error ()
|
||||
{return HalfFloat(0.00097656f/2);}
|
||||
static const int min_exponent10 = HalfFloat::MIN_EXPONENT10;
|
||||
static const int max_exponent10 = HalfFloat::MAX_EXPONENT10;
|
||||
static const int min_exponent = HalfFloat::MIN_EXPONENT;
|
||||
static const int max_exponent = HalfFloat::MAX_EXPONENT;
|
||||
|
||||
static const bool has_infinity = true;
|
||||
static const bool has_quiet_NaN = true;
|
||||
static const bool has_signaling_NaN = true;
|
||||
static const bool is_iec559 = false;
|
||||
static const bool has_denorm = denorm_present;
|
||||
static const bool tinyness_before = false;
|
||||
static const float_round_style round_style = round_to_nearest;
|
||||
|
||||
static HalfFloat denorm_min ()
|
||||
{return HalfFloat(1,0,1);}
|
||||
static HalfFloat infinity ()
|
||||
{return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,0);}
|
||||
static HalfFloat quiet_NaN ()
|
||||
{return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);}
|
||||
static HalfFloat signaling_NaN ()
|
||||
{return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);}
|
||||
};
|
||||
} // end namespace std
|
||||
|
||||
|
||||
#include "umHalf.inl"
|
||||
|
||||
#ifndef UM_HALF_NO_TYPEDEFS
|
||||
typedef HalfFloat float16;
|
||||
#endif
|
||||
|
||||
#endif // !! UM_HALF_H_INCLUDED
|
|
@ -0,0 +1,495 @@
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
Copyright (c) 2006-2008, Alexander Gessler
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use of this software in source and binary forms,
|
||||
with or without modification, are permitted provided that the following
|
||||
conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above
|
||||
copyright notice, this list of conditions and the
|
||||
following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the
|
||||
following disclaimer in the documentation and/or other
|
||||
materials provided with the distribution.
|
||||
|
||||
* Neither the name of the ASSIMP team, nor the names of its
|
||||
contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior
|
||||
written permission of the ASSIMP Development Team.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef UM_HALF_INL_INCLUDED
|
||||
#define UM_HALF_INL_INCLUDED
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#pragma intrinsic(_BitScanReverse)
|
||||
#endif
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat::HalfFloat(float other)
|
||||
{
|
||||
IEEESingle f;
|
||||
f.Float = other;
|
||||
|
||||
IEEE.Sign = f.IEEE.Sign;
|
||||
|
||||
if ( !f.IEEE.Exp)
|
||||
{
|
||||
IEEE.Frac = 0;
|
||||
IEEE.Exp = 0;
|
||||
}
|
||||
else if (f.IEEE.Exp==0xff)
|
||||
{
|
||||
// NaN or INF
|
||||
IEEE.Frac = (f.IEEE.Frac!=0) ? 1 : 0;
|
||||
IEEE.Exp = 31;
|
||||
}
|
||||
else
|
||||
{
|
||||
// regular number
|
||||
int new_exp = f.IEEE.Exp-127;
|
||||
|
||||
if (new_exp<-24)
|
||||
{ // this maps to 0
|
||||
IEEE.Frac = 0;
|
||||
IEEE.Exp = 0;
|
||||
}
|
||||
|
||||
else if (new_exp<-14)
|
||||
{
|
||||
// this maps to a denorm
|
||||
IEEE.Exp = 0;
|
||||
unsigned int exp_val = (unsigned int) (-14 - new_exp); // 2^-exp_val
|
||||
switch (exp_val)
|
||||
{
|
||||
case 0:
|
||||
IEEE.Frac = 0;
|
||||
break;
|
||||
case 1: IEEE.Frac = 512 + (f.IEEE.Frac>>14); break;
|
||||
case 2: IEEE.Frac = 256 + (f.IEEE.Frac>>15); break;
|
||||
case 3: IEEE.Frac = 128 + (f.IEEE.Frac>>16); break;
|
||||
case 4: IEEE.Frac = 64 + (f.IEEE.Frac>>17); break;
|
||||
case 5: IEEE.Frac = 32 + (f.IEEE.Frac>>18); break;
|
||||
case 6: IEEE.Frac = 16 + (f.IEEE.Frac>>19); break;
|
||||
case 7: IEEE.Frac = 8 + (f.IEEE.Frac>>20); break;
|
||||
case 8: IEEE.Frac = 4 + (f.IEEE.Frac>>21); break;
|
||||
case 9: IEEE.Frac = 2 + (f.IEEE.Frac>>22); break;
|
||||
case 10: IEEE.Frac = 1; break;
|
||||
}
|
||||
}
|
||||
else if (new_exp>15)
|
||||
{ // map this value to infinity
|
||||
IEEE.Frac = 0;
|
||||
IEEE.Exp = 31;
|
||||
}
|
||||
else
|
||||
{
|
||||
IEEE.Exp = new_exp+15;
|
||||
IEEE.Frac = (f.IEEE.Frac >> 13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline HalfFloat::HalfFloat(uint16_t _m,uint16_t _e,uint16_t _s)
|
||||
{
|
||||
IEEE.Frac = _m;
|
||||
IEEE.Exp = _e;
|
||||
IEEE.Sign = _s;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
HalfFloat::operator float() const
|
||||
{
|
||||
IEEESingle sng;
|
||||
sng.IEEE.Sign = IEEE.Sign;
|
||||
|
||||
if (!IEEE.Exp)
|
||||
{
|
||||
if (!IEEE.Frac)
|
||||
{
|
||||
sng.IEEE.Frac=0;
|
||||
sng.IEEE.Exp=0;
|
||||
}
|
||||
else
|
||||
{
|
||||
const float half_denorm = (1.0f/16384.0f);
|
||||
float mantissa = ((float)(IEEE.Frac)) / 1024.0f;
|
||||
float sgn = (IEEE.Sign)? -1.0f :1.0f;
|
||||
sng.Float = sgn*mantissa*half_denorm;
|
||||
}
|
||||
}
|
||||
else if (31 == IEEE.Exp)
|
||||
{
|
||||
sng.IEEE.Exp = 0xff;
|
||||
sng.IEEE.Frac = (IEEE.Frac!=0) ? 1 : 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
sng.IEEE.Exp = IEEE.Exp+112;
|
||||
sng.IEEE.Frac = (IEEE.Frac << 13);
|
||||
}
|
||||
return sng.Float;
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::IsNaN() const
|
||||
{
|
||||
return IEEE.Frac != 0 && IEEE.Exp == MAX_EXPONENT_VALUE;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::IsInfinity() const
|
||||
{
|
||||
return IEEE.Frac == 0 && IEEE.Exp == MAX_EXPONENT_VALUE;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::IsDenorm() const
|
||||
{
|
||||
return IEEE.Exp == 0;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::GetSign() const
|
||||
{
|
||||
return IEEE.Sign == 0;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator= (HalfFloat other)
|
||||
{
|
||||
bits = other.GetBits();
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator= (float other)
|
||||
{
|
||||
*this = (HalfFloat)other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::operator== (HalfFloat other) const
|
||||
{
|
||||
// +0 and -0 are considered to be equal
|
||||
if ((bits << 1u) == 0 && (other.bits << 1u) == 0) return true;
|
||||
|
||||
return bits == other.bits && !this->IsNaN();
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::operator!= (HalfFloat other) const
|
||||
{
|
||||
// +0 and -0 are considered to be equal
|
||||
if ((bits << 1u) == 0 && (other.bits << 1u) == 0) return false;
|
||||
|
||||
return bits != other.bits || this->IsNaN();
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::operator< (HalfFloat other) const
|
||||
{
|
||||
// NaN comparisons are always false
|
||||
if (this->IsNaN() || other.IsNaN())
|
||||
return false;
|
||||
|
||||
// this works since the segment oder is s,e,m.
|
||||
return (int16_t)this->bits < (int16_t)other.GetBits();
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::operator> (HalfFloat other) const
|
||||
{
|
||||
// NaN comparisons are always false
|
||||
if (this->IsNaN() || other.IsNaN())
|
||||
return false;
|
||||
|
||||
// this works since the segment oder is s,e,m.
|
||||
return (int16_t)this->bits > (int16_t)other.GetBits();
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::operator<= (HalfFloat other) const
|
||||
{
|
||||
return !(*this > other);
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline bool HalfFloat::operator>= (HalfFloat other) const
|
||||
{
|
||||
return !(*this < other);
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator += (HalfFloat other)
|
||||
{
|
||||
*this = (*this) + other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator -= (HalfFloat other)
|
||||
{
|
||||
*this = (*this) - other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator *= (HalfFloat other)
|
||||
{
|
||||
*this = (float)(*this) * (float)other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator /= (HalfFloat other)
|
||||
{
|
||||
*this = (float)(*this) / (float)other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator += (float other)
|
||||
{
|
||||
*this = (*this) + (HalfFloat)other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator -= (float other)
|
||||
{
|
||||
*this = (*this) - (HalfFloat)other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator *= (float other)
|
||||
{
|
||||
*this = (float)(*this) * other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator /= (float other)
|
||||
{
|
||||
*this = (float)(*this) / other;
|
||||
return *this;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator++()
|
||||
{
|
||||
// setting the exponent to bias means using 0 as exponent - thus we
|
||||
// can set the mantissa to any value we like, we'll always get 1.0
|
||||
return this->operator+=(HalfFloat(0,BIAS,0));
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat HalfFloat::operator++(int)
|
||||
{
|
||||
HalfFloat f = *this;
|
||||
this->operator+=(HalfFloat(0,BIAS,0));
|
||||
return f;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat& HalfFloat::operator--()
|
||||
{
|
||||
return this->operator-=(HalfFloat(0,BIAS,0));
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat HalfFloat::operator--(int)
|
||||
{
|
||||
HalfFloat f = *this;
|
||||
this->operator-=(HalfFloat(0,BIAS,0));
|
||||
return f;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat HalfFloat::operator-() const
|
||||
{
|
||||
return HalfFloat(IEEE.Frac,IEEE.Exp,~IEEE.Sign);
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline uint16_t HalfFloat::GetBits() const
|
||||
{
|
||||
return bits;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline uint16_t& HalfFloat::GetBits()
|
||||
{
|
||||
return bits;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat operator+ (HalfFloat one, HalfFloat two)
|
||||
{
|
||||
#if (!defined HALFFLOAT_NO_CUSTOM_IMPLEMENTATIONS)
|
||||
|
||||
if (one.IEEE.Exp == HalfFloat::MAX_EXPONENT_VALUE)
|
||||
{
|
||||
// if one of the components is NaN the result becomes NaN, too.
|
||||
if (0 != one.IEEE.Frac || two.IsNaN())
|
||||
return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);
|
||||
|
||||
// otherwise this must be infinity
|
||||
return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,one.IEEE.Sign | two.IEEE.Sign);
|
||||
}
|
||||
else if (two.IEEE.Exp == HalfFloat::MAX_EXPONENT_VALUE)
|
||||
{
|
||||
if (one.IsNaN() || 0 != two.IEEE.Frac)
|
||||
return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);
|
||||
|
||||
return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,one.IEEE.Sign | two.IEEE.Sign);
|
||||
}
|
||||
|
||||
HalfFloat out;
|
||||
long m1,m2,temp;
|
||||
|
||||
// compute the difference between the two exponents. shifts with negative
|
||||
// numbers are undefined, thus we need two code paths
|
||||
register int expDiff = one.IEEE.Exp - two.IEEE.Exp;
|
||||
|
||||
if (0 == expDiff)
|
||||
{
|
||||
// the exponents are equal, thus we must just add the hidden bit
|
||||
temp = two.IEEE.Exp;
|
||||
|
||||
if (0 == one.IEEE.Exp)m1 = one.IEEE.Frac;
|
||||
else m1 = (int)one.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
|
||||
|
||||
if (0 == two.IEEE.Exp)m2 = two.IEEE.Frac;
|
||||
else m2 = (int)two.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
|
||||
}
|
||||
else
|
||||
{
|
||||
if (expDiff < 0)
|
||||
{
|
||||
expDiff = -expDiff;
|
||||
std::swap(one,two);
|
||||
}
|
||||
|
||||
m1 = (int)one.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
|
||||
|
||||
if (0 == two.IEEE.Exp)m2 = two.IEEE.Frac;
|
||||
else m2 = (int)two.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
|
||||
|
||||
if (expDiff < ((sizeof(long)<<3)-(HalfFloat::BITS_MANTISSA+1)))
|
||||
{
|
||||
m1 <<= expDiff;
|
||||
temp = two.IEEE.Exp;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (0 != two.IEEE.Exp)
|
||||
{
|
||||
// arithmetic underflow
|
||||
if (expDiff > HalfFloat::BITS_MANTISSA)return HalfFloat(0,0,0);
|
||||
else
|
||||
{
|
||||
m2 >>= expDiff;
|
||||
}
|
||||
}
|
||||
temp = one.IEEE.Exp;
|
||||
}
|
||||
}
|
||||
|
||||
// convert from sign-bit to two's complement representation
|
||||
if (one.IEEE.Sign)m1 = -m1;
|
||||
if (two.IEEE.Sign)m2 = -m2;
|
||||
m1 += m2;
|
||||
if (m1 < 0)
|
||||
{
|
||||
out.IEEE.Sign = 1;
|
||||
m1 = -m1;
|
||||
}
|
||||
else out.IEEE.Sign = 0;
|
||||
|
||||
// and renormalize the result to fit in a half
|
||||
if (0 == m1)return HalfFloat(0,0,0);
|
||||
|
||||
#ifdef _MSC_VER
|
||||
_BitScanReverse((unsigned long*)&m2,m1);
|
||||
#else
|
||||
m2 = __builtin_clz(m1);
|
||||
#endif
|
||||
expDiff = m2 - HalfFloat::BITS_MANTISSA;
|
||||
temp += expDiff;
|
||||
if (expDiff >= HalfFloat::MAX_EXPONENT_VALUE)
|
||||
{
|
||||
// arithmetic overflow. return INF and keep the sign
|
||||
return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,out.IEEE.Sign);
|
||||
}
|
||||
else if (temp <= 0)
|
||||
{
|
||||
// this maps to a denorm
|
||||
m1 <<= (-expDiff-1);
|
||||
temp = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// rebuild the normalized representation, take care of the hidden bit
|
||||
if (expDiff < 0)m1 <<= (-expDiff);
|
||||
else m1 >>= expDiff; // m1 >= 0
|
||||
}
|
||||
out.IEEE.Frac = m1;
|
||||
out.IEEE.Exp = temp;
|
||||
return out;
|
||||
|
||||
#else
|
||||
return HalfFloat((float)one + (float)two);
|
||||
#endif
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat operator- (HalfFloat one, HalfFloat two)
|
||||
{
|
||||
return HalfFloat(one + (-two));
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat operator* (HalfFloat one, HalfFloat two)
|
||||
{
|
||||
return HalfFloat((float)one * (float)two);
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline HalfFloat operator/ (HalfFloat one, HalfFloat two)
|
||||
{
|
||||
return HalfFloat((float)one / (float)two);
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator+ (HalfFloat one, float two)
|
||||
{
|
||||
return (float)one + two;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator- (HalfFloat one, float two)
|
||||
{
|
||||
return (float)one - two;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator* (HalfFloat one, float two)
|
||||
{
|
||||
return (float)one * two;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator/ (HalfFloat one, float two)
|
||||
{
|
||||
return (float)one / two;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator+ (float one, HalfFloat two)
|
||||
{
|
||||
return two + one;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator- (float one, HalfFloat two)
|
||||
{
|
||||
return two - one;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator* (float one, HalfFloat two)
|
||||
{
|
||||
return two * one;
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
inline float operator/ (float one, HalfFloat two)
|
||||
{
|
||||
return two / one;
|
||||
}
|
||||
|
||||
#endif //!! UM_HALF_INL_INCLUDED
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2018 https://github.com/mandreyel/
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,337 @@
|
|||
# mio
|
||||
An easy to use header-only cross-platform C++11 memory mapping library with an MIT license.
|
||||
|
||||
mio has been created with the goal to be easily includable (i.e. no dependencies) in any C++ project that needs memory mapped file IO without the need to pull in Boost.
|
||||
|
||||
Please feel free to open an issue, I'll try to address any concerns as best I can.
|
||||
|
||||
### Why?
|
||||
Because memory mapping is the best thing since sliced bread!
|
||||
|
||||
More seriously, the primary motivation for writing this library instead of using Boost.Iostreams, was the lack of support for establishing a memory mapping with an already open file handle/descriptor. This is possible with mio.
|
||||
|
||||
Furthermore, Boost.Iostreams' solution requires that the user pick offsets exactly at page boundaries, which is cumbersome and error prone. mio, on the other hand, manages this internally, accepting any offset and finding the nearest page boundary.
|
||||
|
||||
Albeit a minor nitpick, Boost.Iostreams implements memory mapped file IO with a `std::shared_ptr` to provide shared semantics, even if not needed, and the overhead of the heap allocation may be unnecessary and/or unwanted.
|
||||
In mio, there are two classes to cover the two use-cases: one that is move-only (basically a zero-cost abstraction over the system specific mmapping functions), and the other that acts just like its Boost.Iostreams counterpart, with shared semantics.
|
||||
|
||||
### How to create a mapping
|
||||
NOTE: the file must exist before creating a mapping.
|
||||
|
||||
There are three ways to map a file into memory:
|
||||
|
||||
- Using the constructor, which throws a `std::system_error` on failure:
|
||||
```c++
|
||||
mio::mmap_source mmap(path, offset, size_to_map);
|
||||
```
|
||||
or you can omit the `offset` and `size_to_map` arguments, in which case the
|
||||
entire file is mapped:
|
||||
```c++
|
||||
mio::mmap_source mmap(path);
|
||||
```
|
||||
|
||||
- Using the factory function:
|
||||
```c++
|
||||
std::error_code error;
|
||||
mio::mmap_source mmap = mio::make_mmap_source(path, offset, size_to_map, error);
|
||||
```
|
||||
or:
|
||||
```c++
|
||||
mio::mmap_source mmap = mio::make_mmap_source(path, error);
|
||||
```
|
||||
|
||||
- Using the `map` member function:
|
||||
```c++
|
||||
std::error_code error;
|
||||
mio::mmap_source mmap;
|
||||
mmap.map(path, offset, size_to_map, error);
|
||||
```
|
||||
or:
|
||||
```c++
|
||||
mmap.map(path, error);
|
||||
```
|
||||
**NOTE:** The constructors **require** exceptions to be enabled. If you prefer
|
||||
to build your projects with `-fno-exceptions`, you can still use the other ways.
|
||||
|
||||
Moreover, in each case, you can provide either some string type for the file's path, or you can use an existing, valid file handle.
|
||||
```c++
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#include <mio/mmap.hpp>
|
||||
#include <algorithm>
|
||||
|
||||
int main()
|
||||
{
|
||||
// NOTE: error handling omitted for brevity.
|
||||
const int fd = open("file.txt", O_RDONLY);
|
||||
mio::mmap_source mmap(fd, 0, mio::map_entire_file);
|
||||
// ...
|
||||
}
|
||||
```
|
||||
However, mio does not check whether the provided file descriptor has the same access permissions as the desired mapping, so the mapping may fail. Such errors are reported via the `std::error_code` out parameter that is passed to the mapping function.
|
||||
|
||||
**WINDOWS USERS**: This library *does* support the use of wide character types
|
||||
for functions where character strings are expected (e.g. path parameters).
|
||||
|
||||
### Example
|
||||
|
||||
```c++
|
||||
#include <mio/mmap.hpp>
|
||||
#include <system_error> // for std::error_code
|
||||
#include <cstdio> // for std::printf
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
|
||||
int handle_error(const std::error_code& error);
|
||||
void allocate_file(const std::string& path, const int size);
|
||||
|
||||
int main()
|
||||
{
|
||||
const auto path = "file.txt";
|
||||
|
||||
// NOTE: mio does *not* create the file for you if it doesn't exist! You
|
||||
// must ensure that the file exists before establishing a mapping. It
|
||||
// must also be non-empty. So for illustrative purposes the file is
|
||||
// created now.
|
||||
allocate_file(path, 155);
|
||||
|
||||
// Read-write memory map the whole file by using `map_entire_file` where the
|
||||
// length of the mapping is otherwise expected, with the factory method.
|
||||
std::error_code error;
|
||||
mio::mmap_sink rw_mmap = mio::make_mmap_sink(
|
||||
path, 0, mio::map_entire_file, error);
|
||||
if (error) { return handle_error(error); }
|
||||
|
||||
// You can use any iterator based function.
|
||||
std::fill(rw_mmap.begin(), rw_mmap.end(), 'a');
|
||||
|
||||
// Or manually iterate through the mapped region just as if it were any other
|
||||
// container, and change each byte's value (since this is a read-write mapping).
|
||||
for (auto& b : rw_mmap) {
|
||||
b += 10;
|
||||
}
|
||||
|
||||
// Or just change one value with the subscript operator.
|
||||
const int answer_index = rw_mmap.size() / 2;
|
||||
rw_mmap[answer_index] = 42;
|
||||
|
||||
// Don't forget to flush changes to disk before unmapping. However, if
|
||||
// `rw_mmap` were to go out of scope at this point, the destructor would also
|
||||
// automatically invoke `sync` before `unmap`.
|
||||
rw_mmap.sync(error);
|
||||
if (error) { return handle_error(error); }
|
||||
|
||||
// We can then remove the mapping, after which rw_mmap will be in a default
|
||||
// constructed state, i.e. this and the above call to `sync` have the same
|
||||
// effect as if the destructor had been invoked.
|
||||
rw_mmap.unmap();
|
||||
|
||||
// Now create the same mapping, but in read-only mode. Note that calling the
|
||||
// overload without the offset and file length parameters maps the entire
|
||||
// file.
|
||||
mio::mmap_source ro_mmap;
|
||||
ro_mmap.map(path, error);
|
||||
if (error) { return handle_error(error); }
|
||||
|
||||
const int the_answer_to_everything = ro_mmap[answer_index];
|
||||
assert(the_answer_to_everything == 42);
|
||||
}
|
||||
|
||||
int handle_error(const std::error_code& error)
|
||||
{
|
||||
const auto& errmsg = error.message();
|
||||
std::printf("error mapping file: %s, exiting...\n", errmsg.c_str());
|
||||
return error.value();
|
||||
}
|
||||
|
||||
void allocate_file(const std::string& path, const int size)
|
||||
{
|
||||
std::ofstream file(path);
|
||||
std::string s(size, '0');
|
||||
file << s;
|
||||
}
|
||||
```
|
||||
|
||||
`mio::basic_mmap` is move-only, but if multiple copies to the same mapping are needed, use `mio::basic_shared_mmap` which has `std::shared_ptr` semantics and has the same interface as `mio::basic_mmap`.
|
||||
```c++
|
||||
#include <mio/shared_mmap.hpp>
|
||||
|
||||
mio::shared_mmap_source shared_mmap1("path", offset, size_to_map);
|
||||
mio::shared_mmap_source shared_mmap2(std::move(mmap1)); // or use operator=
|
||||
mio::shared_mmap_source shared_mmap3(std::make_shared<mio::mmap_source>(mmap1)); // or use operator=
|
||||
mio::shared_mmap_source shared_mmap4;
|
||||
shared_mmap4.map("path", offset, size_to_map, error);
|
||||
```
|
||||
|
||||
It's possible to define the type of a byte (which has to be the same width as `char`), though aliases for the most common ones are provided by default:
|
||||
```c++
|
||||
using mmap_source = basic_mmap_source<char>;
|
||||
using ummap_source = basic_mmap_source<unsigned char>;
|
||||
|
||||
using mmap_sink = basic_mmap_sink<char>;
|
||||
using ummap_sink = basic_mmap_sink<unsigned char>;
|
||||
```
|
||||
But it may be useful to define your own types, say when using the new `std::byte` type in C++17:
|
||||
```c++
|
||||
using mmap_source = mio::basic_mmap_source<std::byte>;
|
||||
using mmap_sink = mio::basic_mmap_sink<std::byte>;
|
||||
```
|
||||
|
||||
Though generally not needed, since mio maps users requested offsets to page boundaries, you can query the underlying system's page allocation granularity by invoking `mio::page_size()`, which is located in `mio/page.hpp`.
|
||||
|
||||
### Single Header File
|
||||
Mio can be added to your project as a single header file simply by including `\single_include\mio\mio.hpp`. Single header files can be regenerated at any time by running the `amalgamate.py` script within `\third_party`.
|
||||
```
|
||||
python amalgamate.py -c config.json -s ../include
|
||||
```
|
||||
|
||||
## CMake
|
||||
As a header-only library, mio has no compiled components. Nevertheless, a [CMake](https://cmake.org/overview/) build system is provided to allow easy testing, installation, and subproject composition on many platforms and operating systems.
|
||||
|
||||
### Testing
|
||||
Mio is distributed with a small suite of tests and examples.
|
||||
When mio is configured as the highest level CMake project, this suite of executables is built by default.
|
||||
Mio's test executables are integrated with the CMake test driver program, [CTest](https://cmake.org/cmake/help/latest/manual/ctest.1.html).
|
||||
|
||||
CMake supports a number of backends for compilation and linking.
|
||||
|
||||
To use a static configuration build tool, such as GNU Make or Ninja:
|
||||
|
||||
```sh
|
||||
cd <mio source directory>
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
# Configure the build
|
||||
cmake -D CMAKE_BUILD_TYPE=<Debug | Release> \
|
||||
-G <"Unix Makefiles" | "Ninja"> ..
|
||||
|
||||
# build the tests
|
||||
< make | ninja | cmake --build . >
|
||||
|
||||
# run the tests
|
||||
< make test | ninja test | cmake --build . --target test | ctest >
|
||||
```
|
||||
|
||||
To use a dynamic configuration build tool, such as Visual Studio or Xcode:
|
||||
|
||||
```sh
|
||||
cd <mio source directory>
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
# Configure the build
|
||||
cmake -G <"Visual Studio 14 2015 Win64" | "Xcode"> ..
|
||||
|
||||
# build the tests
|
||||
cmake --build . --config <Debug | Release>
|
||||
|
||||
# run the tests via ctest...
|
||||
ctest --build-config <Debug | Release>
|
||||
|
||||
# ... or via CMake build tool mode...
|
||||
cmake --build . --config <Debug | Release> --target test
|
||||
```
|
||||
|
||||
Of course the **build** and **test** steps can also be executed via the **all** and **test** targets, respectively, from within the IDE after opening the project file generated during the configuration step.
|
||||
|
||||
Mio's testing is also configured to operate as a client to the [CDash](https://www.cdash.org/) software quality dashboard application. Please see the [Kitware documentation](https://cmake.org/cmake/help/latest/manual/ctest.1.html#dashboard-client) for more information on this mode of operation.
|
||||
|
||||
### Installation
|
||||
|
||||
Mio's build system provides an installation target and support for downstream consumption via CMake's [`find_package`](https://cmake.org/cmake/help/v3.0/command/find_package.html) intrinsic function.
|
||||
CMake allows installation to an arbitrary location, which may be specified by defining `CMAKE_INSTALL_PREFIX` at configure time.
|
||||
In the absense of a user specification, CMake will install mio to conventional location based on the platform operating system.
|
||||
|
||||
To use a static configuration build tool, such as GNU Make or Ninja:
|
||||
|
||||
```sh
|
||||
cd <mio source directory>
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
# Configure the build
|
||||
cmake [-D CMAKE_INSTALL_PREFIX="path/to/installation"] \
|
||||
[-D BUILD_TESTING=False] \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-G <"Unix Makefiles" | "Ninja"> ..
|
||||
|
||||
# install mio
|
||||
<make install | ninja install | cmake --build . --target install>
|
||||
```
|
||||
|
||||
To use a dynamic configuration build tool, such as Visual Studio or Xcode:
|
||||
|
||||
```sh
|
||||
cd <mio source directory>
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
# Configure the project
|
||||
cmake [-D CMAKE_INSTALL_PREFIX="path/to/installation"] \
|
||||
[-D BUILD_TESTING=False] \
|
||||
-G <"Visual Studio 14 2015 Win64" | "Xcode"> ..
|
||||
|
||||
# install mio
|
||||
cmake --build . --config Release --target install
|
||||
```
|
||||
|
||||
Note that the last command of the installation sequence may require administrator privileges (e.g. `sudo`) if the installation root directory lies outside your home directory.
|
||||
|
||||
This installation
|
||||
+ copies the mio header files to the `include/mio` subdirectory of the installation root
|
||||
+ generates and copies several CMake configuration files to the `share/cmake/mio` subdirectory of the installation root
|
||||
|
||||
This latter step allows downstream CMake projects to consume mio via `find_package`, e.g.
|
||||
|
||||
```cmake
|
||||
find_package( mio REQUIRED )
|
||||
target_link_libraries( MyTarget PUBLIC mio::mio )
|
||||
```
|
||||
|
||||
**WINDOWS USERS**: The `mio::mio` target `#define`s `WIN32_LEAN_AND_MEAN` and `NOMINMAX`. The former ensures the imported surface area of the Win API is minimal, and the latter disables Windows' `min` and `max` macros so they don't intefere with `std::min` and `std::max`. Because *mio* is a header only library, these defintions will leak into downstream CMake builds. If their presence is causing problems with your build then you can use the alternative `mio::mio_full_winapi` target, which adds none of these defintions.
|
||||
|
||||
If mio was installed to a non-conventional location, it may be necessary for downstream projects to specify the mio installation root directory via either
|
||||
|
||||
+ the `CMAKE_PREFIX_PATH` configuration option,
|
||||
+ the `CMAKE_PREFIX_PATH` environment variable, or
|
||||
+ `mio_DIR` environment variable.
|
||||
|
||||
Please see the [Kitware documentation](https://cmake.org/cmake/help/v3.0/command/find_package.html) for more information.
|
||||
|
||||
In addition, mio supports packaged relocatable installations via [CPack](https://cmake.org/cmake/help/latest/manual/cpack.1.html).
|
||||
Following configuration, from the build directory, invoke cpack as follows to generate a packaged installation:
|
||||
|
||||
```sh
|
||||
cpack -G <generator name> -C Release
|
||||
```
|
||||
|
||||
The list of supported generators varies from platform to platform. See the output of `cpack --help` for a complete list of supported generators on your platform.
|
||||
|
||||
### Subproject Composition
|
||||
To use mio as a subproject, copy the mio repository to your project's dependencies/externals folder.
|
||||
If your project is version controlled using git, a git submodule or git subtree can be used to syncronize with the updstream repository.
|
||||
The [use](https://services.github.com/on-demand/downloads/submodule-vs-subtree-cheat-sheet/) and [relative advantages](https://andrey.nering.com.br/2016/git-submodules-vs-subtrees/) of these git facilities is beyond the scope of this document, but in brief, each may be established as follows:
|
||||
|
||||
```sh
|
||||
# via git submodule
|
||||
cd <my project's dependencies directory>
|
||||
git submodule add -b master https://github.com/mandreyel/mio.git
|
||||
|
||||
# via git subtree
|
||||
cd <my project's root directory>
|
||||
git subtree add --prefix <path/to/dependencies>/mio \
|
||||
https://github.com/mandreyel/mio.git master --squash
|
||||
```
|
||||
|
||||
Given a mio subdirectory in a project, simply add the following lines to your project's to add mio include directories to your target's include path.
|
||||
|
||||
```cmake
|
||||
add_subdirectory( path/to/mio/ )
|
||||
target_link_libraries( MyTarget PUBLIC <mio::mio | mio> )
|
||||
```
|
||||
|
||||
Note that, as a subproject, mio's tests and examples will not be built and CPack integration is deferred to the host project.
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1 +1 @@
|
|||
Subproject commit d6297d250433715c283d17f1969cfcb50d2b6531
|
||||
Subproject commit b56650c7f59b8cd40d18809784a6d6be38ef8acb
|
|
@ -31,7 +31,7 @@
|
|||
#include "../include/path.hpp"
|
||||
#include "../include/errors.hpp"
|
||||
|
||||
#if defined(__unix__)
|
||||
#if defined(__unix__) || defined(__APPLE__)
|
||||
#include <sys/types.h>
|
||||
#include <dirent.h>
|
||||
#include <errno.h>
|
||||
|
@ -178,7 +178,7 @@ entry_iterator& entry_iterator::operator++(int)
|
|||
/// Same as the other operator++().
|
||||
entry_iterator& entry_iterator::operator++()
|
||||
{
|
||||
return (operator++());
|
||||
return (operator++(0));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -51,7 +51,7 @@
|
|||
#include <shlwapi.h>
|
||||
//#include <ntifs.h> // Currently not in msys2
|
||||
|
||||
// @TODO: This is a hack to make it compile under Windows, check if this is save.
|
||||
// @TODO: This is a hack to make it compile under Windows, check if this is safe.
|
||||
#define F_OK 0
|
||||
|
||||
#elif defined(_PATHIE_UNIX)
|
||||
|
@ -149,6 +149,8 @@ Path::Path(const std::vector<Path>& components)
|
|||
*/
|
||||
void Path::sanitize()
|
||||
{
|
||||
bool isWindowsUNCPath = m_path.size() >= 2 && (m_path[0] == '\\' && m_path[1] == '\\'); // UNC path
|
||||
|
||||
// Replace any backslashes \ with forward slashes /.
|
||||
size_t cur = string::npos;
|
||||
while ((cur = m_path.find("\\")) != string::npos) { // assignment intended
|
||||
|
@ -156,8 +158,9 @@ void Path::sanitize()
|
|||
}
|
||||
|
||||
// Replace all double slashes // with a single one
|
||||
// [fseide] except for the first position, which would be a Windows UNC path
|
||||
cur = string::npos;
|
||||
while ((cur = m_path.find("//")) != string::npos) { // assignment intended
|
||||
while ((cur = m_path.find("//", isWindowsUNCPath ? 1 : 0)) != string::npos) { // assignment intended
|
||||
m_path.replace(cur, 2, "/");
|
||||
}
|
||||
|
||||
|
@ -899,7 +902,7 @@ Path Path::pwd()
|
|||
*/
|
||||
Path Path::exe()
|
||||
{
|
||||
#if defined(__linux__)
|
||||
#if defined(__linux__) || defined(__APPLE__)
|
||||
char buf[PATH_MAX];
|
||||
ssize_t size = ::readlink("/proc/self/exe", buf, PATH_MAX);
|
||||
|
||||
|
@ -1546,7 +1549,7 @@ bool Path::is_directory() const
|
|||
throw(Pathie::ErrnoError(errsav));
|
||||
}
|
||||
|
||||
return s.st_mode & S_IFDIR;
|
||||
return (s.st_mode & S_IFDIR) != 0;
|
||||
#else
|
||||
#error Unsupported system.
|
||||
#endif
|
||||
|
@ -1590,7 +1593,7 @@ bool Path::is_file() const
|
|||
throw(Pathie::ErrnoError(errno));
|
||||
}
|
||||
|
||||
return s.st_mode & S_IFREG;
|
||||
return (s.st_mode & S_IFREG) != 0;
|
||||
#else
|
||||
#error Unsupported system.
|
||||
#endif
|
||||
|
@ -1710,9 +1713,9 @@ void Path::remove() const
|
|||
* function uses the apropriate native Win32API function
|
||||
* calls accordingly therefore. */
|
||||
if (is_directory())
|
||||
result = RemoveDirectoryW(utf16.c_str());
|
||||
result = RemoveDirectoryW(utf16.c_str()) != 0;
|
||||
else
|
||||
result = DeleteFileW(utf16.c_str());
|
||||
result = DeleteFileW(utf16.c_str()) != 0;
|
||||
|
||||
if (!result) {
|
||||
DWORD err = GetLastError();
|
||||
|
@ -3282,7 +3285,7 @@ bool Path::fnmatch(const std::string& pattern, int flags /* = 0 */) const
|
|||
#elif defined(_WIN32)
|
||||
std::wstring utf16path = utf8_to_utf16(m_path);
|
||||
std::wstring utf16pattern = utf8_to_utf16(pattern);
|
||||
return PathMatchSpecW(utf16path.c_str(), utf16pattern.c_str());
|
||||
return PathMatchSpecW(utf16path.c_str(), utf16pattern.c_str()) != 0;
|
||||
#else
|
||||
#error Unsupported system.
|
||||
#endif
|
||||
|
|
|
@ -143,7 +143,7 @@ std::string Pathie::convert_encodings(const char* from_encoding, const char* to_
|
|||
errno = 0;
|
||||
errsav = 0;
|
||||
|
||||
#ifdef BSD
|
||||
#if defined(BSD) && ! defined(__APPLE__) //Since MacOS evolved from BSD, it is captured here but the iconv on macos behaves differently
|
||||
// What the heck. FreeBSD violates POSIX.1-2008: it declares iconv()
|
||||
// differently than mandated by POSIX: http://pubs.opengroup.org/onlinepubs/9699919799/functions/iconv.html
|
||||
// (it declares a `const' where it must not be).
|
||||
|
@ -181,11 +181,10 @@ std::string Pathie::convert_encodings(const char* from_encoding, const char* to_
|
|||
std::string Pathie::utf8_to_filename(const std::string& utf8)
|
||||
{
|
||||
bool fs_encoding_is_utf8 = false;
|
||||
|
||||
char* fsencoding = NULL;
|
||||
#if defined(__APPLE__) || defined(PATHIE_ASSUME_UTF8_ON_UNIX)
|
||||
fs_encoding_is_utf8 = true;
|
||||
#else
|
||||
char* fsencoding = NULL;
|
||||
fsencoding = nl_langinfo(CODESET);
|
||||
fs_encoding_is_utf8 = (strcmp(fsencoding, "UTF-8") == 0);
|
||||
#endif
|
||||
|
@ -206,11 +205,10 @@ std::string Pathie::utf8_to_filename(const std::string& utf8)
|
|||
std::string Pathie::filename_to_utf8(const std::string& native_filename)
|
||||
{
|
||||
bool fs_encoding_is_utf8 = false;
|
||||
|
||||
char* fsencoding = NULL;
|
||||
#if defined(__APPLE__) || defined(PATHIE_ASSUME_UTF8_ON_UNIX)
|
||||
fs_encoding_is_utf8 = true;
|
||||
#else
|
||||
char* fsencoding = NULL;
|
||||
fsencoding = nl_langinfo(CODESET);
|
||||
fs_encoding_is_utf8 = (strcmp(fsencoding, "UTF-8") == 0);
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2014-2015 William Ahern
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to
|
||||
deal in the Software without restriction, including without limitation the
|
||||
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
sell copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
IN THE SOFTWARE.
|
|
@ -0,0 +1,182 @@
|
|||
# Introduction #
|
||||
|
||||
This is a simple implementation of the CHD perfect hash algorithm. CHD can
|
||||
generate perfect hash functions for very large key sets--on the order of
|
||||
millions of keys--in a very short time. On my circa 2012 desktop and using
|
||||
the default parameters (hash load factor of 80% and average displacement map
|
||||
bucket load of 4.0 keys) this implementation can generate a hash function
|
||||
for 1,000 keys in less than 1/100th of a second, and 1,000,000 keys in less
|
||||
than a second.
|
||||
|
||||
For more information about the algorithm, see
|
||||
http://cmph.sourceforge.net/chd.html.
|
||||
|
||||
# Dependencies #
|
||||
|
||||
* No runtime dependencies.
|
||||
* Requires a modern C++ compiler to build.
|
||||
* The included build requires GNU Make.
|
||||
|
||||
# Building #
|
||||
|
||||
## Make Macros ##
|
||||
|
||||
The typical GNU macros can be used control the build.
|
||||
|
||||
### Compilation ###
|
||||
|
||||
Note that the modules for Lua 5.1, 5.2, and 5.3 can be built simultaneously.
|
||||
|
||||
* CXX: C++ compiler path.
|
||||
* CXXFLAGS: C++ compiler flags.
|
||||
* CPPFLAGS: C preprocessor flags. Necessary if Lua API cannot be discovered
|
||||
automatically. You can specify multiple include paths if building more than
|
||||
one Lua module.
|
||||
* LDFLAGS: Linker flags. Not normally needed.
|
||||
* SOFLAGS: Flags needed to build dynamic library.
|
||||
* LOFLAGS: Flags needed to build loadable module. Normally should be the
|
||||
same as SOFLAGS, except on OS X.
|
||||
* LIBS: Library dependencies. Normally empty, but see the section Avoiding
|
||||
C++ Dependencies.
|
||||
|
||||
#### Avoiding C++ Dependencies
|
||||
|
||||
Defining the preprocessor macro PHF_NO_LIBCXX to 1 will prevent usage of C++
|
||||
interfaces such as std::string that would require a dependency on libc++ or
|
||||
libstdc++. This allows using platform-dependent flags in CXXFLAGS, LDFLAGS,
|
||||
and SOFLAGS to prevent a dependency on the system C++ library.
|
||||
|
||||
For example, on OS X you can do:
|
||||
```sh
|
||||
$ make CPPFLAGS="-DPHF_NO_LIBCXX" \
|
||||
CXXFLAGS="-std=c++11 -fno-rtti -fno-exceptions -O3 -march=native" \
|
||||
LDFLAGS="-nostdlib" \
|
||||
LIBS="-lSystem"
|
||||
```
|
||||
|
||||
### Installation ####
|
||||
* prefix
|
||||
* includedir
|
||||
* libdir
|
||||
* luacpath: Lua C module install path. Can be used for one-shot installation
|
||||
of a particular Lua version module.
|
||||
* lua51cpath: Lua 5.1 C module install path.
|
||||
* lua52cpath: Same as above, for 5.2.
|
||||
* lua53cpath: Same as above, for 5.3.
|
||||
|
||||
## Make Targets ##
|
||||
|
||||
* phf: Builds command-line utility (development)
|
||||
* libphf.so: Builds dynamic library for non-OS X
|
||||
* libphf.dylib: Builds dynamic library for OS X
|
||||
* lua5.1: Builds Lua 5.1 module at 5.1/phf.so. Lua 5.1 headers should be
|
||||
specified using CPPFLAGS if not in normal locations.
|
||||
* lua5.2: Same as above, for Lua 5.2.
|
||||
* lua5.3: Same as above, for Lua 5.3.
|
||||
|
||||
# Usage #
|
||||
|
||||
## Lua ##
|
||||
|
||||
## API ###
|
||||
|
||||
### phf.new(keys[, lambda][, alpha][, seed][, nodiv]) ###
|
||||
|
||||
* keys: array of keys in order from 1..#keys. They should be all
|
||||
numbers or all strings.
|
||||
|
||||
* lambda: number of keys per bucket when generating the g() function mapping.
|
||||
|
||||
* alpha: output hash space loading factor as percentage from
|
||||
1..100. 100% generates a *minimal* perfect hash function. But note that
|
||||
the implementation does *not* implement the necessary optimizations to
|
||||
ensure timely generation of minimal perfect hash functions. Normally you
|
||||
want a loading factor of 80% to 90% for large key sets.
|
||||
|
||||
* seed: random integer seed.
|
||||
|
||||
* nodiv: if true rounds r and m to powers of 2, and performs modular
|
||||
reduction using bitwise AND. Otherwise, r and m are rounded up to the
|
||||
nearest primes and modulo division used when indexing tables. Note that
|
||||
the rounding occurs after calculation of the intermediate and output hash
|
||||
table loading.
|
||||
|
||||
This is more important when building small hash tables with the C
|
||||
interface. The optimization is substantial when the compiler can inline
|
||||
the code, but isn't substantial from Lua.
|
||||
|
||||
Returns a callable object.
|
||||
|
||||
### phf:hash(key)
|
||||
|
||||
* Returns an integer hash in the range 1..phf:m(). The returned integer will
|
||||
be unique for all keys in the original set. Otherwise the result is
|
||||
unspecified.
|
||||
|
||||
### Example ###
|
||||
|
||||
```Lua
|
||||
local phf = require"phf"
|
||||
|
||||
local lambda = 4 -- how many keys per intermediate bucket
|
||||
local alpha = 80 -- output hash space loading in percentage.
|
||||
|
||||
local keys = { "apple", "banana", "cherry", "date", "eggplant", "fig",
|
||||
"guava", "honeydew", "jackfruit", "kiwi", "lemon", "mango" }
|
||||
|
||||
local F = phf.new(keys, lambda, alpha)
|
||||
|
||||
for i=1,#keys do
|
||||
print(keys[i], F(keys[i]))
|
||||
end
|
||||
|
||||
```
|
||||
|
||||
## C++ ##
|
||||
|
||||
## API ##
|
||||
|
||||
### PHF::uniq<T>(T k[], size_t n); ###
|
||||
|
||||
Similar to the shell command `sort | uniq`. Sorts, deduplicates, and shifts
|
||||
down the keys in the array k. Returns the number of unique keys, which will
|
||||
have been moved to the beginning of the array. If necessary do this before
|
||||
calling PHF::init, as PHF::init does not tolerate duplicate keys.
|
||||
|
||||
### int PHF::init<T, nodiv>(struct phf *f, const T k[], size_t n, size_t l, size_t a, phf_seed_t s);
|
||||
|
||||
Generate a perfect hash function for the n keys in array k and store the
|
||||
results in f. Returns a system error number on failure, or 0 on success. f
|
||||
is unmodified on failure.
|
||||
|
||||
### void PHF::destroy(struct phf *);
|
||||
|
||||
Deallocates internal tables, but not the struct object itself.
|
||||
|
||||
### void PHF::compact<T, nodiv>(struct phf *);
|
||||
|
||||
By default the displacement map is an array of uint32_t integers. This
|
||||
function will select the smallest type necessary to hold the largest
|
||||
displacement value and update the internal state accordingly. For a loading
|
||||
factor of 80% (0.8) in the output hash space, and displacement map loading
|
||||
factor of 4 (400%), the smallest primitive type will often be uint8_t.
|
||||
|
||||
### phf_hash_t PHF::hash<T>(struct phf *f, T k);
|
||||
|
||||
Returns an integer hash value, h, where 0 <= h < f->m. h will be unique for
|
||||
each unique key provided when generating the function. f->m will be larger
|
||||
than the number of unique keys and is based on the specified loading factor
|
||||
(alpha), rounded up to the nearest prime or nearest power of 2, depending on
|
||||
the mode of modular reduction selected. For example, for a loading factor of
|
||||
80% m will be 127: 100 is 80% of 125, and 127 is the closest prime greater
|
||||
than or equal to 125. With the nodiv option, m would be 128: 100 is 80% of
|
||||
125, and 128 is the closest power of 2 greater than or equal to 125.
|
||||
|
||||
## C ##
|
||||
|
||||
The C API is nearly identical to the C++ API, except the prefix is phf_
|
||||
instead of PHF::. phf_uniq, phf_init, and phf_hash are macros which utilize
|
||||
C11's _Generic or GCC's __builtin_types_compatible_p interfaces to overload
|
||||
the interfaces by key type. The explicit suffixes _uint32, _uint64, and
|
||||
_string may be used directly.
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,299 @@
|
|||
/* ==========================================================================
|
||||
* phf.h - Tiny perfect hash function library.
|
||||
* --------------------------------------------------------------------------
|
||||
* Copyright (c) 2014-2015 William Ahern
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the
|
||||
* "Software"), to deal in the Software without restriction, including
|
||||
* without limitation the rights to use, copy, modify, merge, publish,
|
||||
* distribute, sublicense, and/or sell copies of the Software, and to permit
|
||||
* persons to whom the Software is furnished to do so, subject to the
|
||||
* following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included
|
||||
* in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
|
||||
* NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
|
||||
* USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
* ==========================================================================
|
||||
*/
|
||||
#ifndef PHF_H
|
||||
#define PHF_H
|
||||
|
||||
#include <stddef.h> /* size_t */
|
||||
#include <stdint.h> /* UINT32_MAX uint32_t uint64_t */
|
||||
#include <stdbool.h> /* bool */
|
||||
#include <inttypes.h> /* PRIu32 PRIx32 */
|
||||
|
||||
|
||||
/*
|
||||
* C O M P I L E R F E A T U R E S & D I A G N O S T I C S
|
||||
*
|
||||
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
|
||||
|
||||
#define PHF_GNUC_PREREQ(M, m) (__GNUC__ > (M) || (__GNUC__ == (M) && __GNUC_MINOR__ >= (m)))
|
||||
|
||||
#ifdef __clang__
|
||||
#define phf_has_extension(x) __has_extension(x)
|
||||
#define phf_has_attribute(x) __has_attribute(x)
|
||||
#else
|
||||
#define phf_has_extension(x) 0
|
||||
#define phf_has_attribute(x) 0
|
||||
#endif
|
||||
|
||||
#ifndef PHF_HAVE_NOEXCEPT
|
||||
#define PHF_HAVE_NOEXCEPT \
|
||||
(__cplusplus >= 201103L || \
|
||||
phf_has_extension(cxx_noexcept) || \
|
||||
PHF_GNUC_PREREQ(4, 6))
|
||||
#endif
|
||||
|
||||
#ifndef PHF_HAVE_GENERIC
|
||||
#define PHF_HAVE_GENERIC \
|
||||
(__STDC_VERSION__ >= 201112L || \
|
||||
phf_has_extension(c_generic_selections) || \
|
||||
PHF_GNUC_PREREQ(4, 9))
|
||||
#endif
|
||||
|
||||
#ifndef PHF_HAVE_BUILTIN_TYPES_COMPATIBLE_P
|
||||
#define PHF_HAVE_BUILTIN_TYPES_COMPATIBLE_P (defined __GNUC__)
|
||||
#endif
|
||||
|
||||
#ifndef PHF_HAVE_BUILTIN_CHOOSE_EXPR
|
||||
#define PHF_HAVE_BUILTIN_CHOOSE_EXPR (defined __GNUC__)
|
||||
#endif
|
||||
|
||||
#ifndef PHF_HAVE_ATTRIBUTE_VISIBILITY
|
||||
#define PHF_HAVE_ATTRIBUTE_VISIBILITY \
|
||||
(phf_has_attribute(visibility) || PHF_GNUC_PREREQ(4, 0))
|
||||
#endif
|
||||
|
||||
#ifndef PHF_HAVE_COMPUTED_GOTOS
|
||||
#ifdef __GNUC__
|
||||
#define PHF_HAVE_COMPUTED_GOTOS 1
|
||||
#else
|
||||
#define PHF_HAVE_COMPUTED_GOTOS 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#if __cplusplus < 201103L
|
||||
#pragma clang diagnostic ignored "-Wc++11-extensions"
|
||||
#pragma clang diagnostic ignored "-Wvariadic-macros"
|
||||
#endif
|
||||
#elif PHF_GNUC_PREREQ(4, 6)
|
||||
#pragma GCC diagnostic push
|
||||
#if __cplusplus < 201103L
|
||||
#pragma GCC diagnostic ignored "-Wpedantic"
|
||||
#pragma GCC diagnostic ignored "-Wvariadic-macros"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
/*
|
||||
* C / C + + V I S I B I L I T Y
|
||||
*
|
||||
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
|
||||
|
||||
#ifndef PHF_PUBLIC
|
||||
#define PHF_PUBLIC
|
||||
#endif
|
||||
|
||||
#ifndef PHF_LOCAL
|
||||
#if PHF_HAVE_ATTRIBUTE_VISIBILITY
|
||||
#define PHF_LOCAL __attribute__((visibility("hidden")))
|
||||
#else
|
||||
#define PHF_LOCAL
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
/*
|
||||
* C / C + + S H A R E D T Y P E S
|
||||
*
|
||||
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
|
||||
|
||||
#define phf_error_t int /* for documentation purposes */
|
||||
|
||||
#define PHF_HASH_MAX UINT32_MAX
|
||||
#define PHF_PRIuHASH PRIu32
|
||||
#define PHF_PRIxHASH PRIx32
|
||||
|
||||
typedef uint32_t phf_hash_t;
|
||||
typedef uint32_t phf_seed_t;
|
||||
|
||||
typedef struct phf_string {
|
||||
void *p;
|
||||
size_t n;
|
||||
} phf_string_t;
|
||||
|
||||
struct phf {
|
||||
bool nodiv;
|
||||
|
||||
phf_seed_t seed;
|
||||
|
||||
size_t r; /* number of elements in g */
|
||||
size_t m; /* number of elements in perfect hash */
|
||||
uint32_t *g; /* displacement map indexed by g(k) % r */
|
||||
|
||||
size_t d_max; /* maximum displacement value in g */
|
||||
|
||||
enum {
|
||||
PHF_G_UINT8_MOD_R = 1,
|
||||
PHF_G_UINT8_BAND_R,
|
||||
PHF_G_UINT16_MOD_R,
|
||||
PHF_G_UINT16_BAND_R,
|
||||
PHF_G_UINT32_MOD_R,
|
||||
PHF_G_UINT32_BAND_R,
|
||||
} g_op;
|
||||
|
||||
const void *g_jmp;
|
||||
}; /* struct phf */
|
||||
|
||||
|
||||
/*
|
||||
* C + + I N T E R F A C E S
|
||||
*
|
||||
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
|
||||
#ifdef __cplusplus
|
||||
|
||||
#if !PHF_NO_LIBCXX
|
||||
#include <string> /* std::string */
|
||||
#endif
|
||||
|
||||
namespace PHF {
|
||||
template<typename key_t>
|
||||
PHF_PUBLIC size_t uniq(key_t[], const size_t);
|
||||
|
||||
template<typename key_t, bool nodiv>
|
||||
PHF_PUBLIC phf_error_t init(struct phf *, const key_t[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
|
||||
PHF_PUBLIC void compact(struct phf *);
|
||||
|
||||
template<typename key_t>
|
||||
PHF_PUBLIC phf_hash_t hash(struct phf *, key_t);
|
||||
|
||||
PHF_PUBLIC void destroy(struct phf *);
|
||||
}
|
||||
|
||||
extern template size_t PHF::uniq<uint32_t>(uint32_t[], const size_t);
|
||||
extern template size_t PHF::uniq<uint64_t>(uint64_t[], const size_t);
|
||||
extern template size_t PHF::uniq<phf_string_t>(phf_string_t[], const size_t);
|
||||
#if !PHF_NO_LIBCXX
|
||||
extern template size_t PHF::uniq<std::string>(std::string[], const size_t);
|
||||
#endif
|
||||
|
||||
extern template phf_error_t PHF::init<uint32_t, true>(struct phf *, const uint32_t[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
extern template phf_error_t PHF::init<uint64_t, true>(struct phf *, const uint64_t[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
extern template phf_error_t PHF::init<phf_string_t, true>(struct phf *, const phf_string_t[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
#if !PHF_NO_LIBCXX
|
||||
extern template phf_error_t PHF::init<std::string, true>(struct phf *, const std::string[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
#endif
|
||||
|
||||
extern template phf_error_t PHF::init<uint32_t, false>(struct phf *, const uint32_t[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
extern template phf_error_t PHF::init<uint64_t, false>(struct phf *, const uint64_t[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
extern template phf_error_t PHF::init<phf_string_t, false>(struct phf *, const phf_string_t[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
#if !PHF_NO_LIBCXX
|
||||
extern template phf_error_t PHF::init<std::string, false>(struct phf *, const std::string[], const size_t, const size_t, const size_t, const phf_seed_t);
|
||||
#endif
|
||||
|
||||
extern template phf_hash_t PHF::hash<uint32_t>(struct phf *, uint32_t);
|
||||
extern template phf_hash_t PHF::hash<uint64_t>(struct phf *, uint64_t);
|
||||
extern template phf_hash_t PHF::hash<phf_string_t>(struct phf *, phf_string_t);
|
||||
#if !PHF_NO_LIBCXX
|
||||
extern template phf_hash_t PHF::hash<std::string>(struct phf *, std::string);
|
||||
#endif
|
||||
|
||||
#endif /* __cplusplus */
|
||||
|
||||
|
||||
/*
|
||||
* C 8 9 I N T E R F A C E S
|
||||
*
|
||||
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
PHF_PUBLIC size_t phf_uniq_uint32(uint32_t *, const size_t);
|
||||
PHF_PUBLIC size_t phf_uniq_uint64(uint64_t *, const size_t);
|
||||
PHF_PUBLIC size_t phf_uniq_string(phf_string_t *, const size_t);
|
||||
|
||||
PHF_PUBLIC phf_error_t phf_init_uint32(struct phf *, const uint32_t *, const size_t, const size_t, const size_t, const phf_seed_t, const bool nodiv);
|
||||
PHF_PUBLIC phf_error_t phf_init_uint64(struct phf *, const uint64_t *, const size_t, const size_t, const size_t, const phf_seed_t, const bool nodiv);
|
||||
PHF_PUBLIC phf_error_t phf_init_string(struct phf *, const phf_string_t *, const size_t, const size_t, const size_t, const phf_seed_t, const bool nodiv);
|
||||
|
||||
PHF_PUBLIC void phf_compact(struct phf *);
|
||||
|
||||
PHF_PUBLIC phf_hash_t phf_hash_uint32(struct phf *, const uint32_t);
|
||||
PHF_PUBLIC phf_hash_t phf_hash_uint64(struct phf *, const uint64_t);
|
||||
PHF_PUBLIC phf_hash_t phf_hash_string(struct phf *, const phf_string_t);
|
||||
|
||||
PHF_PUBLIC void phf_destroy(struct phf *);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
/*
|
||||
* C 1 1 / G N U I N T E R F A C E S
|
||||
*
|
||||
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
|
||||
#if PHF_HAVE_GENERIC
|
||||
|
||||
#define phf_uniq(k, n) _Generic(*(k), \
|
||||
uint32_t: phf_uniq_uint32, \
|
||||
uint64_t: phf_uniq_uint64, \
|
||||
phf_string_t: phf_uniq_string)((k), (n))
|
||||
|
||||
#define phf_init(f, k, ...) _Generic(*(k), \
|
||||
uint32_t: phf_init_uint32, \
|
||||
uint64_t: phf_init_uint64, \
|
||||
phf_string_t: phf_init_string)((f), (k), __VA_ARGS__)
|
||||
|
||||
#define phf_hash(f, k) _Generic((k), \
|
||||
uint32_t: phf_hash_uint32, \
|
||||
uint64_t: phf_hash_uint64, \
|
||||
phf_string_t: phf_hash_string)((f), (k))
|
||||
|
||||
#elif PHF_HAVE_BUILTIN_TYPES_COMPATIBLE_P && PHF_HAVE_BUILTIN_CHOOSE_EXPR
|
||||
|
||||
#define phf_choose(cond, a, b) __builtin_choose_expr(cond, a, b)
|
||||
#define phf_istype(E, T) __builtin_types_compatible_p(__typeof__(E), T)
|
||||
|
||||
#define phf_uniq(k, n) \
|
||||
phf_choose(phf_istype(*(k), uint32_t), phf_uniq_uint32((uint32_t *)(k), (n)), \
|
||||
phf_choose(phf_istype(*(k), uint64_t), phf_uniq_uint64((uint64_t *)(k), (n)), \
|
||||
phf_choose(phf_istype(*(k), phf_string_t), phf_uniq_string((phf_string_t *)(k), (n)), \
|
||||
(void)0)))
|
||||
|
||||
#define phf_init(f, k, ...) \
|
||||
phf_choose(phf_istype(*(k), uint32_t), phf_init_uint32((f), (const uint32_t *)(k), __VA_ARGS__), \
|
||||
phf_choose(phf_istype(*(k), uint64_t), phf_init_uint64((f), (const uint64_t *)(k), __VA_ARGS__), \
|
||||
phf_choose(phf_istype(*(k), phf_string_t), phf_init_string((f), (const phf_string_t *)(k), __VA_ARGS__), \
|
||||
(void)0)))
|
||||
|
||||
#define phf_hash(f, k) ((*(phf_hash_t (*)()) \
|
||||
phf_choose(phf_istype((k), uint32_t), &phf_hash_uint32, \
|
||||
phf_choose(phf_istype((k), uint64_t), &phf_hash_uint64, \
|
||||
phf_choose(phf_istype((k), phf_string_t), &phf_hash_string, \
|
||||
(void)0))))((f), (k)))
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#elif PHF_GNUC_PREREQ(4, 6)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#endif /* PHF_H */
|
|
@ -1,318 +1,245 @@
|
|||
/*
|
||||
* Copyright 1993-2015 NVIDIA Corporation. All rights reserved.
|
||||
/* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions
|
||||
* are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
* contributors may be used to endorse or promote products derived
|
||||
* from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
* CONTRIBUTORS BE LIABLE FOR ANY DIRECINDIRECFunctor, T, AccTyf, pe, INCIDENTAL, SPECIAL,
|
||||
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
* OF LIABILITY, WHETHER IN CONTRACSf, TRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensors/tensor.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "functional/tmp.h"
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace marian {
|
||||
|
||||
template <unsigned int blockSize>
|
||||
__device__ void
|
||||
reduceBlock(volatile float *sdata, float mySum, const unsigned int tid)
|
||||
{
|
||||
sdata[tid] = mySum;
|
||||
__syncthreads();
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// do reduction in shared mem
|
||||
if (blockSize >= 512)
|
||||
{
|
||||
if (tid < 256)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 256];
|
||||
// Utility class used to avoid linker errors with extern
|
||||
// unsized shared memory arrays with templated type
|
||||
template <class T>
|
||||
struct SharedMemory {
|
||||
__device__ inline operator T *() {
|
||||
extern __shared__ int __smem[];
|
||||
return (T *)__smem;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
__device__ inline operator const T *() const {
|
||||
extern __shared__ int __smem[];
|
||||
return (T *)__smem;
|
||||
}
|
||||
};
|
||||
|
||||
// specialize for double to avoid unaligned memory
|
||||
// access compile errors
|
||||
template <>
|
||||
struct SharedMemory<double> {
|
||||
__device__ inline operator double *() {
|
||||
extern __shared__ double __smem_d[];
|
||||
return (double *)__smem_d;
|
||||
}
|
||||
|
||||
if (blockSize >= 256)
|
||||
{
|
||||
if (tid < 128)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 128];
|
||||
__device__ inline operator const double *() const {
|
||||
extern __shared__ double __smem_d[];
|
||||
return (double *)__smem_d;
|
||||
}
|
||||
};
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (blockSize >= 128)
|
||||
{
|
||||
if (tid < 64)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 64];
|
||||
}
|
||||
/*
|
||||
This version adds multiple elements per thread sequentially. This reduces
|
||||
the overall cost of the algorithm while keeping the work complexity O(n) and
|
||||
the step complexity O(log n). (Brent's Theorem optimization)
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
Note, this kernel needs a minimum of 64*sizeof(T) bytes of shared memory.
|
||||
In other words if blockSize <= 32, allocate 64*sizeof(T) bytes.
|
||||
If blockSize > 32, allocate blockSize*sizeof(T) bytes.
|
||||
*/
|
||||
template <typename T, typename AccType, unsigned int blockSize, bool nIsPow2Greater1, size_t K, class Functor, class AggFunctor>
|
||||
__global__ void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
|
||||
const functional::Shape full,
|
||||
functional::Tensor<AccType> out,
|
||||
functional::Array<functional::Tensor<T>, K> ins) {
|
||||
int n = full.elements();
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 32];
|
||||
}
|
||||
|
||||
if (blockSize >= 32)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 16];
|
||||
}
|
||||
|
||||
if (blockSize >= 16)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 8];
|
||||
}
|
||||
|
||||
if (blockSize >= 8)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 4];
|
||||
}
|
||||
|
||||
if (blockSize >= 4)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 2];
|
||||
}
|
||||
|
||||
if (blockSize >= 2)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <unsigned int blockSize, bool nIsPow2, class Functor>
|
||||
__device__ void
|
||||
reduceBlocks(Functor f, float *g_idata, float *g_odata, unsigned int n)
|
||||
{
|
||||
extern __shared__ float sdata[];
|
||||
// Handle to thread block group
|
||||
cg::thread_block cta = cg::this_thread_block();
|
||||
AccType *sdata = SharedMemory<AccType>();
|
||||
|
||||
// perform first level of reduction,
|
||||
// reading from global memory, writing to shared memory
|
||||
unsigned int tid = threadIdx.x;
|
||||
unsigned int i = blockIdx.x*(blockSize*2) + threadIdx.x;
|
||||
unsigned int i = blockIdx.x * blockSize * 2 + threadIdx.x;
|
||||
unsigned int gridSize = blockSize * 2 * gridDim.x;
|
||||
float mySum = 0;
|
||||
|
||||
// we reduce multiple elements per thread. The number is determined by the
|
||||
AccType mySum = aggInit;
|
||||
|
||||
// we reduceSinglePass multiple elements per thread. The number is determined by the
|
||||
// number of active thread blocks (via gridDim). More blocks will result
|
||||
// in a larger gridSize and therefore fewer elements per thread
|
||||
while (i < n)
|
||||
{
|
||||
mySum += f(g_idata[i]);
|
||||
while (i < n) {
|
||||
mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i));
|
||||
|
||||
// ensure we don't read out of bounds -- this is optimized away for powerOf2 sized arrays
|
||||
if (nIsPow2 || i + blockSize < n)
|
||||
mySum += f(g_idata[i+blockSize]);
|
||||
// ensure we don't read out of bounds -- this is optimized away for powerOf2
|
||||
// sized arrays
|
||||
if (nIsPow2Greater1 || i + blockSize < n)
|
||||
mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i + blockSize));
|
||||
|
||||
i += gridSize;
|
||||
}
|
||||
|
||||
// each thread puts its local sum into shared memory
|
||||
sdata[tid] = mySum;
|
||||
cg::sync(cta);
|
||||
|
||||
// do reduction in shared mem
|
||||
reduceBlock<blockSize>(sdata, mySum, tid);
|
||||
if ((blockSize >= 512) && (tid < 256)) {
|
||||
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 256]);
|
||||
}
|
||||
|
||||
cg::sync(cta);
|
||||
|
||||
if ((blockSize >= 256) && (tid < 128)) {
|
||||
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 128]);
|
||||
}
|
||||
|
||||
cg::sync(cta);
|
||||
|
||||
if ((blockSize >= 128) && (tid < 64)) {
|
||||
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 64]);
|
||||
}
|
||||
|
||||
cg::sync(cta);
|
||||
|
||||
cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
|
||||
|
||||
if (cta.thread_rank() < 32) {
|
||||
// Fetch final intermediate sum from 2nd warp
|
||||
if (blockSize >= 64)
|
||||
mySum = aggFunctor(mySum, sdata[tid + 32]);
|
||||
// reduce final warp using shuffle
|
||||
for (int offset = tile32.size() / 2; offset > 0; offset /= 2) {
|
||||
mySum = aggFunctor(mySum, tile32.shfl_down(mySum, offset));
|
||||
}
|
||||
}
|
||||
|
||||
// write result for this block to global mem
|
||||
if (tid == 0) g_odata[blockIdx.x] = sdata[0];
|
||||
if (cta.thread_rank() == 0)
|
||||
out[blockIdx.x] = aggFunctor(out[blockIdx.x], mySum * scale); // aggFunctor?
|
||||
}
|
||||
|
||||
// Global variable used by reduceSinglePass to count how many blocks have finished
|
||||
__device__ unsigned int retirementCount = 0;
|
||||
|
||||
cudaError_t setRetirementCount(int retCnt)
|
||||
{
|
||||
return cudaMemcpyToSymbol(retirementCount, &retCnt, sizeof(unsigned int), 0, cudaMemcpyHostToDevice);
|
||||
static inline bool isPow2Greater1(unsigned int x) { // is power of two but also larger than 1, otherwise an out-of-bounds read occurs
|
||||
return x > 1 && ((x & (x - 1)) == 0);
|
||||
}
|
||||
|
||||
// This reduction kernel reduces an arbitrary size array in a single kernel invocation
|
||||
// It does so by keeping track of how many blocks have finished. After each thread
|
||||
// block completes the reduction of its own block of data, it "takes a ticket" by
|
||||
// atomically incrementing a global counter. If the ticket value is equal to the number
|
||||
// of thread blocks, then the block holding the ticket knows that it is the last block
|
||||
// to finish. This last block is responsible for summing the results of all the other
|
||||
// blocks.
|
||||
//
|
||||
// In order for this to work, we must be sure that before a block takes a ticket, all
|
||||
// of its memory transactions have completed. This is what __threadfence() does -- it
|
||||
// blocks until the results of all outstanding memory transactions within the
|
||||
// calling thread are visible to all other threads.
|
||||
//
|
||||
// For more details on the reduction algorithm (notably the multi-pass approach), see
|
||||
// the "reduction" sample in the CUDA SDK.
|
||||
|
||||
template <unsigned int blockSize, bool nIsPow2, class Functor>
|
||||
__global__ void reduceSinglePass(Functor f, float *g_idata, float *g_odata, unsigned int n)
|
||||
{
|
||||
|
||||
//
|
||||
// PHASE 1: Process all inputs assigned to this block
|
||||
//
|
||||
|
||||
reduceBlocks<blockSize, nIsPow2>(f, g_idata, g_odata, n);
|
||||
|
||||
//
|
||||
// PHASE 2: Last block finished will process all partial sums
|
||||
//
|
||||
|
||||
if (gridDim.x > 1)
|
||||
{
|
||||
const unsigned int tid = threadIdx.x;
|
||||
__shared__ bool amLast;
|
||||
extern float __shared__ smem[];
|
||||
|
||||
// wait until all outstanding memory instructions in this thread are finished
|
||||
__threadfence();
|
||||
|
||||
// Thread 0 takes a ticket
|
||||
if (tid==0)
|
||||
{
|
||||
unsigned int ticket = atomicInc(&retirementCount, gridDim.x);
|
||||
// If the ticket ID is equal to the number of blocks, we are the last block!
|
||||
amLast = (ticket == gridDim.x-1);
|
||||
static inline unsigned int nextPow2(unsigned int x) {
|
||||
--x;
|
||||
x |= x >> 1;
|
||||
x |= x >> 2;
|
||||
x |= x >> 4;
|
||||
x |= x >> 8;
|
||||
x |= x >> 16;
|
||||
return ++x;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// The last block sums the results of all other blocks
|
||||
if (amLast)
|
||||
{
|
||||
int i = tid;
|
||||
float mySum = 0;
|
||||
|
||||
while (i < gridDim.x)
|
||||
{
|
||||
mySum += g_odata[i];
|
||||
i += blockSize;
|
||||
}
|
||||
|
||||
reduceBlock<blockSize>(smem, mySum, tid);
|
||||
|
||||
if (tid==0)
|
||||
{
|
||||
g_odata[0] = smem[0];
|
||||
|
||||
// reset retirement count so that next run succeeds
|
||||
retirementCount = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isPow2(unsigned int x)
|
||||
{
|
||||
return ((x&(x-1))==0);
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
void ReduceAll(Functor f, Tensor out, Tensor in)
|
||||
{
|
||||
cudaSetDevice(out->getDeviceId().no);
|
||||
int size = in->shape().elements();
|
||||
int threads = std::min(MAX_THREADS, size);
|
||||
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Wrapper function for kernel launch
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T, typename AccType, size_t K, class Functor, class AggFunctor>
|
||||
void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
|
||||
const functional::Shape full,
|
||||
functional::Tensor<AccType> out,
|
||||
functional::Array<functional::Tensor<T>, K> ins,
|
||||
int threads, int blocks) {
|
||||
int size = full.elements();
|
||||
// when there is only one warp per block, we need to allocate two warps
|
||||
// worth of shared memory so that we don't index shared memory out of bounds
|
||||
int smemSize = (threads <= 32) ? 2 * threads * sizeof(AccType) : threads * sizeof(AccType);
|
||||
dim3 dimBlock(threads, 1, 1);
|
||||
dim3 dimGrid(blocks, 1, 1);
|
||||
int smemSize = threads * sizeof(float);
|
||||
|
||||
float* d_idata = in->data();
|
||||
float* d_odata = out->data();
|
||||
|
||||
// choose which of the optimized versions of reduction to launch
|
||||
if (isPow2(size))
|
||||
{
|
||||
switch (threads)
|
||||
{
|
||||
if (isPow2Greater1(size)) {
|
||||
switch (threads) {
|
||||
case 512:
|
||||
reduceSinglePass<512, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 512, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 256:
|
||||
reduceSinglePass<256, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 256, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 128:
|
||||
reduceSinglePass<128, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 128, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 64:
|
||||
reduceSinglePass< 64, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 64, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 32:
|
||||
reduceSinglePass< 32, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 32, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 16:
|
||||
reduceSinglePass< 16, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 16, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 8:
|
||||
reduceSinglePass< 8, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 8, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 4:
|
||||
reduceSinglePass< 4, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 4, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
reduceSinglePass< 2, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 2, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 1:
|
||||
reduceSinglePass< 1, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 1, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (threads)
|
||||
{
|
||||
} else {
|
||||
switch (threads) {
|
||||
case 512:
|
||||
reduceSinglePass<512, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 512, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 256:
|
||||
reduceSinglePass<256, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 256, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 128:
|
||||
reduceSinglePass<128, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 128, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 64:
|
||||
reduceSinglePass< 64, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 64, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 32:
|
||||
reduceSinglePass< 32, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 32, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 16:
|
||||
reduceSinglePass< 16, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 16, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 8:
|
||||
reduceSinglePass< 8, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 8, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 4:
|
||||
reduceSinglePass< 4, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 4, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
reduceSinglePass< 2, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 2, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
|
||||
case 1:
|
||||
reduceSinglePass< 1, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
reduceSinglePass<T, AccType, 1, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 1a38d26a13cc67b1aae641d4983b624bef6d5305
|
||||
Subproject commit 1d33bb67c3b6b2a51d3c9ffd55f37725801da39d
|
|
@ -0,0 +1,711 @@
|
|||
/* SIMD (SSE1+MMX or SSE2) implementation of sin, cos, exp and log
|
||||
|
||||
Inspired by Intel Approximate Math library, and based on the
|
||||
corresponding algorithms of the cephes math library
|
||||
|
||||
The default is to use the SSE1 version. If you define USE_SSE2 the
|
||||
the SSE2 intrinsics will be used in place of the MMX intrinsics. Do
|
||||
not expect any significant performance improvement with SSE2.
|
||||
*/
|
||||
|
||||
/* Copyright (C) 2007 Julien Pommier
|
||||
|
||||
This software is provided 'as-is', without any express or implied
|
||||
warranty. In no event will the authors be held liable for any damages
|
||||
arising from the use of this software.
|
||||
|
||||
Permission is granted to anyone to use this software for any purpose,
|
||||
including commercial applications, and to alter it and redistribute it
|
||||
freely, subject to the following restrictions:
|
||||
|
||||
1. The origin of this software must not be misrepresented; you must not
|
||||
claim that you wrote the original software. If you use this software
|
||||
in a product, an acknowledgment in the product documentation would be
|
||||
appreciated but is not required.
|
||||
2. Altered source versions must be plainly marked as such, and must not be
|
||||
misrepresented as being the original software.
|
||||
3. This notice may not be removed or altered from any source distribution.
|
||||
|
||||
(this is the zlib license)
|
||||
*/
|
||||
|
||||
#include <xmmintrin.h>
|
||||
|
||||
/* yes I know, the top of this file is quite ugly */
|
||||
|
||||
#ifdef _MSC_VER /* visual c++ */
|
||||
# define ALIGN16_BEG __declspec(align(16))
|
||||
# define ALIGN16_END
|
||||
#else /* gcc or icc */
|
||||
# define ALIGN16_BEG
|
||||
# define ALIGN16_END __attribute__((aligned(16)))
|
||||
#endif
|
||||
|
||||
/* __m128 is ugly to write */
|
||||
typedef __m128 v4sf; // vector of 4 float (sse1)
|
||||
|
||||
#ifdef USE_SSE2
|
||||
# include <emmintrin.h>
|
||||
typedef __m128i v4si; // vector of 4 int (sse2)
|
||||
#else
|
||||
typedef __m64 v2si; // vector of 2 int (mmx)
|
||||
#endif
|
||||
|
||||
/* declare some SSE constants -- why can't I figure a better way to do that? */
|
||||
#define _PS_CONST(Name, Val) \
|
||||
static const ALIGN16_BEG float _ps_##Name[4] ALIGN16_END = { (float)Val, (float)Val, (float)Val, (float)Val }
|
||||
#define _PI32_CONST(Name, Val) \
|
||||
static const ALIGN16_BEG int _pi32_##Name[4] ALIGN16_END = { Val, Val, Val, Val }
|
||||
#define _PS_CONST_TYPE(Name, Type, Val) \
|
||||
static const ALIGN16_BEG Type _ps_##Name[4] ALIGN16_END = { Val, Val, Val, Val }
|
||||
|
||||
_PS_CONST(1 , 1.0f);
|
||||
_PS_CONST(0p5, 0.5f);
|
||||
/* the smallest non denormalized float number */
|
||||
_PS_CONST_TYPE(min_norm_pos, int, 0x00800000);
|
||||
_PS_CONST_TYPE(mant_mask, int, 0x7f800000);
|
||||
_PS_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
|
||||
|
||||
_PS_CONST_TYPE(sign_mask, int, (int)0x80000000);
|
||||
_PS_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
|
||||
|
||||
_PI32_CONST(1, 1);
|
||||
_PI32_CONST(inv1, ~1);
|
||||
_PI32_CONST(2, 2);
|
||||
_PI32_CONST(4, 4);
|
||||
_PI32_CONST(0x7f, 0x7f);
|
||||
|
||||
_PS_CONST(cephes_SQRTHF, 0.707106781186547524);
|
||||
_PS_CONST(cephes_log_p0, 7.0376836292E-2);
|
||||
_PS_CONST(cephes_log_p1, - 1.1514610310E-1);
|
||||
_PS_CONST(cephes_log_p2, 1.1676998740E-1);
|
||||
_PS_CONST(cephes_log_p3, - 1.2420140846E-1);
|
||||
_PS_CONST(cephes_log_p4, + 1.4249322787E-1);
|
||||
_PS_CONST(cephes_log_p5, - 1.6668057665E-1);
|
||||
_PS_CONST(cephes_log_p6, + 2.0000714765E-1);
|
||||
_PS_CONST(cephes_log_p7, - 2.4999993993E-1);
|
||||
_PS_CONST(cephes_log_p8, + 3.3333331174E-1);
|
||||
_PS_CONST(cephes_log_q1, -2.12194440e-4);
|
||||
_PS_CONST(cephes_log_q2, 0.693359375);
|
||||
|
||||
#ifndef USE_SSE2
|
||||
typedef union xmm_mm_union {
|
||||
__m128 xmm;
|
||||
__m64 mm[2];
|
||||
} xmm_mm_union;
|
||||
|
||||
#define COPY_XMM_TO_MM(xmm_, mm0_, mm1_) { \
|
||||
xmm_mm_union u; u.xmm = xmm_; \
|
||||
mm0_ = u.mm[0]; \
|
||||
mm1_ = u.mm[1]; \
|
||||
}
|
||||
|
||||
#define COPY_MM_TO_XMM(mm0_, mm1_, xmm_) { \
|
||||
xmm_mm_union u; u.mm[0]=mm0_; u.mm[1]=mm1_; xmm_ = u.xmm; \
|
||||
}
|
||||
|
||||
#endif // USE_SSE2
|
||||
|
||||
/* natural logarithm computed for 4 simultaneous float
|
||||
return NaN for x <= 0
|
||||
*/
|
||||
static inline v4sf log_ps(v4sf x) {
|
||||
#ifdef USE_SSE2
|
||||
v4si emm0;
|
||||
#else
|
||||
v2si mm0, mm1;
|
||||
#endif
|
||||
v4sf one = *(v4sf*)_ps_1;
|
||||
|
||||
v4sf invalid_mask = _mm_cmple_ps(x, _mm_setzero_ps());
|
||||
|
||||
x = _mm_max_ps(x, *(v4sf*)_ps_min_norm_pos); /* cut off denormalized stuff */
|
||||
|
||||
#ifndef USE_SSE2
|
||||
/* part 1: x = frexpf(x, &e); */
|
||||
COPY_XMM_TO_MM(x, mm0, mm1);
|
||||
mm0 = _mm_srli_pi32(mm0, 23);
|
||||
mm1 = _mm_srli_pi32(mm1, 23);
|
||||
#else
|
||||
emm0 = _mm_srli_epi32(_mm_castps_si128(x), 23);
|
||||
#endif
|
||||
/* keep only the fractional part */
|
||||
x = _mm_and_ps(x, *(v4sf*)_ps_inv_mant_mask);
|
||||
x = _mm_or_ps(x, *(v4sf*)_ps_0p5);
|
||||
|
||||
#ifndef USE_SSE2
|
||||
/* now e=mm0:mm1 contain the really base-2 exponent */
|
||||
mm0 = _mm_sub_pi32(mm0, *(v2si*)_pi32_0x7f);
|
||||
mm1 = _mm_sub_pi32(mm1, *(v2si*)_pi32_0x7f);
|
||||
v4sf e = _mm_cvtpi32x2_ps(mm0, mm1);
|
||||
_mm_empty(); /* bye bye mmx */
|
||||
#else
|
||||
emm0 = _mm_sub_epi32(emm0, *(v4si*)_pi32_0x7f);
|
||||
v4sf e = _mm_cvtepi32_ps(emm0);
|
||||
#endif
|
||||
|
||||
e = _mm_add_ps(e, one);
|
||||
|
||||
/* part2:
|
||||
if( x < SQRTHF ) {
|
||||
e -= 1;
|
||||
x = x + x - 1.0;
|
||||
} else { x = x - 1.0; }
|
||||
*/
|
||||
v4sf mask = _mm_cmplt_ps(x, *(v4sf*)_ps_cephes_SQRTHF);
|
||||
v4sf tmp = _mm_and_ps(x, mask);
|
||||
x = _mm_sub_ps(x, one);
|
||||
e = _mm_sub_ps(e, _mm_and_ps(one, mask));
|
||||
x = _mm_add_ps(x, tmp);
|
||||
|
||||
|
||||
v4sf z = _mm_mul_ps(x,x);
|
||||
|
||||
v4sf y = *(v4sf*)_ps_cephes_log_p0;
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p1);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p2);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p3);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p4);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p5);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p6);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p7);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p8);
|
||||
y = _mm_mul_ps(y, x);
|
||||
|
||||
y = _mm_mul_ps(y, z);
|
||||
|
||||
|
||||
tmp = _mm_mul_ps(e, *(v4sf*)_ps_cephes_log_q1);
|
||||
y = _mm_add_ps(y, tmp);
|
||||
|
||||
|
||||
tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
|
||||
y = _mm_sub_ps(y, tmp);
|
||||
|
||||
tmp = _mm_mul_ps(e, *(v4sf*)_ps_cephes_log_q2);
|
||||
x = _mm_add_ps(x, y);
|
||||
x = _mm_add_ps(x, tmp);
|
||||
x = _mm_or_ps(x, invalid_mask); // negative arg will be NAN
|
||||
return x;
|
||||
}
|
||||
|
||||
_PS_CONST(exp_hi, 88.3762626647949f);
|
||||
_PS_CONST(exp_lo, -88.3762626647949f);
|
||||
|
||||
_PS_CONST(cephes_LOG2EF, 1.44269504088896341);
|
||||
_PS_CONST(cephes_exp_C1, 0.693359375);
|
||||
_PS_CONST(cephes_exp_C2, -2.12194440e-4);
|
||||
|
||||
_PS_CONST(cephes_exp_p0, 1.9875691500E-4);
|
||||
_PS_CONST(cephes_exp_p1, 1.3981999507E-3);
|
||||
_PS_CONST(cephes_exp_p2, 8.3334519073E-3);
|
||||
_PS_CONST(cephes_exp_p3, 4.1665795894E-2);
|
||||
_PS_CONST(cephes_exp_p4, 1.6666665459E-1);
|
||||
_PS_CONST(cephes_exp_p5, 5.0000001201E-1);
|
||||
|
||||
static inline v4sf exp_ps(v4sf x) {
|
||||
v4sf tmp = _mm_setzero_ps(), fx;
|
||||
#ifdef USE_SSE2
|
||||
v4si emm0;
|
||||
#else
|
||||
v2si mm0, mm1;
|
||||
#endif
|
||||
v4sf one = *(v4sf*)_ps_1;
|
||||
|
||||
x = _mm_min_ps(x, *(v4sf*)_ps_exp_hi);
|
||||
x = _mm_max_ps(x, *(v4sf*)_ps_exp_lo);
|
||||
|
||||
/* express exp(x) as exp(g + n*log(2)) */
|
||||
fx = _mm_mul_ps(x, *(v4sf*)_ps_cephes_LOG2EF);
|
||||
fx = _mm_add_ps(fx, *(v4sf*)_ps_0p5);
|
||||
|
||||
/* how to perform a floorf with SSE: just below */
|
||||
#ifndef USE_SSE2
|
||||
/* step 1 : cast to int */
|
||||
tmp = _mm_movehl_ps(tmp, fx);
|
||||
mm0 = _mm_cvttps_pi32(fx);
|
||||
mm1 = _mm_cvttps_pi32(tmp);
|
||||
/* step 2 : cast back to float */
|
||||
tmp = _mm_cvtpi32x2_ps(mm0, mm1);
|
||||
#else
|
||||
emm0 = _mm_cvttps_epi32(fx);
|
||||
tmp = _mm_cvtepi32_ps(emm0);
|
||||
#endif
|
||||
/* if greater, substract 1 */
|
||||
v4sf mask = _mm_cmpgt_ps(tmp, fx);
|
||||
mask = _mm_and_ps(mask, one);
|
||||
fx = _mm_sub_ps(tmp, mask);
|
||||
|
||||
tmp = _mm_mul_ps(fx, *(v4sf*)_ps_cephes_exp_C1);
|
||||
v4sf z = _mm_mul_ps(fx, *(v4sf*)_ps_cephes_exp_C2);
|
||||
x = _mm_sub_ps(x, tmp);
|
||||
x = _mm_sub_ps(x, z);
|
||||
|
||||
z = _mm_mul_ps(x,x);
|
||||
|
||||
v4sf y = *(v4sf*)_ps_cephes_exp_p0;
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p1);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p2);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p3);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p4);
|
||||
y = _mm_mul_ps(y, x);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p5);
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_add_ps(y, x);
|
||||
y = _mm_add_ps(y, one);
|
||||
|
||||
/* build 2^n */
|
||||
#ifndef USE_SSE2
|
||||
z = _mm_movehl_ps(z, fx);
|
||||
mm0 = _mm_cvttps_pi32(fx);
|
||||
mm1 = _mm_cvttps_pi32(z);
|
||||
mm0 = _mm_add_pi32(mm0, *(v2si*)_pi32_0x7f);
|
||||
mm1 = _mm_add_pi32(mm1, *(v2si*)_pi32_0x7f);
|
||||
mm0 = _mm_slli_pi32(mm0, 23);
|
||||
mm1 = _mm_slli_pi32(mm1, 23);
|
||||
|
||||
v4sf pow2n;
|
||||
COPY_MM_TO_XMM(mm0, mm1, pow2n);
|
||||
_mm_empty();
|
||||
#else
|
||||
emm0 = _mm_cvttps_epi32(fx);
|
||||
emm0 = _mm_add_epi32(emm0, *(v4si*)_pi32_0x7f);
|
||||
emm0 = _mm_slli_epi32(emm0, 23);
|
||||
v4sf pow2n = _mm_castsi128_ps(emm0);
|
||||
#endif
|
||||
y = _mm_mul_ps(y, pow2n);
|
||||
return y;
|
||||
}
|
||||
|
||||
_PS_CONST(minus_cephes_DP1, -0.78515625);
|
||||
_PS_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
|
||||
_PS_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
|
||||
_PS_CONST(sincof_p0, -1.9515295891E-4);
|
||||
_PS_CONST(sincof_p1, 8.3321608736E-3);
|
||||
_PS_CONST(sincof_p2, -1.6666654611E-1);
|
||||
_PS_CONST(coscof_p0, 2.443315711809948E-005);
|
||||
_PS_CONST(coscof_p1, -1.388731625493765E-003);
|
||||
_PS_CONST(coscof_p2, 4.166664568298827E-002);
|
||||
_PS_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
|
||||
|
||||
|
||||
/* evaluation of 4 sines at onces, using only SSE1+MMX intrinsics so
|
||||
it runs also on old athlons XPs and the pentium III of your grand
|
||||
mother.
|
||||
|
||||
The code is the exact rewriting of the cephes sinf function.
|
||||
Precision is excellent as long as x < 8192 (I did not bother to
|
||||
take into account the special handling they have for greater values
|
||||
-- it does not return garbage for arguments over 8192, though, but
|
||||
the extra precision is missing).
|
||||
|
||||
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
|
||||
surprising but correct result.
|
||||
|
||||
Performance is also surprisingly good, 1.33 times faster than the
|
||||
macos vsinf SSE2 function, and 1.5 times faster than the
|
||||
__vrs4_sinf of amd's ACML (which is only available in 64 bits). Not
|
||||
too bad for an SSE1 function (with no special tuning) !
|
||||
However the latter libraries probably have a much better handling of NaN,
|
||||
Inf, denormalized and other special arguments..
|
||||
|
||||
On my core 1 duo, the execution of this function takes approximately 95 cycles.
|
||||
|
||||
From what I have observed on the experiments with Intel AMath lib, switching to an
|
||||
SSE2 version would improve the perf by only 10%.
|
||||
|
||||
Since it is based on SSE intrinsics, it has to be compiled at -O2 to
|
||||
deliver full speed.
|
||||
*/
|
||||
static inline v4sf sin_ps(v4sf x) { // any x
|
||||
v4sf xmm1, xmm2 = _mm_setzero_ps(), xmm3, sign_bit, y;
|
||||
|
||||
#ifdef USE_SSE2
|
||||
v4si emm0, emm2;
|
||||
#else
|
||||
v2si mm0, mm1, mm2, mm3;
|
||||
#endif
|
||||
sign_bit = x;
|
||||
/* take the absolute value */
|
||||
x = _mm_and_ps(x, *(v4sf*)_ps_inv_sign_mask);
|
||||
/* extract the sign bit (upper one) */
|
||||
sign_bit = _mm_and_ps(sign_bit, *(v4sf*)_ps_sign_mask);
|
||||
|
||||
/* scale by 4/Pi */
|
||||
y = _mm_mul_ps(x, *(v4sf*)_ps_cephes_FOPI);
|
||||
|
||||
#ifdef USE_SSE2
|
||||
/* store the integer part of y in mm0 */
|
||||
emm2 = _mm_cvttps_epi32(y);
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
emm2 = _mm_add_epi32(emm2, *(v4si*)_pi32_1);
|
||||
emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_inv1);
|
||||
y = _mm_cvtepi32_ps(emm2);
|
||||
|
||||
/* get the swap sign flag */
|
||||
emm0 = _mm_and_si128(emm2, *(v4si*)_pi32_4);
|
||||
emm0 = _mm_slli_epi32(emm0, 29);
|
||||
/* get the polynom selection mask
|
||||
there is one polynom for 0 <= x <= Pi/4
|
||||
and another one for Pi/4<x<=Pi/2
|
||||
|
||||
Both branches will be computed.
|
||||
*/
|
||||
emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_2);
|
||||
emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
|
||||
|
||||
v4sf swap_sign_bit = _mm_castsi128_ps(emm0);
|
||||
v4sf poly_mask = _mm_castsi128_ps(emm2);
|
||||
sign_bit = _mm_xor_ps(sign_bit, swap_sign_bit);
|
||||
|
||||
#else
|
||||
/* store the integer part of y in mm0:mm1 */
|
||||
xmm2 = _mm_movehl_ps(xmm2, y);
|
||||
mm2 = _mm_cvttps_pi32(y);
|
||||
mm3 = _mm_cvttps_pi32(xmm2);
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
mm2 = _mm_add_pi32(mm2, *(v2si*)_pi32_1);
|
||||
mm3 = _mm_add_pi32(mm3, *(v2si*)_pi32_1);
|
||||
mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_inv1);
|
||||
mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_inv1);
|
||||
y = _mm_cvtpi32x2_ps(mm2, mm3);
|
||||
/* get the swap sign flag */
|
||||
mm0 = _mm_and_si64(mm2, *(v2si*)_pi32_4);
|
||||
mm1 = _mm_and_si64(mm3, *(v2si*)_pi32_4);
|
||||
mm0 = _mm_slli_pi32(mm0, 29);
|
||||
mm1 = _mm_slli_pi32(mm1, 29);
|
||||
/* get the polynom selection mask */
|
||||
mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_2);
|
||||
mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_2);
|
||||
mm2 = _mm_cmpeq_pi32(mm2, _mm_setzero_si64());
|
||||
mm3 = _mm_cmpeq_pi32(mm3, _mm_setzero_si64());
|
||||
v4sf swap_sign_bit, poly_mask;
|
||||
COPY_MM_TO_XMM(mm0, mm1, swap_sign_bit);
|
||||
COPY_MM_TO_XMM(mm2, mm3, poly_mask);
|
||||
sign_bit = _mm_xor_ps(sign_bit, swap_sign_bit);
|
||||
_mm_empty(); /* good-bye mmx */
|
||||
#endif
|
||||
|
||||
/* The magic pass: "Extended precision modular arithmetic"
|
||||
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
||||
xmm1 = *(v4sf*)_ps_minus_cephes_DP1;
|
||||
xmm2 = *(v4sf*)_ps_minus_cephes_DP2;
|
||||
xmm3 = *(v4sf*)_ps_minus_cephes_DP3;
|
||||
xmm1 = _mm_mul_ps(y, xmm1);
|
||||
xmm2 = _mm_mul_ps(y, xmm2);
|
||||
xmm3 = _mm_mul_ps(y, xmm3);
|
||||
x = _mm_add_ps(x, xmm1);
|
||||
x = _mm_add_ps(x, xmm2);
|
||||
x = _mm_add_ps(x, xmm3);
|
||||
|
||||
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
||||
y = *(v4sf*)_ps_coscof_p0;
|
||||
v4sf z = _mm_mul_ps(x,x);
|
||||
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p1);
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p2);
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_mul_ps(y, z);
|
||||
v4sf tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
|
||||
y = _mm_sub_ps(y, tmp);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_1);
|
||||
|
||||
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
||||
|
||||
v4sf y2 = *(v4sf*)_ps_sincof_p0;
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p1);
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p2);
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_mul_ps(y2, x);
|
||||
y2 = _mm_add_ps(y2, x);
|
||||
|
||||
/* select the correct result from the two polynoms */
|
||||
xmm3 = poly_mask;
|
||||
y2 = _mm_and_ps(xmm3, y2); //, xmm3);
|
||||
y = _mm_andnot_ps(xmm3, y);
|
||||
y = _mm_add_ps(y,y2);
|
||||
/* update the sign */
|
||||
y = _mm_xor_ps(y, sign_bit);
|
||||
return y;
|
||||
}
|
||||
|
||||
/* almost the same as sin_ps */
|
||||
static inline v4sf cos_ps(v4sf x) { // any x
|
||||
v4sf xmm1, xmm2 = _mm_setzero_ps(), xmm3, y;
|
||||
#ifdef USE_SSE2
|
||||
v4si emm0, emm2;
|
||||
#else
|
||||
v2si mm0, mm1, mm2, mm3;
|
||||
#endif
|
||||
/* take the absolute value */
|
||||
x = _mm_and_ps(x, *(v4sf*)_ps_inv_sign_mask);
|
||||
|
||||
/* scale by 4/Pi */
|
||||
y = _mm_mul_ps(x, *(v4sf*)_ps_cephes_FOPI);
|
||||
|
||||
#ifdef USE_SSE2
|
||||
/* store the integer part of y in mm0 */
|
||||
emm2 = _mm_cvttps_epi32(y);
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
emm2 = _mm_add_epi32(emm2, *(v4si*)_pi32_1);
|
||||
emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_inv1);
|
||||
y = _mm_cvtepi32_ps(emm2);
|
||||
|
||||
emm2 = _mm_sub_epi32(emm2, *(v4si*)_pi32_2);
|
||||
|
||||
/* get the swap sign flag */
|
||||
emm0 = _mm_andnot_si128(emm2, *(v4si*)_pi32_4);
|
||||
emm0 = _mm_slli_epi32(emm0, 29);
|
||||
/* get the polynom selection mask */
|
||||
emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_2);
|
||||
emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
|
||||
|
||||
v4sf sign_bit = _mm_castsi128_ps(emm0);
|
||||
v4sf poly_mask = _mm_castsi128_ps(emm2);
|
||||
#else
|
||||
/* store the integer part of y in mm0:mm1 */
|
||||
xmm2 = _mm_movehl_ps(xmm2, y);
|
||||
mm2 = _mm_cvttps_pi32(y);
|
||||
mm3 = _mm_cvttps_pi32(xmm2);
|
||||
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
mm2 = _mm_add_pi32(mm2, *(v2si*)_pi32_1);
|
||||
mm3 = _mm_add_pi32(mm3, *(v2si*)_pi32_1);
|
||||
mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_inv1);
|
||||
mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_inv1);
|
||||
|
||||
y = _mm_cvtpi32x2_ps(mm2, mm3);
|
||||
|
||||
|
||||
mm2 = _mm_sub_pi32(mm2, *(v2si*)_pi32_2);
|
||||
mm3 = _mm_sub_pi32(mm3, *(v2si*)_pi32_2);
|
||||
|
||||
/* get the swap sign flag in mm0:mm1 and the
|
||||
polynom selection mask in mm2:mm3 */
|
||||
|
||||
mm0 = _mm_andnot_si64(mm2, *(v2si*)_pi32_4);
|
||||
mm1 = _mm_andnot_si64(mm3, *(v2si*)_pi32_4);
|
||||
mm0 = _mm_slli_pi32(mm0, 29);
|
||||
mm1 = _mm_slli_pi32(mm1, 29);
|
||||
|
||||
mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_2);
|
||||
mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_2);
|
||||
|
||||
mm2 = _mm_cmpeq_pi32(mm2, _mm_setzero_si64());
|
||||
mm3 = _mm_cmpeq_pi32(mm3, _mm_setzero_si64());
|
||||
|
||||
v4sf sign_bit, poly_mask;
|
||||
COPY_MM_TO_XMM(mm0, mm1, sign_bit);
|
||||
COPY_MM_TO_XMM(mm2, mm3, poly_mask);
|
||||
_mm_empty(); /* good-bye mmx */
|
||||
#endif
|
||||
/* The magic pass: "Extended precision modular arithmetic"
|
||||
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
||||
xmm1 = *(v4sf*)_ps_minus_cephes_DP1;
|
||||
xmm2 = *(v4sf*)_ps_minus_cephes_DP2;
|
||||
xmm3 = *(v4sf*)_ps_minus_cephes_DP3;
|
||||
xmm1 = _mm_mul_ps(y, xmm1);
|
||||
xmm2 = _mm_mul_ps(y, xmm2);
|
||||
xmm3 = _mm_mul_ps(y, xmm3);
|
||||
x = _mm_add_ps(x, xmm1);
|
||||
x = _mm_add_ps(x, xmm2);
|
||||
x = _mm_add_ps(x, xmm3);
|
||||
|
||||
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
||||
y = *(v4sf*)_ps_coscof_p0;
|
||||
v4sf z = _mm_mul_ps(x,x);
|
||||
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p1);
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p2);
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_mul_ps(y, z);
|
||||
v4sf tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
|
||||
y = _mm_sub_ps(y, tmp);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_1);
|
||||
|
||||
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
||||
|
||||
v4sf y2 = *(v4sf*)_ps_sincof_p0;
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p1);
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p2);
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_mul_ps(y2, x);
|
||||
y2 = _mm_add_ps(y2, x);
|
||||
|
||||
/* select the correct result from the two polynoms */
|
||||
xmm3 = poly_mask;
|
||||
y2 = _mm_and_ps(xmm3, y2); //, xmm3);
|
||||
y = _mm_andnot_ps(xmm3, y);
|
||||
y = _mm_add_ps(y,y2);
|
||||
/* update the sign */
|
||||
y = _mm_xor_ps(y, sign_bit);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
/* since sin_ps and cos_ps are almost identical, sincos_ps could replace both of them..
|
||||
it is almost as fast, and gives you a free cosine with your sine */
|
||||
static inline void sincos_ps(v4sf x, v4sf *s, v4sf *c) {
|
||||
v4sf xmm1, xmm2, xmm3 = _mm_setzero_ps(), sign_bit_sin, y;
|
||||
#ifdef USE_SSE2
|
||||
v4si emm0, emm2, emm4;
|
||||
#else
|
||||
v2si mm0, mm1, mm2, mm3, mm4, mm5;
|
||||
#endif
|
||||
sign_bit_sin = x;
|
||||
/* take the absolute value */
|
||||
x = _mm_and_ps(x, *(v4sf*)_ps_inv_sign_mask);
|
||||
/* extract the sign bit (upper one) */
|
||||
sign_bit_sin = _mm_and_ps(sign_bit_sin, *(v4sf*)_ps_sign_mask);
|
||||
|
||||
/* scale by 4/Pi */
|
||||
y = _mm_mul_ps(x, *(v4sf*)_ps_cephes_FOPI);
|
||||
|
||||
#ifdef USE_SSE2
|
||||
/* store the integer part of y in emm2 */
|
||||
emm2 = _mm_cvttps_epi32(y);
|
||||
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
emm2 = _mm_add_epi32(emm2, *(v4si*)_pi32_1);
|
||||
emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_inv1);
|
||||
y = _mm_cvtepi32_ps(emm2);
|
||||
|
||||
emm4 = emm2;
|
||||
|
||||
/* get the swap sign flag for the sine */
|
||||
emm0 = _mm_and_si128(emm2, *(v4si*)_pi32_4);
|
||||
emm0 = _mm_slli_epi32(emm0, 29);
|
||||
v4sf swap_sign_bit_sin = _mm_castsi128_ps(emm0);
|
||||
|
||||
/* get the polynom selection mask for the sine*/
|
||||
emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_2);
|
||||
emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
|
||||
v4sf poly_mask = _mm_castsi128_ps(emm2);
|
||||
#else
|
||||
/* store the integer part of y in mm2:mm3 */
|
||||
xmm3 = _mm_movehl_ps(xmm3, y);
|
||||
mm2 = _mm_cvttps_pi32(y);
|
||||
mm3 = _mm_cvttps_pi32(xmm3);
|
||||
|
||||
/* j=(j+1) & (~1) (see the cephes sources) */
|
||||
mm2 = _mm_add_pi32(mm2, *(v2si*)_pi32_1);
|
||||
mm3 = _mm_add_pi32(mm3, *(v2si*)_pi32_1);
|
||||
mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_inv1);
|
||||
mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_inv1);
|
||||
|
||||
y = _mm_cvtpi32x2_ps(mm2, mm3);
|
||||
|
||||
mm4 = mm2;
|
||||
mm5 = mm3;
|
||||
|
||||
/* get the swap sign flag for the sine */
|
||||
mm0 = _mm_and_si64(mm2, *(v2si*)_pi32_4);
|
||||
mm1 = _mm_and_si64(mm3, *(v2si*)_pi32_4);
|
||||
mm0 = _mm_slli_pi32(mm0, 29);
|
||||
mm1 = _mm_slli_pi32(mm1, 29);
|
||||
v4sf swap_sign_bit_sin;
|
||||
COPY_MM_TO_XMM(mm0, mm1, swap_sign_bit_sin);
|
||||
|
||||
/* get the polynom selection mask for the sine */
|
||||
|
||||
mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_2);
|
||||
mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_2);
|
||||
mm2 = _mm_cmpeq_pi32(mm2, _mm_setzero_si64());
|
||||
mm3 = _mm_cmpeq_pi32(mm3, _mm_setzero_si64());
|
||||
v4sf poly_mask;
|
||||
COPY_MM_TO_XMM(mm2, mm3, poly_mask);
|
||||
#endif
|
||||
|
||||
/* The magic pass: "Extended precision modular arithmetic"
|
||||
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
||||
xmm1 = *(v4sf*)_ps_minus_cephes_DP1;
|
||||
xmm2 = *(v4sf*)_ps_minus_cephes_DP2;
|
||||
xmm3 = *(v4sf*)_ps_minus_cephes_DP3;
|
||||
xmm1 = _mm_mul_ps(y, xmm1);
|
||||
xmm2 = _mm_mul_ps(y, xmm2);
|
||||
xmm3 = _mm_mul_ps(y, xmm3);
|
||||
x = _mm_add_ps(x, xmm1);
|
||||
x = _mm_add_ps(x, xmm2);
|
||||
x = _mm_add_ps(x, xmm3);
|
||||
|
||||
#ifdef USE_SSE2
|
||||
emm4 = _mm_sub_epi32(emm4, *(v4si*)_pi32_2);
|
||||
emm4 = _mm_andnot_si128(emm4, *(v4si*)_pi32_4);
|
||||
emm4 = _mm_slli_epi32(emm4, 29);
|
||||
v4sf sign_bit_cos = _mm_castsi128_ps(emm4);
|
||||
#else
|
||||
/* get the sign flag for the cosine */
|
||||
mm4 = _mm_sub_pi32(mm4, *(v2si*)_pi32_2);
|
||||
mm5 = _mm_sub_pi32(mm5, *(v2si*)_pi32_2);
|
||||
mm4 = _mm_andnot_si64(mm4, *(v2si*)_pi32_4);
|
||||
mm5 = _mm_andnot_si64(mm5, *(v2si*)_pi32_4);
|
||||
mm4 = _mm_slli_pi32(mm4, 29);
|
||||
mm5 = _mm_slli_pi32(mm5, 29);
|
||||
v4sf sign_bit_cos;
|
||||
COPY_MM_TO_XMM(mm4, mm5, sign_bit_cos);
|
||||
_mm_empty(); /* good-bye mmx */
|
||||
#endif
|
||||
|
||||
sign_bit_sin = _mm_xor_ps(sign_bit_sin, swap_sign_bit_sin);
|
||||
|
||||
|
||||
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
||||
v4sf z = _mm_mul_ps(x,x);
|
||||
y = *(v4sf*)_ps_coscof_p0;
|
||||
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p1);
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p2);
|
||||
y = _mm_mul_ps(y, z);
|
||||
y = _mm_mul_ps(y, z);
|
||||
v4sf tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
|
||||
y = _mm_sub_ps(y, tmp);
|
||||
y = _mm_add_ps(y, *(v4sf*)_ps_1);
|
||||
|
||||
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
||||
|
||||
v4sf y2 = *(v4sf*)_ps_sincof_p0;
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p1);
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p2);
|
||||
y2 = _mm_mul_ps(y2, z);
|
||||
y2 = _mm_mul_ps(y2, x);
|
||||
y2 = _mm_add_ps(y2, x);
|
||||
|
||||
/* select the correct result from the two polynoms */
|
||||
xmm3 = poly_mask;
|
||||
v4sf ysin2 = _mm_and_ps(xmm3, y2);
|
||||
v4sf ysin1 = _mm_andnot_ps(xmm3, y);
|
||||
y2 = _mm_sub_ps(y2,ysin2);
|
||||
y = _mm_sub_ps(y, ysin1);
|
||||
|
||||
xmm1 = _mm_add_ps(ysin1,ysin2);
|
||||
xmm2 = _mm_add_ps(y,y2);
|
||||
|
||||
/* update the sign */
|
||||
*s = _mm_xor_ps(xmm1, sign_bit_sin);
|
||||
*c = _mm_xor_ps(xmm2, sign_bit_cos);
|
||||
}
|
||||
|
|
@ -107,6 +107,7 @@ class ThreadPool {
|
|||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads, size_t in_bound)
|
||||
: bound(in_bound), stop(false) {
|
||||
ABORT_IF(getThrowExceptionOnAbort(), "Throwing of MarianRuntimeException not presently supported in threads");
|
||||
reserve(threads);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ static std::string strerror()
|
|||
{
|
||||
buff = "Unknown error";
|
||||
}
|
||||
#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && ! _GNU_SOURCE
|
||||
#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__) && ! _GNU_SOURCE
|
||||
// XSI-compliant strerror_r()
|
||||
if (strerror_r(errno, &buff[0], buff.size()) != 0)
|
||||
{
|
||||
|
@ -125,7 +125,7 @@ struct static_method_holder
|
|||
is_p->peek();
|
||||
peek_failed = is_p->fail();
|
||||
}
|
||||
catch (std::ios_base::failure e) {}
|
||||
catch (const std::ios_base::failure &e) {}
|
||||
if (peek_failed)
|
||||
{
|
||||
throw Exception(std::string("strict_fstream: open('")
|
||||
|
|
|
@ -4,8 +4,12 @@ include_directories(.)
|
|||
include_directories(3rd_party)
|
||||
include_directories(3rd_party/SQLiteCpp/include)
|
||||
include_directories(3rd_party/sentencepiece)
|
||||
include_directories(3rd_party/fbgemm/include)
|
||||
include_directories(${CMAKE_BINARY_DIR}/local/include)
|
||||
|
||||
add_library(marian STATIC
|
||||
common/aliases.cpp
|
||||
common/fastopt.cpp
|
||||
common/version.cpp
|
||||
common/utils.cpp
|
||||
common/logging.cpp
|
||||
|
@ -14,13 +18,19 @@ add_library(marian STATIC
|
|||
common/config.cpp
|
||||
common/config_parser.cpp
|
||||
common/config_validator.cpp
|
||||
common/options.cpp
|
||||
common/binary.cpp
|
||||
common/build_info.cpp
|
||||
common/io.cpp
|
||||
common/filesystem.cpp
|
||||
common/file_stream.cpp
|
||||
common/types.cpp
|
||||
|
||||
data/alignment.cpp
|
||||
data/vocab.cpp
|
||||
data/default_vocab.cpp
|
||||
data/sentencepiece_vocab.cpp
|
||||
data/factored_vocab.cpp
|
||||
data/corpus_base.cpp
|
||||
data/corpus.cpp
|
||||
data/corpus_sqlite.cpp
|
||||
|
@ -30,8 +40,11 @@ add_library(marian STATIC
|
|||
3rd_party/cnpy/cnpy.cpp
|
||||
3rd_party/ExceptionWithCallStack.cpp
|
||||
|
||||
3rd_party/phf/phf.cc
|
||||
|
||||
tensors/backend.cpp
|
||||
tensors/rand.cpp
|
||||
tensors/tensor.cpp
|
||||
tensors/cpu/device.cpp
|
||||
tensors/cpu/prod.cpp
|
||||
tensors/cpu/tensor_operators.cpp
|
||||
|
@ -39,6 +52,7 @@ add_library(marian STATIC
|
|||
tensors/cpu/sharp/int_gemm.cpp
|
||||
tensors/cpu/sharp/avx_gemm.cpp
|
||||
tensors/cpu/sharp/sse_gemm.cpp
|
||||
tensors/cpu/fbgemm/packed_gemm.cpp
|
||||
|
||||
graph/expression_graph.cpp
|
||||
graph/expression_operators.cpp
|
||||
|
@ -47,6 +61,7 @@ add_library(marian STATIC
|
|||
graph/node_initializers.cpp
|
||||
|
||||
layers/convolution.cpp
|
||||
layers/generic.cpp
|
||||
layers/loss.cpp
|
||||
layers/weight.cpp
|
||||
|
||||
|
@ -77,6 +92,7 @@ add_library(marian STATIC
|
|||
training/graph_group_multinode_sync.cpp
|
||||
training/validator.cpp
|
||||
training/communicator.cpp
|
||||
training/scheduler.cpp
|
||||
|
||||
# this is only compiled to catch build errors, but not linked
|
||||
microsoft/quicksand.cpp
|
||||
|
@ -91,21 +107,40 @@ target_compile_options(marian PUBLIC ${ALL_WARNINGS})
|
|||
# Generate git_revision.h to reflect current git revision information
|
||||
# [https://stackoverflow.com/questions/1435953/how-can-i-pass-git-sha1-to-compiler-as-definition-using-cmake]
|
||||
# Git updates .git/logs/HEAD file whenever you pull or commit something.
|
||||
|
||||
# If Marian is checked out as a submodule in another repository,
|
||||
# there's no .git directory in ${CMAKE_SOURCE_DIR}. Instead .git is a
|
||||
# file that specifies the relative path from ${CMAKE_SOURCE_DIR} to
|
||||
# ./git/modules/<MARIAN_ROOT_DIR> in the root of the repository that
|
||||
# contains Marian as a submodule. We set MARIAN_GIT_DIR to the appropriate
|
||||
# path, depending on whether ${CMAKE_SOURCE_DIR}/.git is a directory or file.
|
||||
if(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) # not a submodule
|
||||
set(MARIAN_GIT_DIR ${CMAKE_SOURCE_DIR}/.git)
|
||||
else(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
|
||||
file(READ ${CMAKE_SOURCE_DIR}/.git MARIAN_GIT_DIR)
|
||||
string(REGEX REPLACE "gitdir: (.*)\n" "\\1" MARIAN_GIT_DIR ${MARIAN_GIT_DIR})
|
||||
get_filename_component(MARIAN_GIT_DIR "${CMAKE_SOURCE_DIR}/${MARIAN_GIT_DIR}" ABSOLUTE)
|
||||
endif(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
|
||||
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||
COMMAND git log -1 --pretty=format:\#define\ GIT_REVISION\ \"\%h\ \%ai\" > ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/.git/logs/HEAD
|
||||
DEPENDS ${MARIAN_GIT_DIR}/logs/HEAD
|
||||
VERBATIM
|
||||
)
|
||||
add_custom_target(marian_version DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h)
|
||||
add_dependencies(marian marian_version) # marian must depend on it so that it gets created first
|
||||
# make sure all local dependencies are installed first before this is built
|
||||
add_dependencies(marian 3rd_party_installs)
|
||||
|
||||
if(CUDA_FOUND)
|
||||
cuda_add_library(marian_cuda
|
||||
tensors/gpu/device.cu
|
||||
tensors/gpu/algorithm.cu
|
||||
tensors/gpu/prod.cu
|
||||
tensors/gpu/prod.cpp
|
||||
tensors/gpu/element.cu
|
||||
tensors/gpu/add.cu
|
||||
tensors/gpu/add_all.cu
|
||||
tensors/gpu/tensor_operators.cu
|
||||
tensors/gpu/cudnn_wrappers.cu
|
||||
translator/nth_element.cu
|
||||
|
@ -115,6 +150,8 @@ cuda_add_library(marian_cuda
|
|||
STATIC)
|
||||
|
||||
target_compile_options(marian_cuda PUBLIC ${ALL_WARNINGS})
|
||||
# make sure all local dependencies are installed first before this is built
|
||||
add_dependencies(marian_cuda 3rd_party_installs)
|
||||
endif(CUDA_FOUND)
|
||||
|
||||
set_target_properties(marian PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
|
@ -179,6 +216,10 @@ if(COMPILE_SERVER)
|
|||
set(EXECUTABLES ${EXECUTABLES} marian_server)
|
||||
endif(COMPILE_SERVER)
|
||||
|
||||
if(APPLE) # This is a dependency of pathie but I can't seem to link it into that CMakeLists because we're not compiling it as a library.
|
||||
set(EXT_LIBS ${EXT_LIBS} iconv)
|
||||
endif()
|
||||
|
||||
foreach(exec ${EXECUTABLES})
|
||||
target_link_libraries(${exec} marian ${EXT_LIBS} ${EXT_LIBS} ${CMAKE_THREAD_LIBS_INIT})
|
||||
if(CUDA_FOUND)
|
||||
|
@ -187,13 +228,6 @@ foreach(exec ${EXECUTABLES})
|
|||
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
endforeach(exec)
|
||||
|
||||
#add_executable(
|
||||
# align2steps
|
||||
# tools/align2steps.cpp
|
||||
#)
|
||||
|
||||
#set_target_properties(align2steps PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
|
||||
if(COMPILE_TESTS)
|
||||
set(CATCH_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/3rd_party)
|
||||
add_library(Catch INTERFACE)
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
#include <sstream>
|
||||
|
||||
#include "tensors/cpu/fbgemm/expression_graph_packable.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
using namespace marian;
|
||||
|
||||
|
@ -11,19 +13,36 @@ int main(int argc, char** argv) {
|
|||
|
||||
auto options = New<Options>();
|
||||
{
|
||||
YAML::Node config; // @TODO: get rid of YAML::Node here entirely to avoid the pattern. Currently not fixing as it requires more changes to the Options object.
|
||||
auto cli = New<cli::CLIWrapper>(
|
||||
options,
|
||||
"Convert a model in the .npz format to a mmap-able binary model",
|
||||
config,
|
||||
"Convert a model in the .npz format and normal memory layout to a mmap-able binary model which could be in normal memory layout or packed memory layout",
|
||||
"Allowed options",
|
||||
"Examples:\n"
|
||||
" ./marian-conv -f model.npz -t model.bin");
|
||||
" ./marian-conv -f model.npz -t model.bin --gemm-type packed16");
|
||||
cli->add<std::string>("--from,-f", "Input model", "model.npz");
|
||||
cli->add<std::string>("--to,-t", "Output model", "model.bin");
|
||||
cli->add<std::string>("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512", "float32");
|
||||
cli->parse(argc, argv);
|
||||
options->merge(config);
|
||||
}
|
||||
auto modelFrom = options->get<std::string>("from");
|
||||
auto modelTo = options->get<std::string>("to");
|
||||
|
||||
auto saveGemmTypeStr = options->get<std::string>("gemm-type", "float32");
|
||||
Type saveGemmType;
|
||||
if(saveGemmTypeStr == "float32") {
|
||||
saveGemmType = Type::float32;
|
||||
} else if(saveGemmTypeStr == "packed16") { // packed16 only supports AVX2. AVX512 might be added later
|
||||
saveGemmType = Type::packed16;
|
||||
} else if(saveGemmTypeStr == "packed8avx2") { // packed8 for AVX2
|
||||
saveGemmType = Type::packed8avx2;
|
||||
} else if(saveGemmTypeStr == "packed8avx512") { // packed8 for AVX512
|
||||
saveGemmType = Type::packed8avx512;
|
||||
} else {
|
||||
ABORT("Unknown gemm-type: {}", saveGemmTypeStr);
|
||||
}
|
||||
|
||||
LOG(info, "Outputting {}", modelTo);
|
||||
|
||||
YAML::Node config;
|
||||
|
@ -31,12 +50,14 @@ int main(int argc, char** argv) {
|
|||
marian::io::getYamlFromModel(config, "special:model.yml", modelFrom);
|
||||
configStr << config;
|
||||
|
||||
auto graph = New<ExpressionGraph>(true, false);
|
||||
auto graph = New<ExpressionGraphPackable>();
|
||||
graph->setDevice(CPU0);
|
||||
graph->getBackend()->setOptimized(false);
|
||||
|
||||
graph->load(modelFrom);
|
||||
graph->forward();
|
||||
graph->save(modelTo, configStr.str());
|
||||
// added a flag if the weights needs to be packed or not
|
||||
graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32);
|
||||
|
||||
// graph->saveBinary(vm["bin"].as<std::string>());
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
using namespace marian;
|
||||
|
||||
auto options = parseOptions(argc, argv, cli::mode::translation);
|
||||
auto task = New<Translate<BeamSearch>>(options);
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ int main(int argc, char **argv) {
|
|||
using namespace marian;
|
||||
|
||||
// Initialize translation task
|
||||
auto options = parseOptions(argc, argv, cli::mode::translation, true);
|
||||
auto options = parseOptions(argc, argv, cli::mode::server, true);
|
||||
auto task = New<TranslateService<BeamSearch>>(options);
|
||||
|
||||
// Initialize web server
|
||||
|
@ -44,7 +44,7 @@ int main(int argc, char **argv) {
|
|||
|
||||
// Error Codes for error code meanings
|
||||
// http://www.boost.org/doc/libs/1_55_0/doc/html/boost_asio/reference.html
|
||||
translate.on_error = [](Ptr<WSServer::Connection> connection,
|
||||
translate.on_error = [](Ptr<WSServer::Connection> /*connection*/,
|
||||
const SimpleWeb::error_code &ec) {
|
||||
LOG(error, "Connection error: ({}) {}", ec.value(), ec.message());
|
||||
};
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
#include <signal.h>
|
||||
#include "marian.h"
|
||||
|
||||
#include "training/graph_group_async.h"
|
||||
|
@ -11,10 +12,12 @@
|
|||
#include "training/graph_group_multinode.h"
|
||||
#endif
|
||||
|
||||
#include "3rd_party/ExceptionWithCallStack.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
using namespace marian;
|
||||
|
||||
auto options = parseOptions(argc, argv);
|
||||
auto options = parseOptions(argc, argv, cli::mode::training);
|
||||
|
||||
// selects MultiNodeGraphGroup family
|
||||
//
|
||||
|
@ -66,5 +69,13 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
// If we exit due to SIGTERM, exit with 128 + the signal number, as suggested
|
||||
// for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent
|
||||
// scripts to determine if training terminated naturally or via SIGTERM.
|
||||
// Whith this approach we can accommodate additional signals in the future.
|
||||
// An alternative would be to return 124, which is what the timeout command
|
||||
// returns for timeout -s SIGTERM <seconds> ...., because exiting after SIGTERM
|
||||
// is not technically a fatal error (which is what the 128+x convention usually
|
||||
// stands for).
|
||||
return getSigtermFlag() ? (128 + SIGTERM) : 0;
|
||||
}
|
||||
|
|
|
@ -9,10 +9,11 @@ int main(int argc, char** argv) {
|
|||
|
||||
createLoggers();
|
||||
|
||||
auto options = New<Options>();
|
||||
Ptr<Options> options = New<Options>();
|
||||
{
|
||||
YAML::Node config; // @TODO: get rid of YAML::Node here entirely to avoid the pattern. Currently not fixing as it requires more changes to the Options object.
|
||||
auto cli = New<cli::CLIWrapper>(
|
||||
options,
|
||||
config,
|
||||
"Create a vocabulary from text corpora given on STDIN",
|
||||
"Allowed options",
|
||||
"Examples:\n"
|
||||
|
@ -20,6 +21,7 @@ int main(int argc, char** argv) {
|
|||
" cat text.src text.trg | ./marian-vocab > vocab.yml");
|
||||
cli->add<size_t>("--max-size,-m", "Generate only UINT most common vocabulary items", 0);
|
||||
cli->parse(argc, argv);
|
||||
options->merge(config);
|
||||
}
|
||||
|
||||
LOG(info, "Creating vocabulary...");
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
#include "common/config_parser.h"
|
||||
#include "common/definitions.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
/**
|
||||
* Add all aliases
|
||||
*
|
||||
* An alias is a command line option name and value pair that sets multiple other non-alias
|
||||
* (standard) command line options. And example would be `--task transformer-big` which --
|
||||
* as a whole -- is an alias for setting options and hyperparameters that would be reasonable
|
||||
* for training a Google-style Transformer-Big model. Below key-value pairs
|
||||
* ("task", "transformer-base") and ("task", "transformer-big") are different aliases that result
|
||||
* in different option sets to be defined.
|
||||
*
|
||||
* The alias option has to be first defined using cli.add<T>(). Defining
|
||||
* multiple aliases for the same option name but with different values is allowed.
|
||||
*
|
||||
* As aliases are key-value pairs by default, values are compared as std::string.
|
||||
* If the command line option corresponding to the alias is a vector, the alias
|
||||
* will be triggered if the requested value exists in that vector at least once.
|
||||
*
|
||||
* @see CLIWrapper::alias()
|
||||
*
|
||||
* The order of alias definitions *does* matter: options from later aliases override earlier
|
||||
* regardless of its order in the command line or config file.
|
||||
*/
|
||||
void ConfigParser::addAliases(cli::CLIWrapper& cli) {
|
||||
cli.alias("fp16", "true", [&](YAML::Node& config) {
|
||||
if(mode_ == cli::mode::training) {
|
||||
config["precision"] = std::vector<std::string>({"float16", "float32", "float32"}); // inference type, optimization type, save type
|
||||
// @TODO: review this
|
||||
// scaling factor (power of 2), frequency, multiplier at increase, tolerance, range, minium factor
|
||||
config["cost-scaling"] = std::vector<std::string>({"7", "2000", "2", "0.05", "10", "1"});
|
||||
} else {
|
||||
config["precision"] = std::vector<std::string>({"float16"}); // for inference we do not need the other types
|
||||
}
|
||||
});
|
||||
|
||||
if(mode_ == cli::mode::training) {
|
||||
// for backwards-compatibility with older version, "--no-shuffle" maps to "--shuffle none"
|
||||
cli.alias("no-shuffle", "true", [](YAML::Node& config) {
|
||||
config["shuffle"] = "none";
|
||||
});
|
||||
|
||||
// Options setting the BiDeep architecture proposed in http://www.aclweb.org/anthology/W17-4710
|
||||
cli.alias("best-deep", "true", [](YAML::Node& config) {
|
||||
config["layer-normalization"] = true;
|
||||
config["tied-embeddings"] = true;
|
||||
config["enc-type"] = "alternating";
|
||||
config["enc-cell-depth"] = 2;
|
||||
config["enc-depth"] = 4;
|
||||
config["dec-cell-base-depth"] = 4;
|
||||
config["dec-cell-high-depth"] = 2;
|
||||
config["dec-depth"] = 4;
|
||||
config["skip"] = true;
|
||||
|
||||
// Training specific options
|
||||
config["learn-rate"] = 0.0003;
|
||||
config["cost-type"] = "ce-mean-words";
|
||||
config["lr-decay-inv-sqrt"] = 16000;
|
||||
config["label-smoothing"] = 0.1;
|
||||
config["clip-norm"] = 0;
|
||||
config["sync-sgd"] = true;
|
||||
config["exponential-smoothing"] = 1e-4;
|
||||
config["mini-batch-fit"] = true;
|
||||
config["mini-batch"] = 1000;
|
||||
config["maxi-batch"] = 1000;
|
||||
// config["workspace"] = 6500;
|
||||
});
|
||||
|
||||
// Architecture and proposed training settings for a Transformer "base" model introduced in
|
||||
// https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
|
||||
cli.alias("task", "transformer-base", [](YAML::Node& config) {
|
||||
// Model options
|
||||
config["type"] = "transformer";
|
||||
config["enc-depth"] = 6;
|
||||
config["dec-depth"] = 6;
|
||||
config["dim-emb"] = 512;
|
||||
config["tied-embeddings-all"] = true;
|
||||
config["transformer-dim-ffn"] = 2048;
|
||||
config["transformer-heads"] = 8;
|
||||
config["transformer-postprocess"] = "dan";
|
||||
config["transformer-preprocess"] = "";
|
||||
config["transformer-ffn-activation"] = "relu";
|
||||
config["transformer-dropout"] = 0.1;
|
||||
|
||||
// Training specific options
|
||||
config["learn-rate"] = 0.0003;
|
||||
config["cost-type"] = "ce-mean-words";
|
||||
config["lr-warmup"] = 16000;
|
||||
config["lr-decay-inv-sqrt"] = 16000;
|
||||
config["label-smoothing"] = 0.1;
|
||||
config["clip-norm"] = 0;
|
||||
config["sync-sgd"] = true;
|
||||
config["exponential-smoothing"] = 1e-4;
|
||||
config["max-length"] = 100;
|
||||
config["mini-batch-fit"] = true;
|
||||
config["mini-batch"] = 1000;
|
||||
config["maxi-batch"] = 1000;
|
||||
config["workspace"] = 9500;
|
||||
config["optimizer-params"] = std::vector<float>({0.9f, 0.98f, 1e-09f});
|
||||
|
||||
// Validation specific options
|
||||
config["beam-size"] = 8;
|
||||
config["valid-mini-batch"] = 16;
|
||||
config["normalize"] = 1.0;
|
||||
});
|
||||
|
||||
// Architecture and proposed training settings for a Transformer "big" model introduced in
|
||||
// https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
|
||||
cli.alias("task", "transformer-big", [](YAML::Node& config) {
|
||||
// Model options
|
||||
config["type"] = "transformer";
|
||||
config["enc-depth"] = 6;
|
||||
config["dec-depth"] = 6;
|
||||
config["dim-emb"] = 1024;
|
||||
config["tied-embeddings-all"] = true;
|
||||
config["transformer-dim-ffn"] = 4096;
|
||||
config["transformer-heads"] = 16;
|
||||
config["transformer-postprocess"] = "dan";
|
||||
config["transformer-preprocess"] = "";
|
||||
config["transformer-ffn-activation"] = "relu";
|
||||
config["transformer-dropout"] = 0.1;
|
||||
|
||||
// Training specific options
|
||||
config["learn-rate"] = 0.0002;
|
||||
config["cost-type"] = "ce-mean-words";
|
||||
config["lr-warmup"] = 8000;
|
||||
config["lr-decay-inv-sqrt"] = 8000;
|
||||
config["label-smoothing"] = 0.1;
|
||||
config["clip-norm"] = 0;
|
||||
config["sync-sgd"] = true;
|
||||
config["exponential-smoothing"] = 1e-4;
|
||||
config["max-length"] = 100;
|
||||
config["mini-batch-fit"] = true;
|
||||
config["mini-batch"] = 1000;
|
||||
config["maxi-batch"] = 1000;
|
||||
config["workspace"] = 13000;
|
||||
config["optimizer-params"] = std::vector<float>({0.9f, 0.998f, 1e-09f});
|
||||
|
||||
// Validation specific options
|
||||
config["beam-size"] = 8;
|
||||
config["valid-mini-batch"] = 8;
|
||||
config["normalize"] = 1.0;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marian
|
|
@ -0,0 +1,65 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace marian {
|
||||
|
||||
std::string citation() {
|
||||
return "Marian: Fast Neural Machine Translation in C++\n"
|
||||
"\n"
|
||||
"Please cite the following paper if you use Marian:\n"
|
||||
"\n"
|
||||
"@InProceedings{mariannmt,\n"
|
||||
" title = {Marian: Fast Neural Machine Translation in {C++}},\n"
|
||||
" author = {Junczys-Dowmunt, Marcin and Grundkiewicz, Roman and\n"
|
||||
" Dwojak, Tomasz and Hoang, Hieu and Heafield, Kenneth and\n"
|
||||
" Neckermann, Tom and Seide, Frank and Germann, Ulrich and\n"
|
||||
" Fikri Aji, Alham and Bogoychev, Nikolay and\n"
|
||||
" Martins, Andr\\'{e} F. T. and Birch, Alexandra},\n"
|
||||
" booktitle = {Proceedings of ACL 2018, System Demonstrations},\n"
|
||||
" pages = {116--121},\n"
|
||||
" publisher = {Association for Computational Linguistics},\n"
|
||||
" year = {2018},\n"
|
||||
" month = {July},\n"
|
||||
" address = {Melbourne, Australia},\n"
|
||||
" url = {http://www.aclweb.org/anthology/P18-4020}\n"
|
||||
"}\n";
|
||||
}
|
||||
|
||||
// The list of contributors has been compiled semi-automatically from the
|
||||
// GitHub contributor list in default order. That list can be printed out with
|
||||
// `git shortlog -s -n`.
|
||||
std::string authors() {
|
||||
return "Marian: Fast Neural Machine Translation in C++\n"
|
||||
"\n"
|
||||
"An inevitably non-exhaustive list of contributors:\n"
|
||||
"\n"
|
||||
"Marcin Junczys-Dowmunt <marcinjd@microsoft.com>\n"
|
||||
"Roman Grundkiewicz <rgrundki@inf.ed.ac.uk>\n"
|
||||
"Frank Seide <fseide@microsoft.com>\n"
|
||||
"Hieu Hoang <hieuhoang@gmail.com>\n"
|
||||
"Tomasz Dwojak <t.dwojak@amu.edu.pl>\n"
|
||||
"Ulrich Germann <ugermann@inf.ed.ac.uk>\n"
|
||||
"Alham Fikri Aji <afaji321@gmail.com>\n"
|
||||
"Cédric Rousseau <cedrou@gmail.com>\n"
|
||||
"Young Jin Kim <youki@microsoft.com>\n"
|
||||
"Lane Schwartz <dowobeha@gmail.com>\n"
|
||||
"Andre Martins <andre.t.martins@gmail.com>\n"
|
||||
"Nikolay Bogoychev <n.bogoych@ed.ac.uk>\n"
|
||||
"Kenneth Heafield <kheafiel@ed.ac.uk>\n"
|
||||
"Maximiliana Behnke <mbehnke@inf.ed.ac.uk>\n"
|
||||
"Tom Neckermann <tomneckermann@gmail.com>\n"
|
||||
"Hany Hassan Awadalla <hanyh@microsoft.com>\n"
|
||||
"Jim Geovedi <jim@geovedi.com>\n"
|
||||
"Catarina Silva <catarina.cruz.csilva@gmail.com>\n"
|
||||
"Jon Clark <jonathac@microsoft.com>\n"
|
||||
"Rihards Krišlauks <rihards.krislauks@gmail.com>\n"
|
||||
"Vishal Chowdhary <vishalc@microsoft.com>\n"
|
||||
"Barry Haddow <bhaddow@inf.ed.ac.uk>\n"
|
||||
"Dominik Stańczak <stanczakdominik@gmail.com>\n"
|
||||
"Michael Hutt <Michael.Hutt@gmail.com>\n"
|
||||
"Richard Wei <rxwei@users.noreply.github.com>\n"
|
||||
"Wenyong Huang <weyo.huang@gmail.com>\n"
|
||||
"alancucki <alancucki+github@gmail.com>\n";
|
||||
}
|
||||
} // namespace marian
|
|
@ -18,6 +18,7 @@ struct Header {
|
|||
size_t dataLength;
|
||||
};
|
||||
|
||||
// cast current void pointer to T pointer and move forward by num elements
|
||||
template <typename T>
|
||||
const T* get(const void*& current, size_t num = 1) {
|
||||
const T* ptr = (const T*)current;
|
||||
|
@ -32,9 +33,10 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
|
|||
binaryFileVersion,
|
||||
BINARY_FILE_VERSION);
|
||||
|
||||
size_t numHeaders = *get<size_t>(current);
|
||||
const Header* headers = get<Header>(current, numHeaders);
|
||||
size_t numHeaders = *get<size_t>(current); // number of item headers that follow
|
||||
const Header* headers = get<Header>(current, numHeaders); // read that many headers
|
||||
|
||||
// prepopulate items with meta data from headers
|
||||
items.resize(numHeaders);
|
||||
for(int i = 0; i < numHeaders; ++i) {
|
||||
items[i].type = (Type)headers[i].type;
|
||||
|
@ -42,21 +44,22 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
|
|||
items[i].mapped = mapped;
|
||||
}
|
||||
|
||||
// read in actual shape and data
|
||||
for(int i = 0; i < numHeaders; ++i) {
|
||||
size_t len = headers[i].shapeLength;
|
||||
items[i].shape.resize(len);
|
||||
const int* arr = get<int>(current, len);
|
||||
std::copy(arr, arr + len, items[i].shape.begin());
|
||||
const int* arr = get<int>(current, len); // read shape
|
||||
std::copy(arr, arr + len, items[i].shape.begin()); // copy to Item::shape
|
||||
}
|
||||
|
||||
// move by offset bytes
|
||||
// move by offset bytes, aligned to 256-bytes boundary
|
||||
size_t offset = *get<size_t>(current);
|
||||
get<char>(current, offset);
|
||||
|
||||
for(int i = 0; i < numHeaders; ++i) {
|
||||
if(items[i].mapped) {
|
||||
if(items[i].mapped) { // memory-mapped, hence only set pointer
|
||||
items[i].ptr = get<char>(current, headers[i].dataLength);
|
||||
} else {
|
||||
} else { // reading into item data
|
||||
size_t len = headers[i].dataLength;
|
||||
items[i].bytes.resize(len);
|
||||
const char* ptr = get<char>(current, len);
|
||||
|
@ -68,15 +71,21 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
|
|||
void loadItems(const std::string& fileName, std::vector<io::Item>& items) {
|
||||
// Read file into buffer
|
||||
size_t fileSize = filesystem::fileSize(fileName);
|
||||
char* ptr = new char[fileSize];
|
||||
std::vector<char> buf(fileSize);
|
||||
// @TODO: check this again:
|
||||
#if 1 // for some reason, the #else branch fails with "file not found" in the *read* operation (open succeeds)
|
||||
FILE *f = fopen(fileName.c_str(), "rb");
|
||||
ABORT_IF(f == nullptr, "Error {} ('{}') opening file '{}'", errno, strerror(errno), fileName);
|
||||
auto rc = fread(buf.data(), sizeof(*buf.data()), buf.size(), f);
|
||||
ABORT_IF(rc != buf.size(), "Error {} ('{}') reading file '{}'", errno, strerror(errno), fileName);
|
||||
fclose(f);
|
||||
#else
|
||||
io::InputFileStream in(fileName);
|
||||
in.read(ptr, fileSize);
|
||||
in.read(buf.data(), buf.size());
|
||||
#endif
|
||||
|
||||
// Load items from buffer without mapping
|
||||
loadItems(ptr, items, false);
|
||||
|
||||
// Delete buffer
|
||||
delete[] ptr;
|
||||
loadItems(buf.data(), items, false);
|
||||
}
|
||||
|
||||
io::Item getItem(const void* current, const std::string& varName) {
|
||||
|
@ -114,7 +123,7 @@ void saveItems(const std::string& fileName,
|
|||
headers.push_back(Header{item.name.size() + 1,
|
||||
(size_t)item.type,
|
||||
item.shape.size(),
|
||||
item.size()});
|
||||
item.bytes.size()}); // binary item size with padding, will be 256-byte-aligned
|
||||
}
|
||||
|
||||
size_t headerSize = headers.size();
|
||||
|
@ -141,9 +150,11 @@ void saveItems(const std::string& fileName,
|
|||
}
|
||||
|
||||
// Write out all values
|
||||
for(const auto& item : items) {
|
||||
pos += out.write(item.data(), item.size());
|
||||
}
|
||||
for(const auto& item : items)
|
||||
pos += out.write(item.data(), item.bytes.size()); // writes out data with padding, keeps 256-byte boundary.
|
||||
// Amazingly this is binary-compatible with V1 and aligned and
|
||||
// non-aligned models can be read with the same procedure.
|
||||
// No version-bump required. Gets 5-8% of speed back when mmapped.
|
||||
}
|
||||
|
||||
} // namespace binary
|
||||
|
|
|
@ -5,10 +5,10 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Increase this if binary format changes
|
||||
#define BINARY_FILE_VERSION 1
|
||||
|
||||
namespace marian {
|
||||
|
||||
const static int BINARY_FILE_VERSION = 1;
|
||||
|
||||
namespace io {
|
||||
namespace binary {
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
#include "common/build_info.h"
|
||||
|
||||
/*
|
||||
* File build_info.cpp is generated using CMake. Do NOT modify it manually! Edit
|
||||
* build_info.cpp.in file instead.
|
||||
*/
|
||||
|
||||
std::string marian::cmakeBuildOptions() {
|
||||
return ""
|
||||
@PROJECT_CMAKE_CACHE@
|
||||
;
|
||||
}
|
||||
|
||||
std::string marian::cmakeBuildOptionsAdvanced() {
|
||||
return ""
|
||||
@PROJECT_CMAKE_CACHE_ADVANCED@
|
||||
;
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace marian {
|
||||
|
||||
// Returns list of non-advanced cache variables used by CMake
|
||||
std::string cmakeBuildOptions();
|
||||
|
||||
// Returns list of advanced cache variables used by CMake
|
||||
std::string cmakeBuildOptionsAdvanced();
|
||||
|
||||
} // namespace marian
|
|
@ -15,7 +15,6 @@ static inline std::string InterpolateEnvVars(std::string str) {
|
|||
// presently has the form /hdfs/VC instead of /{gfs,hdfs}/CLUSTER/VC
|
||||
|
||||
// Catch stdin/stdout and do not process
|
||||
std::cerr << str << std::endl;
|
||||
if(str == "stdin" || str == "stdout") {
|
||||
return str;
|
||||
}
|
||||
|
|
|
@ -3,11 +3,20 @@
|
|||
#include "common/logging.h"
|
||||
#include "common/options.h"
|
||||
#include "common/timer.h"
|
||||
#include "common/utils.h"
|
||||
#include "common/version.h"
|
||||
|
||||
namespace marian {
|
||||
namespace cli {
|
||||
|
||||
// clang-format off
|
||||
const std::unordered_set<std::string> DEPRECIATED_OPTIONS = {
|
||||
"version",
|
||||
"special-vocab"
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
||||
/*
|
||||
static uint16_t guess_terminal_width(uint16_t max_width, uint16_t default_width) {
|
||||
uint16_t cols = 0;
|
||||
|
@ -91,18 +100,14 @@ CLIWrapper::CLIWrapper(YAML::Node &config,
|
|||
optVersion_->group(defaultGroup_);
|
||||
}
|
||||
|
||||
CLIWrapper::CLIWrapper(Ptr<marian::Options> options,
|
||||
const std::string &description,
|
||||
const std::string &header,
|
||||
const std::string &footer,
|
||||
size_t columnWidth,
|
||||
size_t screenWidth)
|
||||
: CLIWrapper(options->getYaml(), description, header, footer, columnWidth, screenWidth) {}
|
||||
|
||||
CLIWrapper::~CLIWrapper() {}
|
||||
|
||||
void CLIWrapper::switchGroup(const std::string &name) {
|
||||
currentGroup_ = name.empty() ? defaultGroup_ : name;
|
||||
// set current group to name, return previous group
|
||||
std::string CLIWrapper::switchGroup(std::string name) {
|
||||
currentGroup_.swap(name);
|
||||
if (currentGroup_.empty())
|
||||
currentGroup_ = defaultGroup_;
|
||||
return name;
|
||||
}
|
||||
|
||||
void CLIWrapper::parse(int argc, char **argv) {
|
||||
|
@ -119,43 +124,110 @@ void CLIWrapper::parse(int argc, char **argv) {
|
|||
}
|
||||
}
|
||||
|
||||
std::string CLIWrapper::failureMessage(const CLI::App *app, const CLI::Error &e) {
|
||||
std::string header = "Error: " + std::string(e.what()) + "\n";
|
||||
if(app->get_help_ptr() != nullptr)
|
||||
header += "Run with " + app->get_help_ptr()->get_name() + " for more information.\n";
|
||||
return header;
|
||||
void CLIWrapper::parseAliases() {
|
||||
// Exit if no aliases defined
|
||||
if(aliases_.empty())
|
||||
return;
|
||||
|
||||
// Iterate all known aliases, each alias has a key, value, and config
|
||||
for(const auto &alias : aliases_) {
|
||||
// Check if the alias option exists in the config (it may come from command line or a config
|
||||
// file)
|
||||
if(config_[alias.key]) {
|
||||
// Check if the option in the config stores the value required to expand the alias. If so,
|
||||
// expand the alias.
|
||||
// Two cases:
|
||||
// * the option is a sequence: extract it as a vector of strings and look for the value
|
||||
// * otherwise: compare values as strings
|
||||
bool expand = false;
|
||||
if(config_[alias.key].IsSequence()) {
|
||||
auto aliasOpts = config_[alias.key].as<std::vector<std::string>>();
|
||||
expand = std::find(aliasOpts.begin(), aliasOpts.end(), alias.value) != aliasOpts.end();
|
||||
} else {
|
||||
expand = config_[alias.key].as<std::string>() == alias.value;
|
||||
}
|
||||
|
||||
bool CLIWrapper::updateConfig(const YAML::Node &config) {
|
||||
bool success = true;
|
||||
if(expand) {
|
||||
// Update global config options with the config associated with the alias. Abort if the
|
||||
// alias contains an undefined option.
|
||||
updateConfig(alias.config,
|
||||
// Priority of each expanded option is the same as the priority of the alias
|
||||
options_[alias.key].priority,
|
||||
"Unknown option(s) in alias '" + alias.key + ": " + alias.value + "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove aliases from the global config to avoid redundancy when writing/reading config files
|
||||
for(const auto &alias : aliases_) {
|
||||
config_.remove(alias.key);
|
||||
}
|
||||
}
|
||||
|
||||
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
|
||||
auto cmdOptions = getParsedOptionNames();
|
||||
// Keep track of unrecognized options from the provided config
|
||||
std::vector<std::string> unknownOpts;
|
||||
|
||||
// Iterate incoming options: they need to be merged into the global config
|
||||
for(auto it : config) {
|
||||
auto key = it.first.as<std::string>();
|
||||
// skip options specified via command-line to allow overwriting them
|
||||
|
||||
// Skip options specified via command-line to allow overwriting them
|
||||
if(cmdOptions.count(key))
|
||||
continue;
|
||||
// Skip options that might exist in config files generated by older versions of Marian
|
||||
if(DEPRECIATED_OPTIONS.count(key))
|
||||
continue;
|
||||
|
||||
// Check if an incoming option has been defined in CLI
|
||||
if(options_.count(key)) {
|
||||
// Do not proceed if the priority of incoming option is not greater than the existing option
|
||||
if(priority <= options_[key].priority) {
|
||||
continue;
|
||||
}
|
||||
// Check if the option exists in the global config and types match
|
||||
if(config_[key] && config_[key].Type() == it.second.Type()) {
|
||||
config_[key] = YAML::Clone(it.second);
|
||||
options_[key].modified = true;
|
||||
options_[key].priority = priority;
|
||||
// If types doesn't match, try to convert
|
||||
} else {
|
||||
success = false;
|
||||
// Default value is a sequence and incoming node is a scalar, hence we can upcast to
|
||||
// single element sequence
|
||||
if(config_[key].IsSequence() && it.second.IsScalar()) {
|
||||
// create single element sequence
|
||||
YAML::Node sequence;
|
||||
sequence.push_back(YAML::Clone(it.second));
|
||||
config_[key] = sequence; // overwrite to replace default values
|
||||
options_[key].priority = priority;
|
||||
} else {
|
||||
// Cannot convert other non-matching types, e.g. scalar <- list should fail
|
||||
ABORT("Cannot convert values for the option: " + key);
|
||||
}
|
||||
}
|
||||
return success;
|
||||
} else { // an unknown option
|
||||
unknownOpts.push_back(key);
|
||||
}
|
||||
}
|
||||
|
||||
std::string CLIWrapper::dumpConfig(bool skipDefault /*= false*/) const {
|
||||
ABORT_IF(!unknownOpts.empty(), errorMsg + ": " + utils::join(unknownOpts, ", "));
|
||||
}
|
||||
|
||||
std::string CLIWrapper::dumpConfig(bool skipUnmodified /*= false*/) const {
|
||||
YAML::Emitter out;
|
||||
out << YAML::Comment("Marian configuration file generated at " + timer::currentDate()
|
||||
+ " with version " + buildVersion());
|
||||
out << YAML::BeginMap;
|
||||
std::string comment;
|
||||
// Iterate option names in the same order as they have been created
|
||||
for(const auto &key : getOrderedOptionNames()) {
|
||||
// do not proceed keys that are removed from config_
|
||||
// Do not dump options that were removed from config_
|
||||
if(!config_[key])
|
||||
continue;
|
||||
if(skipDefault && !options_.at(key).modified)
|
||||
// Do not dump options that were not passed via the command line
|
||||
if(skipUnmodified && options_.at(key).priority == cli::OptionPriority::DefaultValue)
|
||||
continue;
|
||||
// Put the group name as a comment before the first option in the group
|
||||
auto group = options_.at(key).opt->get_group();
|
||||
if(comment != group) {
|
||||
if(!comment.empty())
|
||||
|
@ -192,5 +264,12 @@ std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
|
|||
return keys;
|
||||
}
|
||||
|
||||
std::string CLIWrapper::failureMessage(const CLI::App *app, const CLI::Error &e) {
|
||||
std::string header = "Error: " + std::string(e.what()) + "\n";
|
||||
if(app->get_help_ptr() != nullptr)
|
||||
header += "Run with " + app->get_help_ptr()->get_name() + " for more information.\n";
|
||||
return header;
|
||||
}
|
||||
|
||||
} // namespace cli
|
||||
} // namespace marian
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace marian {
|
||||
|
||||
|
@ -16,28 +17,30 @@ class Options;
|
|||
|
||||
namespace cli {
|
||||
|
||||
// Try to determine the width of the terminal
|
||||
//
|
||||
// TODO: make use of it in the current CLI or remove. This is an old code used
|
||||
// for boost::program_options and might not be needed anymore.
|
||||
//static uint16_t guess_terminal_width(uint16_t max_width = 0,
|
||||
// uint16_t default_width = 180);
|
||||
|
||||
// TODO: use validators in ConfigParser
|
||||
namespace validators {
|
||||
const CLI::detail::ExistingFileValidator file_exists;
|
||||
const CLI::detail::ExistingDirectoryValidator dir_exists;
|
||||
const CLI::detail::ExistingPathValidator path_exists;
|
||||
|
||||
const CLI::detail::NonexistentPathValidator path_not_exists;
|
||||
|
||||
typedef CLI::Range range;
|
||||
}
|
||||
// Option priority
|
||||
enum struct OptionPriority : int { DefaultValue = 0, ConfigFile = 1, CommandLine = 2 };
|
||||
|
||||
/**
|
||||
* The helper class for cli::CLIWrapper handling formatting of options and their
|
||||
* descriptions.
|
||||
* Helper tuple storing an option object, the associated variable and creation index
|
||||
*
|
||||
* Note: bare pointers are used for CLI::Option objects as this comes from the CLI11 library.
|
||||
* Removing it would require deep modifications in the 3rd party library, what we want to avoid.
|
||||
*/
|
||||
struct CLIOptionTuple {
|
||||
CLI::Option *opt; // a pointer to an option object from CLI11
|
||||
Ptr<any_type> var; // value assigned to the option via command-line
|
||||
size_t idx{0}; // order in which the option was created
|
||||
OptionPriority priority{cli::OptionPriority::DefaultValue};
|
||||
};
|
||||
|
||||
// Helper tuple for aliases storing the alias name, value, and options to be expanded
|
||||
struct CLIAliasTuple {
|
||||
std::string key; // alias option name
|
||||
std::string value; // value for the alias option indicating that it should be expanded
|
||||
YAML::Node config; // config with options that the alias adds
|
||||
};
|
||||
|
||||
// The helper class for cli::CLIWrapper handling formatting of options and their descriptions.
|
||||
class CLIFormatter : public CLI::Formatter {
|
||||
public:
|
||||
CLIFormatter(size_t columnWidth, size_t screenWidth);
|
||||
|
@ -47,57 +50,39 @@ private:
|
|||
size_t screenWidth_{0};
|
||||
};
|
||||
|
||||
// @TODO: in this file review the use of naked pointers. We use Ptr<Type> anywhere else,
|
||||
// what's up with that?
|
||||
|
||||
/**
|
||||
* The helper structure storing an option object, the associated variable and creation index.
|
||||
*/
|
||||
struct CLIOptionTuple {
|
||||
CLI::Option *opt;
|
||||
Ptr<any_type> var;
|
||||
size_t idx{0};
|
||||
bool modified{false};
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief The class used to define and parse command-line arguments.
|
||||
*
|
||||
* It is a wrapper around https://github.com/CLIUtils/CLI11 that stores defined
|
||||
* command-line arguments in a YAML object.
|
||||
* It is a wrapper around https://github.com/CLIUtils/CLI11 that stores defined command-line
|
||||
* arguments in a YAML object.
|
||||
*
|
||||
* Usage outline: first call add() methods to create all the options; then call
|
||||
* parse(argv, argc) to parse command line and get defined options and their
|
||||
* values in a YAML object. The object can be also obtained later by calling
|
||||
* Usage outline: first call add() methods to create all the options; then call parse(argv, argc) to
|
||||
* parse command line and get defined options and their values in a YAML object; finally call
|
||||
* parseAliases() to expand alias options. The config object can be also obtained later by calling
|
||||
* getConfig().
|
||||
*
|
||||
* Options are organized in option groups. Each option group has a header that
|
||||
* preceeds all options in the group. The header for the default option group
|
||||
* can be set from the class constructor.
|
||||
* Options are organized in option groups. Each option group has a header that preceeds all options
|
||||
* in the group. The header for the default option group can be set from the class constructor.
|
||||
*/
|
||||
class CLIWrapper {
|
||||
private:
|
||||
// Map with option names and option tuples
|
||||
std::unordered_map<std::string, CLIOptionTuple> options_;
|
||||
// Counter for created options
|
||||
// Counter for created options to keep track of order in which options were created
|
||||
size_t counter_{0};
|
||||
// Command-line argument parser
|
||||
Ptr<CLI::App> app_;
|
||||
std::vector<CLIAliasTuple> aliases_; // List of alias tuples
|
||||
|
||||
// Name of the default option group
|
||||
std::string defaultGroup_{""};
|
||||
// Name of the current option group
|
||||
std::string currentGroup_{""};
|
||||
Ptr<CLI::App> app_; // Command-line argument parser from CLI11
|
||||
|
||||
// Reference to the main config object
|
||||
YAML::Node &config_;
|
||||
std::string defaultGroup_{""}; // Name of the default option group
|
||||
std::string currentGroup_{""}; // Name of the current option group
|
||||
|
||||
YAML::Node &config_; // Reference to the main config object
|
||||
|
||||
// Option for --version flag. This is a special flag and similarly to --help,
|
||||
// the key "version" will be not added into the YAML config
|
||||
CLI::Option *optVersion_;
|
||||
|
||||
static std::string failureMessage(const CLI::App *app, const CLI::Error &e);
|
||||
|
||||
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from
|
||||
// '--help,-h'
|
||||
std::string keyName(const std::string &args) const {
|
||||
|
@ -107,7 +92,15 @@ private:
|
|||
.front(); // get first long name
|
||||
}
|
||||
|
||||
// Get names of options passed via command-line
|
||||
std::unordered_set<std::string> getParsedOptionNames() const;
|
||||
// Get option names in the same order as they are created
|
||||
std::vector<std::string> getOrderedOptionNames() const;
|
||||
|
||||
static std::string failureMessage(const CLI::App *app, const CLI::Error &e);
|
||||
|
||||
public:
|
||||
|
||||
/**
|
||||
* @brief Create an instance of the command-line argument parser
|
||||
*
|
||||
|
@ -118,8 +111,7 @@ public:
|
|||
* @param header Header text for the main option group
|
||||
* @param footer Text displayed after the list of options
|
||||
* @param columnWidth Width of the column with option names
|
||||
* @param screenWidth Maximum allowed width for help messages, 0 means no
|
||||
* limit
|
||||
* @param screenWidth Maximum allowed width for help messages, 0 means no limit
|
||||
*/
|
||||
CLIWrapper(YAML::Node &config,
|
||||
const std::string &description = "",
|
||||
|
@ -128,24 +120,13 @@ public:
|
|||
size_t columnWidth = 40,
|
||||
size_t screenWidth = 0);
|
||||
|
||||
/**
|
||||
* @brief Create an instance of the command-line argument parser,
|
||||
* short-cuft for Options object.
|
||||
*
|
||||
* @see Other constructor
|
||||
*/
|
||||
CLIWrapper(Ptr<Options> options,
|
||||
const std::string &description = "",
|
||||
const std::string &header = "General options",
|
||||
const std::string &footer = "",
|
||||
size_t columnWidth = 30,
|
||||
size_t screenWidth = 0);
|
||||
|
||||
virtual ~CLIWrapper();
|
||||
|
||||
/**
|
||||
* @brief Define an option with a default value
|
||||
*
|
||||
* Explicit default values will appear in help messages.
|
||||
*
|
||||
* @param args Comma-separated list of short and long option names
|
||||
* @param help Help message
|
||||
* @param val Default value
|
||||
|
@ -154,109 +135,121 @@ public:
|
|||
*/
|
||||
template <typename T>
|
||||
CLI::Option *add(const std::string &args, const std::string &help, T val) {
|
||||
return add_option<T>(keyName(args),
|
||||
return addOption<T>(keyName(args),
|
||||
args,
|
||||
help,
|
||||
val,
|
||||
/*defaulted =*/true,
|
||||
/*addToConfig =*/true);
|
||||
/*defaulted =*/true);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Define an option without an explicit default value. The implicit
|
||||
* default value is T()
|
||||
* @brief Define an option without an explicit default value. The implicit default value is T()
|
||||
*
|
||||
* The option will be defined in the config file even if not given as a
|
||||
* command-line argument. The implicit default value for a boolean or numeric
|
||||
* option is 0, for a string is an empty string, and for a vector is an empty
|
||||
* vector.
|
||||
* The option will be defined in the config file even if not given as a command-line argument. The
|
||||
* implicit default value for a boolean or numeric option is 0, for a string is an empty string,
|
||||
* and for a vector is an empty vector.
|
||||
*
|
||||
* Implicit default values will *NOT* appear in help messages.
|
||||
*
|
||||
* @param args Comma-separated list of short and long option names
|
||||
* @param help Help message
|
||||
*
|
||||
* @return Option object
|
||||
*
|
||||
* TODO: require to always state the default value creating the parser as this
|
||||
* will be clearer
|
||||
* @TODO: require to always state the default value creating the parser as this will be clearer
|
||||
*/
|
||||
template <typename T>
|
||||
CLI::Option *add(const std::string &args, const std::string &help) {
|
||||
return add_option<T>(keyName(args),
|
||||
return addOption<T>(keyName(args),
|
||||
args,
|
||||
help,
|
||||
T(),
|
||||
/*defaulted =*/false,
|
||||
/*addToConfig =*/true);
|
||||
/*defaulted =*/false);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Define a non-defaulted option
|
||||
* @brief Transform a command line option into an alias. This alias will set other options later.
|
||||
*
|
||||
* The option will not be present in the config file unless given as a
|
||||
* command-line argument.
|
||||
* An alias sets one or more options to predefined values. The options expanded by the alias are
|
||||
* provided as a function setting a temporary YAML config.
|
||||
*
|
||||
* @param args Comma-separated list of short and long option names
|
||||
* @param help Help message
|
||||
* The alias option has to be first defined using `add<T>()`. Otherwise, the program will abort.
|
||||
*
|
||||
* @return Option object
|
||||
* Defining more than one alias for the same `key` but different `value` is allowed.
|
||||
*
|
||||
* @TODO: consider removing this method during final refactorization of
|
||||
* command-line/config parsers in the future as all options should either
|
||||
* have a default value or be non-defaulted
|
||||
* Option values are compared as std::string. If the alias option is a vector, the alias will be
|
||||
* triggered if `value` exists in that vector at least once.
|
||||
*
|
||||
* Options set directly via command line have precedence over options defined in an alias, i.e. an
|
||||
* option added via alias can be overwritten by setting a specific option via command line.
|
||||
*
|
||||
* @param key Alias option name
|
||||
* @param value Option value that trigger the alias
|
||||
* @param fun Function setting a temporary YAML config with options expanded by alias
|
||||
*/
|
||||
template <typename T>
|
||||
CLI::Option *add_nondefault(const std::string &args, const std::string &help) {
|
||||
return add_option<T>(keyName(args),
|
||||
args,
|
||||
help,
|
||||
T(),
|
||||
/*defaulted =*/false,
|
||||
/*addToConfig =*/false);
|
||||
void alias(const std::string &key,
|
||||
const std::string &value,
|
||||
const std::function<void(YAML::Node &config)> &fun) {
|
||||
ABORT_IF(!options_.count(key), "Option '{}' is not defined so alias can not be created", key);
|
||||
aliases_.resize(aliases_.size() + 1);
|
||||
aliases_.back().key = key;
|
||||
aliases_.back().value = value;
|
||||
fun(aliases_.back().config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Switch to different option group or to the default group if argument is empty.
|
||||
*
|
||||
* @param name Header of the option group
|
||||
* @return Previous group.
|
||||
*/
|
||||
void switchGroup(const std::string &name = "");
|
||||
std::string switchGroup(std::string name = "");
|
||||
|
||||
// Parse command-line arguments. Handles --help and --version options
|
||||
void parse(int argc, char **argv);
|
||||
|
||||
/*
|
||||
* @brief Overwrite values for unparsed options
|
||||
/**
|
||||
* @brief Expand aliases based on arguments parsed with parse(int, char**)
|
||||
*
|
||||
* Should be called after parse(int, char**) to take an effect. If any alias tries to expand an
|
||||
* undefined option, the method will abort the program.
|
||||
*
|
||||
* All options defined as aliases are removed from the global config object to avoid redundancy
|
||||
* when options are dumped (explicitly or implicitly) to a config file.
|
||||
*/
|
||||
void parseAliases();
|
||||
|
||||
/**
|
||||
* @brief Overwrite options with lower priority
|
||||
*
|
||||
* Values for options with lower priority than the provided priority remain unchanged. This allows
|
||||
* for overwritting default options by options from config files, or both by options provided in
|
||||
* the command line.
|
||||
*
|
||||
* Default values are overwritten with the options from the config provided, while parsed
|
||||
* command-line options remain unchanged.
|
||||
* This should be a preferred way of updating config options as the class keeps track of options,
|
||||
* which values have changed.
|
||||
*
|
||||
* @param node YAML config with new default values for options
|
||||
* @param config YAML config with new default values for options
|
||||
* @param priority priority of incoming options
|
||||
* @param errorMsg error message printed if config contains undefined keys. The message is
|
||||
* appended with ": <comma-separated list of invalid options>"
|
||||
*/
|
||||
bool updateConfig(const YAML::Node &config);
|
||||
void updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg);
|
||||
|
||||
// Get textual YAML representation of the config
|
||||
std::string dumpConfig(bool skipDefault = false) const;
|
||||
std::string dumpConfig(bool skipUnmodified = false) const;
|
||||
|
||||
private:
|
||||
// Get names of options passed via command-line
|
||||
std::unordered_set<std::string> getParsedOptionNames() const;
|
||||
// Get option names in the same order as they are created
|
||||
std::vector<std::string> getOrderedOptionNames() const;
|
||||
|
||||
template <typename T,
|
||||
// options with numeric and string-like values
|
||||
CLI::enable_if_t<!CLI::is_bool<T>::value && !CLI::is_vector<T>::value,
|
||||
CLI::detail::enabler> = CLI::detail::dummy>
|
||||
CLI::Option *add_option(const std::string &key,
|
||||
CLI::Option *addOption(const std::string &key,
|
||||
const std::string &args,
|
||||
const std::string &help,
|
||||
T val,
|
||||
bool defaulted,
|
||||
bool addToConfig) {
|
||||
// define YAML entry if requested
|
||||
if(addToConfig)
|
||||
bool defaulted) {
|
||||
// add key to YAML
|
||||
config_[key] = val;
|
||||
|
||||
// create option tuple
|
||||
|
@ -266,7 +259,7 @@ private:
|
|||
|
||||
// callback function collecting a command-line argument
|
||||
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
||||
options_[key].modified = true;
|
||||
options_[key].priority = cli::OptionPriority::CommandLine;
|
||||
// get variable associated with the option
|
||||
auto &var = options_[key].var->as<T>();
|
||||
// store parser result in var
|
||||
|
@ -298,14 +291,12 @@ private:
|
|||
template <typename T,
|
||||
// options with vector values
|
||||
CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
|
||||
CLI::Option *add_option(const std::string &key,
|
||||
CLI::Option *addOption(const std::string &key,
|
||||
const std::string &args,
|
||||
const std::string &help,
|
||||
T val,
|
||||
bool defaulted,
|
||||
bool addToConfig) {
|
||||
// define YAML entry if requested
|
||||
if(addToConfig)
|
||||
bool defaulted) {
|
||||
// add key to YAML
|
||||
config_[key] = val;
|
||||
|
||||
// create option tuple
|
||||
|
@ -315,7 +306,7 @@ private:
|
|||
|
||||
// callback function collecting command-line arguments
|
||||
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
||||
options_[key].modified = true;
|
||||
options_[key].priority = cli::OptionPriority::CommandLine;
|
||||
// get vector variable associated with the option
|
||||
auto &vec = options_[key].var->as<T>();
|
||||
vec.clear();
|
||||
|
@ -357,14 +348,12 @@ private:
|
|||
template <typename T,
|
||||
// options with boolean values, called flags in CLI11
|
||||
CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
|
||||
CLI::Option *add_option(const std::string &key,
|
||||
CLI::Option *addOption(const std::string &key,
|
||||
const std::string &args,
|
||||
const std::string &help,
|
||||
T val,
|
||||
bool defaulted,
|
||||
bool addToConfig) {
|
||||
// define YAML entry if requested
|
||||
if(addToConfig)
|
||||
bool defaulted) {
|
||||
// add key to YAML
|
||||
config_[key] = val;
|
||||
|
||||
// create option tuple
|
||||
|
@ -374,7 +363,7 @@ private:
|
|||
|
||||
// callback function setting the flag
|
||||
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
||||
options_[key].modified = true;
|
||||
options_[key].priority = cli::OptionPriority::CommandLine;
|
||||
// get parser result, it is safe as boolean options have an implicit value
|
||||
auto val = res[0];
|
||||
auto ret = true;
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#include "common/config.h"
|
||||
#include "common/config_parser.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/options.h"
|
||||
#include "common/regex.h"
|
||||
#include "common/utils.h"
|
||||
#include "common/version.h"
|
||||
#include "common/regex.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
@ -14,35 +16,26 @@ namespace marian {
|
|||
// @TODO: keep seed in a single place, now it is kept here and in Config/Options
|
||||
size_t Config::seed = (size_t)time(0);
|
||||
|
||||
Config::Config(int argc,
|
||||
char** argv,
|
||||
cli::mode mode /*= cli::mode::training*/,
|
||||
bool validate /*= true*/) {
|
||||
initialize(argc, argv, mode, validate);
|
||||
|
||||
Config::Config(ConfigParser const& cp) {
|
||||
initialize(cp);
|
||||
}
|
||||
|
||||
Config::Config(const Config& other) : config_(YAML::Clone(other.config_)) {}
|
||||
Config::Config(const Options& options) : config_(YAML::Clone(options.getYaml())) {}
|
||||
Config::Config(int argc, char** argv, cli::mode mode, bool validate /*= true*/)
|
||||
: Config(ConfigParser(argc, argv, mode, validate)) {}
|
||||
|
||||
void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
|
||||
auto parser = ConfigParser(argc, argv, mode, validate);
|
||||
config_ = parser.getConfig();
|
||||
Config::Config(const Config& other) : config_(YAML::Clone(other.config_)) {}
|
||||
Config::Config(const Options& options) : config_(options.cloneToYamlNode()) {}
|
||||
|
||||
void Config::initialize(ConfigParser const& cp) {
|
||||
config_ = YAML::Clone(cp.getConfig());
|
||||
cli::mode mode = cp.getMode();
|
||||
|
||||
createLoggers(this);
|
||||
|
||||
// echo version and command line
|
||||
LOG(info, "[marian] Marian {}", buildVersion());
|
||||
std::string cmdLine;
|
||||
for (int i = 0; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
std::string quote; // attempt to quote special chars
|
||||
if (arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
|
||||
quote = "'";
|
||||
arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
|
||||
if (!cmdLine.empty())
|
||||
cmdLine.push_back(' ');
|
||||
cmdLine += quote + arg + quote;
|
||||
}
|
||||
std::string cmdLine = cp.cmdLine();
|
||||
std::string hostname; int pid; std::tie
|
||||
(hostname, pid) = utils::hostnameAndProcessId();
|
||||
LOG(info, "[marian] Running on {} as process {} with command line:", hostname, pid);
|
||||
|
@ -56,7 +49,17 @@ void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
|
|||
}
|
||||
|
||||
// load model parameters
|
||||
if(mode != cli::mode::translation) {
|
||||
if(mode == cli::mode::translation || mode == cli::mode::server) {
|
||||
auto model = get<std::vector<std::string>>("models")[0];
|
||||
try {
|
||||
if(!get<bool>("ignore-model-config"))
|
||||
loadModelParameters(model);
|
||||
} catch(std::runtime_error& ) {
|
||||
LOG(info, "[config] No model configuration found in model file");
|
||||
}
|
||||
}
|
||||
// if cli::mode::training or cli::mode::scoring
|
||||
else {
|
||||
auto model = get<std::string>("model");
|
||||
if(filesystem::exists(model) && !get<bool>("no-reload")) {
|
||||
try {
|
||||
|
@ -67,16 +70,6 @@ void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// if cli::mode::translation
|
||||
else {
|
||||
auto model = get<std::vector<std::string>>("models")[0];
|
||||
try {
|
||||
if(!get<bool>("ignore-model-config"))
|
||||
loadModelParameters(model);
|
||||
} catch(std::runtime_error& ) {
|
||||
LOG(info, "[config] No model configuration found in model file");
|
||||
}
|
||||
}
|
||||
|
||||
// echo full configuration
|
||||
log();
|
||||
|
@ -95,15 +88,14 @@ void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
|
|||
version,
|
||||
buildVersion());
|
||||
else
|
||||
LOG(info,
|
||||
"[config] Loaded model has been created with Marian {}",
|
||||
version);
|
||||
LOG(info, "[config] Loaded model has been created with Marian {}", version);
|
||||
|
||||
// Remove "version" from config to make it consistent among different start-up scenarios
|
||||
config_.remove("version");
|
||||
}
|
||||
// If this is a newly started training
|
||||
else if(mode == cli::mode::training) {
|
||||
LOG(info,
|
||||
"[config] Model is being created with Marian {}",
|
||||
buildVersion());
|
||||
LOG(info, "[config] Model is being created with Marian {}", buildVersion());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,10 +148,9 @@ void Config::log() {
|
|||
std::string configString = out.c_str();
|
||||
|
||||
// print YAML prepending each line with [config]
|
||||
std::vector<std::string> results;
|
||||
utils::split(configString, results, "\n");
|
||||
for(auto& r : results)
|
||||
LOG(info, "[config] {}", r);
|
||||
auto lines = utils::split(configString, "\n");
|
||||
for(auto& line : lines)
|
||||
LOG(info, "[config] {}", line);
|
||||
}
|
||||
|
||||
// Parse the device-spec parameters (--num-devices, --devices, --cpu-threads) into an array of
|
||||
|
@ -264,14 +255,17 @@ std::vector<DeviceId> Config::getDevices(Ptr<Options> options,
|
|||
return devices;
|
||||
}
|
||||
|
||||
Ptr<Options> parseOptions(int argc,
|
||||
char** argv,
|
||||
cli::mode mode /*= cli::mode::training*/,
|
||||
bool validate /*= true*/) {
|
||||
auto config = New<Config>(argc, argv, mode, validate);
|
||||
auto options = New<Options>();
|
||||
options->merge(config->get());
|
||||
return options;
|
||||
Ptr<Options>
|
||||
parseOptions(int argc, char** argv, cli::mode mode, bool validate){
|
||||
ConfigParser cp(mode);
|
||||
return cp.parseOptions(argc, argv, validate);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Config& config) {
|
||||
YAML::Emitter outYaml;
|
||||
cli::OutputYaml(config.get(), outYaml);
|
||||
out << outYaml.c_str();
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
|
|
@ -38,6 +38,7 @@ public:
|
|||
|
||||
typedef YAML::Node YamlNode;
|
||||
|
||||
Config(ConfigParser const& cp);
|
||||
// TODO: remove mode from this class
|
||||
Config(int argc,
|
||||
char** argv,
|
||||
|
@ -47,7 +48,7 @@ public:
|
|||
Config(const Config& other);
|
||||
Config(const Options& options);
|
||||
|
||||
void initialize(int argc, char** argv, cli::mode mode, bool validate);
|
||||
void initialize(ConfigParser const& cp);
|
||||
|
||||
bool has(const std::string& key) const;
|
||||
|
||||
|
@ -83,12 +84,7 @@ public:
|
|||
|
||||
void save(const std::string& name);
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const Config& config) {
|
||||
YAML::Emitter outYaml;
|
||||
cli::OutputYaml(config.get(), outYaml);
|
||||
out << outYaml.c_str();
|
||||
return out;
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& out, const Config& config);
|
||||
|
||||
static std::vector<DeviceId> getDevices(Ptr<Options> options,
|
||||
size_t myMPIRank = 0,
|
||||
|
@ -115,7 +111,7 @@ private:
|
|||
*/
|
||||
Ptr<Options> parseOptions(int argc,
|
||||
char** argv,
|
||||
cli::mode mode = cli::mode::training,
|
||||
cli::mode mode,
|
||||
bool validate = true);
|
||||
|
||||
} // namespace marian
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
#include "common/config_parser.h"
|
||||
|
||||
#include "common/definitions.h"
|
||||
#include "common/authors.h"
|
||||
#include "common/build_info.h"
|
||||
#include "common/cli_helper.h"
|
||||
#include "common/config.h"
|
||||
#include "common/config_parser.h"
|
||||
#include "common/config_validator.h"
|
||||
#include "common/definitions.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/options.h"
|
||||
#include "common/regex.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
|
@ -22,7 +25,8 @@
|
|||
|
||||
namespace marian {
|
||||
|
||||
// TODO: move to CLIWrapper
|
||||
// TODO: Move this to CLIWrapper and allow to mark options as paths in the same place they are
|
||||
// defined
|
||||
// clang-format off
|
||||
const std::set<std::string> PATHS = {
|
||||
"model",
|
||||
|
@ -32,6 +36,7 @@ const std::set<std::string> PATHS = {
|
|||
"embedding-vectors",
|
||||
"valid-sets",
|
||||
"valid-script-path",
|
||||
"valid-script-args",
|
||||
"valid-log",
|
||||
"valid-translation-output",
|
||||
"input", // except: stdin
|
||||
|
@ -47,23 +52,80 @@ const std::set<std::string> PATHS = {
|
|||
};
|
||||
// clang-format on
|
||||
|
||||
std::string escapeCmdLine(int argc, char** argv){
|
||||
std::string cmdLine;
|
||||
for(int i = 0; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
std::string quote; // attempt to quote special chars
|
||||
if(arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
|
||||
quote = "'";
|
||||
arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
|
||||
if(!cmdLine.empty())
|
||||
cmdLine.push_back(' ');
|
||||
cmdLine += quote + arg + quote;
|
||||
}
|
||||
return cmdLine;
|
||||
}
|
||||
|
||||
std::string const& ConfigParser::cmdLine() const {
|
||||
return cmdLine_;
|
||||
}
|
||||
|
||||
ConfigParser::ConfigParser(cli::mode mode)
|
||||
: cli_(config_,"Marian: Fast Neural Machine Translation in C++",
|
||||
"General options", "", 40),
|
||||
mode_(mode == cli::mode::server ? cli::mode::translation : mode) {
|
||||
|
||||
addOptionsGeneral(cli_);
|
||||
if (mode == cli::mode::server)
|
||||
addOptionsServer(cli_);
|
||||
addOptionsModel(cli_);
|
||||
|
||||
// clang-format off
|
||||
switch(mode_) {
|
||||
case cli::mode::training:
|
||||
addOptionsTraining(cli_);
|
||||
addOptionsValidation(cli_);
|
||||
break;
|
||||
case cli::mode::translation:
|
||||
addOptionsTranslation(cli_);
|
||||
break;
|
||||
case cli::mode::scoring:
|
||||
addOptionsScoring(cli_);
|
||||
break;
|
||||
default:
|
||||
ABORT("wrong CLI mode");
|
||||
break;
|
||||
}
|
||||
|
||||
addAliases(cli_);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
|
||||
int defaultWorkspace = (mode_ == cli::mode::translation) ? 512 : 2048;
|
||||
|
||||
cli.switchGroup("General options");
|
||||
|
||||
// clang-format off
|
||||
cli.add<bool>("--authors",
|
||||
"Print list of authors and exit");
|
||||
cli.add<bool>("--cite",
|
||||
"Print citation and exit");
|
||||
cli.add<std::string>("--build-info",
|
||||
"Print CMake build options and exit. Set to 'all' to print advanced options")
|
||||
->implicit_val("basic");
|
||||
cli.add<std::vector<std::string>>("--config,-c",
|
||||
"Configuration file(s). If multiple, later overrides earlier");
|
||||
cli.add<size_t>("--workspace,-w",
|
||||
"Preallocate arg MB of work space",
|
||||
defaultWorkspace);
|
||||
cli.add_nondefault<std::string>("--log",
|
||||
cli.add<std::string>("--log",
|
||||
"Log training process information to file given by arg");
|
||||
cli.add<std::string>("--log-level",
|
||||
"Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off",
|
||||
"info");
|
||||
cli.add_nondefault<std::string>("--log-time-zone",
|
||||
cli.add<std::string>("--log-time-zone",
|
||||
"Set time zone for the date shown on logging");
|
||||
cli.add<bool>("--quiet",
|
||||
"Suppress all logging to stderr. Logging to files still works");
|
||||
|
@ -77,14 +139,24 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
|
|||
"allow the use of environment variables in paths, of the form ${VAR_NAME}");
|
||||
cli.add<bool>("--relative-paths",
|
||||
"All paths are relative to the config file location");
|
||||
cli.add_nondefault<std::string>("--dump-config",
|
||||
"Dump current (modified) configuration to stdout and exit. Possible values: full, minimal")
|
||||
cli.add<std::string>("--dump-config",
|
||||
"Dump current (modified) configuration to stdout and exit. Possible values: full, minimal, expand")
|
||||
->implicit_val("full");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsServer(cli::CLIWrapper& cli) {
|
||||
// clang-format off
|
||||
auto previous_group = cli.switchGroup("Server options");
|
||||
cli.add<size_t>("--port,-p",
|
||||
"Port number for web socket server",
|
||||
8080);
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Model options");
|
||||
auto previous_group = cli.switchGroup("Model options");
|
||||
|
||||
// clang-format off
|
||||
if(mode_ == cli::mode::translation) {
|
||||
|
@ -96,7 +168,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||
"model.npz");
|
||||
|
||||
if(mode_ == cli::mode::training) {
|
||||
cli.add_nondefault<std::string>("--pretrained-model",
|
||||
cli.add<std::string>("--pretrained-model",
|
||||
"Path prefix for pre-trained model to initialize model weights");
|
||||
}
|
||||
}
|
||||
|
@ -108,10 +180,13 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||
"amun");
|
||||
cli.add<std::vector<int>>("--dim-vocabs",
|
||||
"Maximum items in vocabulary ordered by rank, 0 uses all items in the provided/created vocabulary file",
|
||||
std::vector<int>({0, 0}));
|
||||
{0, 0});
|
||||
cli.add<int>("--dim-emb",
|
||||
"Size of embedding vector",
|
||||
512);
|
||||
cli.add<int>("--lemma-dim-emb",
|
||||
"Re-embedding dimension of lemma in factors",
|
||||
0);
|
||||
cli.add<int>("--dim-rnn",
|
||||
"Size of rnn hidden state", 1024);
|
||||
cli.add<std::string>("--enc-type",
|
||||
|
@ -143,10 +218,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||
"Enable layer normalization");
|
||||
cli.add<bool>("--right-left",
|
||||
"Train right-to-left model");
|
||||
cli.add<std::vector<std::string>>("--input-types",
|
||||
"Provide type of input data if different than 'sequence'. "
|
||||
"Possible values: sequence, class. You need to provide one type per input.",
|
||||
{});
|
||||
cli.add<bool>("--best-deep",
|
||||
"Use Edinburgh deep RNN configuration (s2s)");
|
||||
cli.add_nondefault<std::vector<size_t>>("--special-vocab",
|
||||
"Model-specific special vocabulary ids");
|
||||
cli.add<bool>("--tied-embeddings",
|
||||
"Tie target embeddings and output embeddings in output layer");
|
||||
cli.add<bool>("--tied-embeddings-src",
|
||||
|
@ -196,7 +273,17 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||
cli.add<std::string>("--transformer-postprocess",
|
||||
"Operation after each transformer layer: d = dropout, a = add, n = normalize",
|
||||
"dan");
|
||||
cli.add<bool>("--transformer-train-position-embeddings",
|
||||
"Train positional embeddings instead of using static sinusoidal embeddings");
|
||||
cli.add<bool>("--transformer-depth-scaling",
|
||||
"Scale down weight initialization in transformer layers by 1 / sqrt(depth)");
|
||||
|
||||
cli.add<std::string>("--bert-mask-symbol", "Masking symbol for BERT masked-LM training", "[MASK]");
|
||||
cli.add<std::string>("--bert-sep-symbol", "Sentence separator symbol for BERT next sentence prediction training", "[SEP]");
|
||||
cli.add<std::string>("--bert-class-symbol", "Class symbol BERT classifier training", "[CLS]");
|
||||
cli.add<float>("--bert-masking-fraction", "Fraction of masked out tokens during training", 0.15f);
|
||||
cli.add<bool>("--bert-train-type-embeddings", "Train bert type embeddings, set to false to use static sinusoidal embeddings", true);
|
||||
cli.add<int>("--bert-type-vocab-size", "Size of BERT type vocab (sentence A and B)", 2);
|
||||
#ifdef CUDNN
|
||||
cli.add<int>("--char-stride",
|
||||
"Width of max-pooling layer after convolution layer in char-s2s model",
|
||||
|
@ -205,11 +292,11 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||
"Number of highway network layers after max-pooling in char-s2s model",
|
||||
4);
|
||||
cli.add<std::vector<int>>("--char-conv-filters-num",
|
||||
"Numbers of convolution filters of correspoding width in char-s2s model",
|
||||
std::vector<int>({200, 200, 250, 250, 300, 300, 300, 300}));
|
||||
"Numbers of convolution filters of corresponding width in char-s2s model",
|
||||
{200, 200, 250, 250, 300, 300, 300, 300});
|
||||
cli.add<std::vector<int>>("--char-conv-filters-widths",
|
||||
"Convolution window widths in char-s2s model",
|
||||
std::vector<int>({1, 2, 3, 4, 5, 6, 7, 8}));
|
||||
{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
#endif
|
||||
|
||||
if(mode_ == cli::mode::training) {
|
||||
|
@ -234,14 +321,19 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||
cli.add<float>("--transformer-dropout-ffn",
|
||||
"Dropout for transformer filter (0 = no dropout)");
|
||||
}
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Training options");
|
||||
auto previous_group = cli.switchGroup("Training options");
|
||||
// clang-format off
|
||||
cli.add<std::string>("--cost-type",
|
||||
cli.add<std::string>("--cost-type", // @TODO: rename to loss-type
|
||||
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean");
|
||||
cli.add<std::string>("--multi-loss-type",
|
||||
"How to accumulate multi-objective losses: sum, scaled, mean", "sum");
|
||||
cli.add<bool>("--unlikelihood-loss",
|
||||
"Use word-level weights as indicators for sequence-level unlikelihood training");
|
||||
cli.add<bool>("--overwrite",
|
||||
"Do not create model checkpoints, only overwrite main model file with last checkpoint. "
|
||||
"Reduces disk usage");
|
||||
|
@ -273,9 +365,11 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||
"Display information every arg updates (append 't' for every arg target labels)",
|
||||
"1000u");
|
||||
cli.add<size_t>("--disp-first",
|
||||
"Display nformation for the first arg updates");
|
||||
"Display information for the first arg updates");
|
||||
cli.add<bool>("--disp-label-counts",
|
||||
"Display label counts when logging loss progress");
|
||||
// cli.add<int>("--disp-label-index",
|
||||
// "Display label counts based on i-th input stream (-1 is last)", -1);
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
|
||||
"Save model file every arg updates (append 't' for every arg target labels)",
|
||||
"10000u");
|
||||
|
@ -283,8 +377,12 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||
addSuboptionsInputLength(cli);
|
||||
|
||||
// data management options
|
||||
cli.add<std::string>("--shuffle",
|
||||
"How to shuffle input data (data: shuffles data and sorted batches; batches: "
|
||||
"data is read in order into batches, but batches are shuffled; none: no shuffling). "
|
||||
"Use with '--maxi-batch-sort none' in order to achieve exact reading order", "data");
|
||||
cli.add<bool>("--no-shuffle",
|
||||
"Skip shuffling of training data before each epoch");
|
||||
"Shortcut for backwards compatiblity, equivalent to --shuffle none (deprecated)");
|
||||
cli.add<bool>("--no-restore-corpus",
|
||||
"Skip restoring corpus state after training is restarted");
|
||||
cli.add<std::string>("--tempdir,-T",
|
||||
|
@ -304,30 +402,33 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||
cli.add<std::string>("--optimizer,-o",
|
||||
"Optimization algorithm: sgd, adagrad, adam",
|
||||
"adam");
|
||||
cli.add_nondefault<std::vector<float>>("--optimizer-params",
|
||||
"Parameters for optimization algorithm, e.g. betas for adam");
|
||||
cli.add<size_t>("--optimizer-delay",
|
||||
"SGD update delay, 1 = no delay",
|
||||
1);
|
||||
cli.add<std::vector<float>>("--optimizer-params",
|
||||
"Parameters for optimization algorithm, e.g. betas for Adam. "
|
||||
"Auto-adjusted to --mini-batch-words-ref if given");
|
||||
cli.add<float>("--optimizer-delay",
|
||||
"SGD update delay (#batches between updates). 1 = no delay. "
|
||||
"Can be fractional, e.g. 0.1 to use only 10% of each batch",
|
||||
1.f);
|
||||
|
||||
cli.add<bool>("--sync-sgd",
|
||||
"Use synchronous SGD instead of asynchronous for multi-gpu training");
|
||||
|
||||
// learning rate options
|
||||
cli.add<double>("--learn-rate,-l",
|
||||
"Learning rate",
|
||||
0.0001);
|
||||
cli.add<float>("--learn-rate,-l",
|
||||
"Learning rate. "
|
||||
"Auto-adjusted to --mini-batch-words-ref if given",
|
||||
0.0001f);
|
||||
cli.add<bool>("--lr-report",
|
||||
"Report learning rate for each update");
|
||||
|
||||
cli.add<double>("--lr-decay",
|
||||
cli.add<float>("--lr-decay",
|
||||
"Per-update decay factor for learning rate: lr <- lr * arg (0 to disable)");
|
||||
cli.add<std::string>("--lr-decay-strategy",
|
||||
"Strategy for learning rate decaying: epoch, batches, stalled, epoch+batches, epoch+stalled",
|
||||
"epoch+stalled");
|
||||
cli.add<std::vector<size_t>>("--lr-decay-start",
|
||||
"The first number of (epoch, batches, stalled) validations to start learning rate decaying (tuple)",
|
||||
std::vector<size_t>({10,1}));
|
||||
{10,1});
|
||||
cli.add<size_t>("--lr-decay-freq",
|
||||
"Learning rate decaying frequency for batches, requires --lr-decay-strategy to be batches",
|
||||
50000);
|
||||
|
@ -335,9 +436,10 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||
"Reset running statistics of optimizer whenever learning rate decays");
|
||||
cli.add<bool>("--lr-decay-repeat-warmup",
|
||||
"Repeat learning rate warmup when learning rate is decayed");
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--lr-decay-inv-sqrt",
|
||||
"Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs))",
|
||||
"0");
|
||||
cli.add<std::vector<std::string/*SchedulerPeriod*/>>("--lr-decay-inv-sqrt",
|
||||
"Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). "
|
||||
"Add second argument to define the starting point (default: same as first value)",
|
||||
{"0"});
|
||||
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--lr-warmup",
|
||||
"Increase learning rate linearly for arg first batches (append 't' for arg first target labels)",
|
||||
|
@ -351,12 +453,15 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||
|
||||
cli.add<double>("--label-smoothing",
|
||||
"Epsilon for label smoothing (0 to disable)");
|
||||
cli.add<double>("--clip-norm",
|
||||
"Clip gradient norm to argcli.add<int>(0 to disable)",
|
||||
1.f);
|
||||
cli.add<double>("--factor-weight",
|
||||
"Weight for loss function for factors (factored vocab only) (1 to disable)", 1.0f);
|
||||
cli.add<float>("--clip-norm",
|
||||
"Clip gradient norm to arg (0 to disable)",
|
||||
1.f); // @TODO: this is currently wrong with ce-sum and should rather be disabled or fixed by multiplying with labels
|
||||
cli.add<float>("--exponential-smoothing",
|
||||
"Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable",
|
||||
0)->implicit_val("1e-4");
|
||||
"Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable. "
|
||||
"Auto-adjusted to --mini-batch-words-ref if given.",
|
||||
0.f)->implicit_val("1e-4");
|
||||
cli.add<std::string>("--guided-alignment",
|
||||
"Path to a file with word alignments. Use guided alignment to guide attention or 'none'",
|
||||
"none");
|
||||
|
@ -366,14 +471,14 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||
cli.add<double>("--guided-alignment-weight",
|
||||
"Weight for guided alignment cost",
|
||||
0.1);
|
||||
cli.add_nondefault<std::string>("--data-weighting",
|
||||
cli.add<std::string>("--data-weighting",
|
||||
"Path to a file with sentence or word weights");
|
||||
cli.add<std::string>("--data-weighting-type",
|
||||
"Processing level for data weighting: sentence, word",
|
||||
"sentence");
|
||||
|
||||
// embedding options
|
||||
cli.add_nondefault<std::vector<std::string>>("--embedding-vectors",
|
||||
cli.add<std::vector<std::string>>("--embedding-vectors",
|
||||
"Paths to files with custom source and target embedding vectors");
|
||||
cli.add<bool>("--embedding-normalization",
|
||||
"Normalize values from custom embedding vectors to [-1, 1]");
|
||||
|
@ -382,21 +487,40 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||
cli.add<bool>("--embedding-fix-trg",
|
||||
"Fix target embeddings. Affects all decoders");
|
||||
|
||||
// mixed precision training
|
||||
cli.add<bool>("--fp16",
|
||||
"Shortcut for mixed precision training with float16 and cost-scaling, "
|
||||
"corresponds to: --precision float16 float32 float32 --cost-scaling 7 2000 2 0.05 10 1");
|
||||
cli.add<std::vector<std::string>>("--precision",
|
||||
"Mixed precision training for forward/backward pass and optimizaton. "
|
||||
"Defines types for: forward/backward, optimization, saving.",
|
||||
{"float32", "float32", "float32"});
|
||||
cli.add<std::vector<std::string>>("--cost-scaling",
|
||||
"Dynamic cost scaling for mixed precision training: "
|
||||
"power of 2, scaling window, scaling factor, tolerance, range, minimum factor")->implicit_val("7.f 2000 2.f 0.05f 10 1.f");
|
||||
cli.add<bool>("--normalize-gradient", "Normalize gradient by multiplying with no. devices / total labels");
|
||||
|
||||
// multi-node training
|
||||
cli.add<bool>("--multi-node",
|
||||
"Enable asynchronous multi-node training through MPI (and legacy sync if combined with --sync-sgd)");
|
||||
cli.add<bool>("--multi-node-overlap",
|
||||
"Overlap model computations with MPI communication",
|
||||
true);
|
||||
|
||||
// add ULR settings
|
||||
addSuboptionsULR(cli);
|
||||
|
||||
cli.add<std::vector<std::string>>("--task",
|
||||
"Use predefined set of options. Possible values: transformer, transformer-big");
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Validation set options");
|
||||
auto previous_group = cli.switchGroup("Validation set options");
|
||||
|
||||
// clang-format off
|
||||
cli.add_nondefault<std::vector<std::string>>("--valid-sets",
|
||||
cli.add<std::vector<std::string>>("--valid-sets",
|
||||
"Paths to validation corpora: source target");
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--valid-freq",
|
||||
"Validate model every arg updates (append 't' for every arg target labels)",
|
||||
|
@ -404,7 +528,9 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
|||
cli.add<std::vector<std::string>>("--valid-metrics",
|
||||
"Metric to use during validation: cross-entropy, ce-mean-words, perplexity, valid-script, "
|
||||
"translation, bleu, bleu-detok. Multiple metrics can be specified",
|
||||
std::vector<std::string>({"cross-entropy"}));
|
||||
{"cross-entropy"});
|
||||
cli.add<bool>("--valid-reset-stalled",
|
||||
"Reset all stalled validation metrics when the training is restarted");
|
||||
cli.add<size_t>("--early-stopping",
|
||||
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
||||
10);
|
||||
|
@ -425,38 +551,47 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
|||
"Allow unknown words to appear in output");
|
||||
cli.add<bool>("--n-best",
|
||||
"Generate n-best list");
|
||||
cli.add<bool>("--word-scores",
|
||||
"Print word-level scores");
|
||||
|
||||
// efficiency options
|
||||
cli.add<int>("--valid-mini-batch",
|
||||
"Size of mini-batch used during validation",
|
||||
32);
|
||||
cli.add<size_t>("--valid-max-length",
|
||||
"Maximum length of a sentence in a validating sentence pair",
|
||||
"Maximum length of a sentence in a validating sentence pair. "
|
||||
"Sentences longer than valid-max-length are cropped to valid-max-length",
|
||||
1000);
|
||||
|
||||
// options for validation script
|
||||
cli.add_nondefault<std::string>("--valid-script-path",
|
||||
cli.add<std::string>("--valid-script-path",
|
||||
"Path to external validation script."
|
||||
" It should print a single score to stdout."
|
||||
" If the option is used with validating translation, the output"
|
||||
" translation file will be passed as a first argument");
|
||||
cli.add_nondefault<std::string>("--valid-translation-output",
|
||||
"Path to store the translation");
|
||||
|
||||
cli.add<std::vector<std::string>>("--valid-script-args",
|
||||
"Additional args passed to --valid-script-path. These are inserted"
|
||||
" between the script path and the output translation-file path");
|
||||
cli.add<std::string>("--valid-translation-output",
|
||||
"(Template for) path to store the translation. "
|
||||
"E.g., validation-output-after-{U}-updates-{T}-tokens.txt. Template "
|
||||
"parameters: {E} for epoch; {B} for No. of batches within epoch; "
|
||||
"{U} for total No. of updates; {T} for total No. of tokens seen.");
|
||||
cli.add<bool>("--keep-best",
|
||||
"Keep best model for each validation metric");
|
||||
cli.add_nondefault<std::string>("--valid-log",
|
||||
cli.add<std::string>("--valid-log",
|
||||
"Log validation scores to file given by arg");
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Translator options");
|
||||
auto previous_group = cli.switchGroup("Translator options");
|
||||
|
||||
// clang-format off
|
||||
cli.add<std::vector<std::string>>("--input,-i",
|
||||
"Paths to input file(s), stdin by default",
|
||||
std::vector<std::string>({"stdin"}));
|
||||
{"stdin"});
|
||||
cli.add<std::string>("--output,-o",
|
||||
"Path to output file, stdout by default",
|
||||
"stdout");
|
||||
|
@ -478,9 +613,15 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
|||
"Allow unknown words to appear in output");
|
||||
cli.add<bool>("--n-best",
|
||||
"Generate n-best list");
|
||||
cli.add_nondefault<std::string>("--alignment",
|
||||
cli.add<std::string>("--alignment",
|
||||
"Return word alignment. Possible values: 0.0-1.0, hard, soft")
|
||||
->implicit_val("1");
|
||||
cli.add<bool>("--word-scores",
|
||||
"Print word-level scores");
|
||||
#ifdef USE_SENTENCEPIECE
|
||||
cli.add<bool>("--no-spm-decode",
|
||||
"Keep the output segmented into SentencePiece subwords");
|
||||
#endif
|
||||
|
||||
addSuboptionsDevices(cli);
|
||||
addSuboptionsInputLength(cli);
|
||||
|
@ -490,26 +631,31 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
|||
"Optimize speed aggressively sacrificing memory or precision");
|
||||
cli.add<bool>("--skip-cost",
|
||||
"Ignore model cost during translation, not recommended for beam-size > 1");
|
||||
cli.add<bool>("--fp16",
|
||||
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
|
||||
cli.add<std::vector<std::string>>("--precision",
|
||||
"Mixed precision for inference, set parameter type in expression graph",
|
||||
{"float32"});
|
||||
|
||||
cli.add_nondefault<std::vector<std::string>>("--shortlist",
|
||||
cli.add<std::vector<std::string>>("--shortlist",
|
||||
"Use softmax shortlist: path first best prune");
|
||||
cli.add_nondefault<std::vector<float>>("--weights",
|
||||
cli.add<std::vector<float>>("--weights",
|
||||
"Scorer weights");
|
||||
cli.add<bool>("--output-sampling",
|
||||
"Noise output layer with gumbel noise",
|
||||
false);
|
||||
|
||||
// TODO: the options should be available only in server
|
||||
cli.add_nondefault<size_t>("--port,-p",
|
||||
"Port number for web socket server");
|
||||
#if 0 // @TODO: Ask Hany if there are any decoding-time options
|
||||
// add ULR settings
|
||||
addSuboptionsULR(cli);
|
||||
#endif
|
||||
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Scorer options");
|
||||
auto previous_group = cli.switchGroup("Scorer options");
|
||||
|
||||
// clang-format off
|
||||
cli.add<bool>("--no-reload",
|
||||
|
@ -530,10 +676,10 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
|
|||
"Feature name to be inserted into n-best list", "Score");
|
||||
cli.add<bool>("--normalize,-n",
|
||||
"Divide translation score by translation length");
|
||||
cli.add_nondefault<std::string>("--summary",
|
||||
cli.add<std::string>("--summary",
|
||||
"Only print total cost, possible values: cross-entropy (ce-mean), ce-mean-words, ce-sum, perplexity")
|
||||
->implicit_val("cross-entropy");
|
||||
cli.add_nondefault<std::string>("--alignment",
|
||||
cli.add<std::string>("--alignment",
|
||||
"Return word alignments. Possible values: 0.0-1.0, hard, soft")
|
||||
->implicit_val("1"),
|
||||
|
||||
|
@ -543,6 +689,13 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
|
|||
|
||||
cli.add<bool>("--optimize",
|
||||
"Optimize speed aggressively sacrificing memory or precision");
|
||||
cli.add<bool>("--fp16",
|
||||
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
|
||||
cli.add<std::vector<std::string>>("--precision",
|
||||
"Mixed precision for inference, set parameter type in expression graph",
|
||||
{"float32"});
|
||||
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
@ -550,8 +703,8 @@ void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) {
|
|||
// clang-format off
|
||||
cli.add<std::vector<std::string>>("--devices,-d",
|
||||
"Specifies GPU ID(s) to use for training. Defaults to 0..num-devices-1",
|
||||
std::vector<std::string>({"0"}));
|
||||
cli.add_nondefault<size_t>("--num-devices",
|
||||
{"0"});
|
||||
cli.add<size_t>("--num-devices",
|
||||
"Number of GPUs to use for this process. Defaults to length(devices) or 1");
|
||||
#ifdef USE_NCCL
|
||||
if(mode_ == cli::mode::training)
|
||||
|
@ -560,12 +713,13 @@ void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) {
|
|||
#endif
|
||||
#ifdef CUDA_FOUND
|
||||
cli.add<size_t>("--cpu-threads",
|
||||
"Use CPU-based computation with this many independent threads, 0 means GPU-based computation")
|
||||
->default_val("0")->implicit_val("1");
|
||||
"Use CPU-based computation with this many independent threads, 0 means GPU-based computation",
|
||||
0)
|
||||
->implicit_val("1");
|
||||
#else
|
||||
cli.add<size_t>("--cpu-threads",
|
||||
"Use CPU-based computation with this many independent threads, 0 means GPU-based computation")
|
||||
->default_val("1");
|
||||
"Use CPU-based computation with this many independent threads, 0 means GPU-based computation",
|
||||
1);
|
||||
#endif
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -593,6 +747,8 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
|
|||
cli.add<size_t>("--mini-batch-fit-step",
|
||||
"Step size for mini-batch-fit statistics",
|
||||
10);
|
||||
cli.add<bool>("--gradient-checkpointing",
|
||||
"Enable gradient-checkpointing to minimize memory usage");
|
||||
}
|
||||
|
||||
cli.add<int>("--maxi-batch",
|
||||
|
@ -602,8 +758,25 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
|
|||
"Sorting strategy for maxi-batch: none, src, trg (not available for decoder)",
|
||||
defaultMaxiBatchSort);
|
||||
|
||||
if(mode_ == cli::mode::training) {
|
||||
cli.add<bool>("--shuffle-in-ram",
|
||||
"Keep shuffled corpus in RAM, do not write to temp file");
|
||||
// @TODO: Consider making the next two options options of the vocab instead, to make it more local in scope.
|
||||
cli.add<size_t>("--all-caps-every",
|
||||
"When forming minibatches, preprocess every Nth line on the fly to all-caps. Assumes UTF-8");
|
||||
cli.add<size_t>("--english-title-case-every",
|
||||
"When forming minibatches, preprocess every Nth line on the fly to title-case. Assumes English (ASCII only)");
|
||||
|
||||
cli.add<int>("--mini-batch-words-ref",
|
||||
"If given, the following hyper parameters are adjusted as-if we had this mini-batch size: "
|
||||
"--learn-rate, --optimizer-params, --exponential-smoothing, --mini-batch-warmup");
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--mini-batch-warmup",
|
||||
"Linear ramp-up of MB size, up to this #updates (append 't' for up to this #target labels). "
|
||||
"Auto-adjusted to --mini-batch-words-ref if given",
|
||||
{"0"});
|
||||
cli.add<bool>("--mini-batch-track-lr",
|
||||
"Dynamically track mini-batch size inverse to actual learning rate (not considering lr-warmup)");
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
@ -614,7 +787,7 @@ void ConfigParser::addSuboptionsInputLength(cli::CLIWrapper& cli) {
|
|||
"Maximum length of a sentence in a training sentence pair",
|
||||
defaultMaxLength);
|
||||
cli.add<bool>("--max-length-crop",
|
||||
"Crop a sentence to max-length instead of ommitting it if longer than max-length");
|
||||
"Crop a sentence to max-length instead of omitting it if longer than max-length");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
@ -622,8 +795,7 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
|
|||
// clang-format off
|
||||
// support for universal encoder ULR https://arxiv.org/pdf/1802.05368.pdf
|
||||
cli.add<bool>("--ulr",
|
||||
"Enable ULR (Universal Language Representation)",
|
||||
false);
|
||||
"Enable ULR (Universal Language Representation)");
|
||||
// reading pre-trained universal embeddings for multi-sources.
|
||||
// Note that source and target here is relative to ULR not the translation langs
|
||||
// queries: EQ in Fig2 : is the unified embeddings projected to one space.
|
||||
|
@ -636,8 +808,7 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
|
|||
"Path to file with universal sources embeddings of traget keys from projection into universal space",
|
||||
"");
|
||||
cli.add<bool>("--ulr-trainable-transformation",
|
||||
"Make Query Transformation Matrix A trainable",
|
||||
false);
|
||||
"Make Query Transformation Matrix A trainable");
|
||||
cli.add<int>("--ulr-dim-emb",
|
||||
"ULR monolingual embeddings dimension");
|
||||
cli.add<float>("--ulr-dropout",
|
||||
|
@ -649,62 +820,41 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
|
|||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::expandAliases(cli::CLIWrapper& cli) {
|
||||
YAML::Node config;
|
||||
// The order of aliases does matter as later options overwrite earlier
|
||||
|
||||
if(config_["best-deep"].as<bool>()) {
|
||||
config["layer-normalization"] = true;
|
||||
config["tied-embeddings"] = true;
|
||||
config["enc-type"] = "alternating";
|
||||
config["enc-cell-depth"] = 2;
|
||||
config["enc-depth"] = 4;
|
||||
config["dec-cell-base-depth"] = 4;
|
||||
config["dec-cell-high-depth"] = 2;
|
||||
config["dec-depth"] = 4;
|
||||
config["skip"] = true;
|
||||
}
|
||||
cli::mode ConfigParser::getMode() const { return mode_; }
|
||||
|
||||
if(config) {
|
||||
auto success = cli.updateConfig(config);
|
||||
ABORT_IF(!success, "Unknown option(s) in aliases, check if aliases consist of correct options");
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
||||
cli::CLIWrapper cli(config_,
|
||||
"Marian: Fast Neural Machine Translation in C++",
|
||||
"General options",
|
||||
"",
|
||||
40);
|
||||
|
||||
addOptionsGeneral(cli);
|
||||
addOptionsModel(cli);
|
||||
|
||||
// clang-format off
|
||||
switch(mode_) {
|
||||
case cli::mode::training:
|
||||
addOptionsTraining(cli);
|
||||
addOptionsValidation(cli);
|
||||
break;
|
||||
case cli::mode::translation:
|
||||
addOptionsTranslation(cli);
|
||||
break;
|
||||
case cli::mode::scoring:
|
||||
addOptionsScoring(cli);
|
||||
break;
|
||||
}
|
||||
// clang-format on
|
||||
Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate){
|
||||
cmdLine_ = escapeCmdLine(argc,argv);
|
||||
|
||||
// parse command-line options and fill wrapped YAML config
|
||||
cli.parse(argc, argv);
|
||||
cli_.parse(argc, argv);
|
||||
|
||||
if(get<bool>("authors")) {
|
||||
std::cerr << authors() << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(get<bool>("cite")) {
|
||||
std::cerr << citation() << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto buildInfo = get<std::string>("build-info");
|
||||
if(!buildInfo.empty() && buildInfo != "false") {
|
||||
if(buildInfo == "all")
|
||||
std::cerr << cmakeBuildOptionsAdvanced() << std::endl;
|
||||
else
|
||||
std::cerr << cmakeBuildOptions() << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// get paths to extra config files
|
||||
auto configPaths = findConfigPaths();
|
||||
if(!configPaths.empty()) {
|
||||
auto config = loadConfigFiles(configPaths);
|
||||
auto success = cli.updateConfig(config);
|
||||
ABORT_IF(!success, "There are option(s) in a config file that are not expected");
|
||||
cli_.updateConfig(config,
|
||||
cli::OptionPriority::ConfigFile,
|
||||
"There are option(s) in a config file that are not expected");
|
||||
}
|
||||
|
||||
if(get<bool>("interpolate-env-vars")) {
|
||||
|
@ -712,21 +862,29 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
|||
}
|
||||
|
||||
if(doValidate) {
|
||||
// this aborts the program on first validation error
|
||||
ConfigValidator(config_).validateOptions(mode_);
|
||||
}
|
||||
|
||||
// remove extra config files from the config to avoid redundancy
|
||||
config_.remove("config");
|
||||
|
||||
if(has("dump-config") && get<std::string>("dump-config") != "false") {
|
||||
bool skipDefault = get<std::string>("dump-config") == "minimal";
|
||||
if(!get<std::string>("dump-config").empty() && get<std::string>("dump-config") != "false") {
|
||||
auto dumpMode = get<std::string>("dump-config");
|
||||
config_.remove("dump-config");
|
||||
std::cout << cli.dumpConfig(skipDefault) << std::endl;
|
||||
|
||||
if(dumpMode == "expand") {
|
||||
cli_.parseAliases();
|
||||
}
|
||||
|
||||
bool minimal = (dumpMode == "minimal" || dumpMode == "expand");
|
||||
std::cout << cli_.dumpConfig(minimal) << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
expandAliases(cli);
|
||||
cli_.parseAliases();
|
||||
auto opts = New<Options>();
|
||||
opts->merge(Config(*this).get());
|
||||
return opts;
|
||||
}
|
||||
|
||||
std::vector<std::string> ConfigParser::findConfigPaths() {
|
||||
|
@ -760,7 +918,8 @@ YAML::Node ConfigParser::loadConfigFiles(const std::vector<std::string>& paths)
|
|||
|
||||
for(auto& path : paths) {
|
||||
// load single config file
|
||||
YAML::Node config = YAML::Load(io::InputFileStream(path));
|
||||
io::InputFileStream strm(path);
|
||||
YAML::Node config = YAML::Load(strm);
|
||||
|
||||
// expand relative paths if requested
|
||||
if(config["relative-paths"] && config["relative-paths"].as<bool>()) {
|
||||
|
@ -787,7 +946,7 @@ YAML::Node ConfigParser::loadConfigFiles(const std::vector<std::string>& paths)
|
|||
return configAll;
|
||||
}
|
||||
|
||||
YAML::Node ConfigParser::getConfig() const {
|
||||
const YAML::Node& ConfigParser::getConfig() const {
|
||||
return config_;
|
||||
}
|
||||
} // namespace marian
|
||||
|
|
|
@ -14,21 +14,72 @@
|
|||
namespace marian {
|
||||
|
||||
namespace cli {
|
||||
enum struct mode { training, translation, scoring };
|
||||
enum struct mode { training, translation, scoring, server };
|
||||
} // namespace cli
|
||||
|
||||
/**
|
||||
* @brief Command-line options parser
|
||||
*
|
||||
* New options and aliases should be defined within `addOptions*` methods.
|
||||
* ... unless they are specific to certain executables.
|
||||
* In that case, use a pattern like this (e.g., for a server):
|
||||
* int main(int argc, char* argv[]) {
|
||||
* ConfigParser cp(cli::mode::translation);
|
||||
* cp.addOption<int>("--port", // option name
|
||||
* "Server Options", // option group name
|
||||
* "Port for server.", // help string
|
||||
* 5678); // default value
|
||||
* auto opts = cp.parseOptions(argc,argv,true); // 'true' for validation
|
||||
* ...
|
||||
*
|
||||
*
|
||||
*/
|
||||
class ConfigParser {
|
||||
public:
|
||||
|
||||
ConfigParser(cli::mode mode);
|
||||
|
||||
ConfigParser(int argc, char** argv, cli::mode mode, bool validate = false)
|
||||
: mode_(mode) {
|
||||
: ConfigParser(mode) {
|
||||
parseOptions(argc, argv, validate);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ConfigParser&
|
||||
addOption(const std::string& args,
|
||||
const std::string& group,
|
||||
const std::string& help,
|
||||
const T val) {
|
||||
std::string previous_group = cli_.switchGroup(group);
|
||||
cli_.add<T>(args,help,val);
|
||||
cli_.switchGroup(previous_group);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ConfigParser&
|
||||
addOption(const std::string& args,
|
||||
const std::string& group,
|
||||
const std::string& help,
|
||||
const T val,
|
||||
const T implicit_val) {
|
||||
std::string previous_group = cli_.switchGroup(group);
|
||||
cli_.add<T>(args,help,val)->implicit_val(implicit_val);
|
||||
cli_.switchGroup(previous_group);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ConfigParser&
|
||||
addOption(const std::string& args,
|
||||
const std::string& group,
|
||||
const std::string& help) {
|
||||
std::string previous_group = cli_.switchGroup(group);
|
||||
cli_.add<T>(args,help);
|
||||
cli_.switchGroup(previous_group);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Parse command-line options
|
||||
*
|
||||
|
@ -45,14 +96,18 @@ public:
|
|||
* @param argc
|
||||
* @param argv
|
||||
* @param validate Do or do not validate parsed options
|
||||
* @return (YAML::Node const&)config_
|
||||
*/
|
||||
void parseOptions(int argc, char** argv, bool validate);
|
||||
|
||||
YAML::Node getConfig() const;
|
||||
|
||||
Ptr<Options> parseOptions(int argc, char** argv, bool validate);
|
||||
YAML::Node const& getConfig() const;
|
||||
cli::mode getMode() const;
|
||||
std::string const& cmdLine() const;
|
||||
private:
|
||||
cli::CLIWrapper cli_;
|
||||
cli::mode mode_;
|
||||
YAML::Node config_;
|
||||
std::string cmdLine_;
|
||||
|
||||
// Check if the config contains value for option key
|
||||
bool has(const std::string& key) const {
|
||||
|
@ -68,17 +123,19 @@ private:
|
|||
}
|
||||
|
||||
void addOptionsGeneral(cli::CLIWrapper&);
|
||||
void addOptionsServer(cli::CLIWrapper&);
|
||||
void addOptionsModel(cli::CLIWrapper&);
|
||||
void addOptionsTraining(cli::CLIWrapper&);
|
||||
void addOptionsValidation(cli::CLIWrapper&);
|
||||
void addOptionsTranslation(cli::CLIWrapper&);
|
||||
void addOptionsScoring(cli::CLIWrapper&);
|
||||
|
||||
void addAliases(cli::CLIWrapper&);
|
||||
|
||||
void addSuboptionsDevices(cli::CLIWrapper&);
|
||||
void addSuboptionsBatching(cli::CLIWrapper&);
|
||||
void addSuboptionsInputLength(cli::CLIWrapper&);
|
||||
void addSuboptionsULR(cli::CLIWrapper&);
|
||||
void expandAliases(cli::CLIWrapper&);
|
||||
|
||||
// Extract paths to all config files found in the config object.
|
||||
// Look at --config option and model.npz.yml files.
|
||||
|
|
|
@ -10,7 +10,10 @@ bool ConfigValidator::has(const std::string& key) const {
|
|||
return config_[key];
|
||||
}
|
||||
|
||||
ConfigValidator::ConfigValidator(const YAML::Node& config) : config_(config) {}
|
||||
ConfigValidator::ConfigValidator(const YAML::Node& config)
|
||||
: config_(config),
|
||||
dumpConfigOnly_(config["dump-config"] && !config["dump-config"].as<std::string>().empty()
|
||||
&& config["dump-config"].as<std::string>() != "false") {}
|
||||
|
||||
ConfigValidator::~ConfigValidator() {}
|
||||
|
||||
|
@ -28,6 +31,9 @@ void ConfigValidator::validateOptions(cli::mode mode) const {
|
|||
validateOptionsParallelData();
|
||||
validateOptionsTraining();
|
||||
break;
|
||||
default:
|
||||
ABORT("wrong CLI mode");
|
||||
break;
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
|
@ -42,16 +48,25 @@ void ConfigValidator::validateOptionsTranslation() const {
|
|||
ABORT_IF(models.empty() && configs.empty(),
|
||||
"You need to provide at least one model file or a config file");
|
||||
|
||||
auto vocabs = get<std::vector<std::string>>("vocabs");
|
||||
ABORT_IF(vocabs.empty(), "Translating, but vocabularies are not given!");
|
||||
|
||||
for(const auto& modelFile : models) {
|
||||
filesystem::Path modelPath(modelFile);
|
||||
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);
|
||||
}
|
||||
|
||||
auto vocabs = get<std::vector<std::string>>("vocabs");
|
||||
ABORT_IF(vocabs.empty(), "Translating, but vocabularies are not given");
|
||||
|
||||
for(const auto& vocabFile : vocabs) {
|
||||
filesystem::Path vocabPath(vocabFile);
|
||||
ABORT_IF(!filesystem::exists(vocabPath), "Vocabulary file does not exist: " + vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigValidator::validateOptionsParallelData() const {
|
||||
// Do not check these constraints if only goal is to dump config
|
||||
if(dumpConfigOnly_)
|
||||
return;
|
||||
|
||||
auto trainSets = get<std::vector<std::string>>("train-sets");
|
||||
ABORT_IF(trainSets.empty(), "No train sets given in config file or on command line");
|
||||
|
||||
|
@ -62,17 +77,23 @@ void ConfigValidator::validateOptionsParallelData() const {
|
|||
|
||||
void ConfigValidator::validateOptionsScoring() const {
|
||||
filesystem::Path modelPath(get<std::string>("model"));
|
||||
|
||||
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelPath.string());
|
||||
ABORT_IF(get<std::vector<std::string>>("vocabs").empty(),
|
||||
"Scoring, but vocabularies are not given!");
|
||||
|
||||
auto vocabs = get<std::vector<std::string>>("vocabs");
|
||||
ABORT_IF(vocabs.empty(), "Scoring, but vocabularies are not given");
|
||||
|
||||
for(const auto& vocabFile : vocabs) {
|
||||
filesystem::Path vocabPath(vocabFile);
|
||||
ABORT_IF(!filesystem::exists(vocabPath), "Vocabulary file does not exist: " + vocabFile);
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigValidator::validateOptionsTraining() const {
|
||||
auto trainSets = get<std::vector<std::string>>("train-sets");
|
||||
|
||||
ABORT_IF(has("embedding-vectors")
|
||||
&& get<std::vector<std::string>>("embedding-vectors").size() != trainSets.size(),
|
||||
&& get<std::vector<std::string>>("embedding-vectors").size() != trainSets.size()
|
||||
&& !get<std::vector<std::string>>("embedding-vectors").empty(),
|
||||
"There should be as many embedding vector files as training sets");
|
||||
|
||||
filesystem::Path modelPath(get<std::string>("model"));
|
||||
|
@ -84,12 +105,13 @@ void ConfigValidator::validateOptionsTraining() const {
|
|||
ABORT_IF(!modelDir.empty() && !filesystem::isDirectory(modelDir),
|
||||
"Model directory does not exist");
|
||||
|
||||
ABORT_IF(
|
||||
has("valid-sets") && get<std::vector<std::string>>("valid-sets").size() != trainSets.size(),
|
||||
ABORT_IF(has("valid-sets")
|
||||
&& get<std::vector<std::string>>("valid-sets").size() != trainSets.size()
|
||||
&& !get<std::vector<std::string>>("valid-sets").empty(),
|
||||
"There should be as many validation sets as training sets");
|
||||
|
||||
// validations for learning rate decaying
|
||||
ABORT_IF(get<double>("lr-decay") > 1.0, "Learning rate decay factor greater than 1.0 is unusual");
|
||||
ABORT_IF(get<float>("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual");
|
||||
|
||||
auto strategy = get<std::string>("lr-decay-strategy");
|
||||
|
||||
|
|
|
@ -5,19 +5,20 @@
|
|||
|
||||
namespace marian {
|
||||
|
||||
// TODO: Finally refactorize Config, Options, ConfigParser and ConfigValidator
|
||||
// classes.
|
||||
class ConfigValidator {
|
||||
private:
|
||||
const YAML::Node& config_;
|
||||
|
||||
bool has(const std::string& key) const;
|
||||
|
||||
template <typename T>
|
||||
T get(const std::string& key) const {
|
||||
return config_[key].as<T>();
|
||||
}
|
||||
|
||||
// The option --dump-config is used, so alleviate some constraints, e.g. we don't want to require
|
||||
// --train-sets or --vocabs
|
||||
bool dumpConfigOnly_{false};
|
||||
|
||||
void validateOptionsTranslation() const;
|
||||
void validateOptionsParallelData() const;
|
||||
void validateOptionsScoring() const;
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include "common/logging.h"
|
||||
#include "shape.h"
|
||||
#include "common/shape.h"
|
||||
#include "common/intrusive_ptr.h"
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
@ -9,10 +10,33 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
//#define THREAD_GUARD(body) std::thread([&]() { body; }).join()
|
||||
// The macro MAYBE_UNUSED is used to selectively disable
|
||||
// unused-variable warnings. C++17 defines the attribute
|
||||
// [[maybe_unused]], but I don't think we're at C++17 yet. We can add it when we reach C++17.
|
||||
// The compilers gcc and clang (and maybe others) define
|
||||
// __has_attribute and support __attribute__(unused) in C++11,
|
||||
#if defined __has_attribute
|
||||
# if __has_attribute(unused)
|
||||
# define MAYBE_UNUSED __attribute__((unused))
|
||||
# else
|
||||
# define MAYBE_UNUSED
|
||||
# endif
|
||||
#else
|
||||
# define MAYBE_UNUSED
|
||||
#endif
|
||||
|
||||
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur.
|
||||
#define NodeOp(op) [=]() { op; }
|
||||
|
||||
// helper macro to disable optimization (gcc only)
|
||||
// To use this, just insert DONT_OPTIMIZE right before the function definition
|
||||
// (e.g. where the "static" keyword would go).
|
||||
#ifdef __GNUC__
|
||||
#define DONT_OPTIMIZE __attribute__((optimize("O0")))
|
||||
#else
|
||||
#define DONT_OPTIMIZE // silently ignore on Visual Studio, where this is less of a problem
|
||||
#endif
|
||||
|
||||
namespace marian {
|
||||
|
||||
// Type to be used for all index types, e.g. for integer tensors for rows operator.
|
||||
|
@ -21,12 +45,21 @@ namespace marian {
|
|||
// This minimizes bandwith at little cost.
|
||||
typedef uint32_t IndexType;
|
||||
|
||||
// @TODO: come up with better short name. "I..." stands for interface now. Here it stands
|
||||
// for "intrusive". Not a good overlap.
|
||||
template <class T>
|
||||
using Ptr = std::shared_ptr<T>;
|
||||
using IPtr = IntrusivePtr<T>;
|
||||
|
||||
template <class T>
|
||||
using UPtr = std::unique_ptr<T>;
|
||||
|
||||
// @TODO: come up with better short name. "I..." stands for interface now.
|
||||
template <class T>
|
||||
using IWeak = T*;
|
||||
|
||||
template <class T>
|
||||
using Ptr = std::shared_ptr<T>;
|
||||
|
||||
template <class T>
|
||||
using Weak = std::weak_ptr<T>;
|
||||
|
||||
|
@ -42,6 +75,18 @@ Ptr<T> New(Ptr<T> p) {
|
|||
return Ptr<T>(p);
|
||||
}
|
||||
|
||||
/** @brief Creates InstrusivePtr of any type, passes all arguments to any available
|
||||
* constructor */
|
||||
template <class T, typename... Args>
|
||||
IPtr<T> INew(Args&&... args) {
|
||||
return IPtr<T>(new T(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
IPtr<T> INew(Ptr<T> p) {
|
||||
return IPtr<T>(p);
|
||||
}
|
||||
|
||||
enum class DeviceType : size_t { gpu = 0, cpu = 1 };
|
||||
|
||||
struct DeviceId {
|
||||
|
@ -51,8 +96,16 @@ struct DeviceId {
|
|||
DeviceId() : no{0}, type{DeviceType::gpu} {}
|
||||
DeviceId(size_t no_, DeviceType type_) : no(no_), type(type_) {}
|
||||
|
||||
std::string typeAsString() const {
|
||||
return (type == DeviceType::gpu ? "gpu" : "cpu");
|
||||
}
|
||||
|
||||
operator std::string() const {
|
||||
return typeAsString() + std::to_string(no);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, DeviceId deviceId) {
|
||||
out << (deviceId.type == DeviceType::gpu ? "gpu" : "cpu") << deviceId.no;
|
||||
out << std::string(deviceId);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -81,12 +134,14 @@ const DeviceId GPU5{5, DeviceType::gpu};
|
|||
const DeviceId GPU6{6, DeviceType::gpu};
|
||||
const DeviceId GPU7{7, DeviceType::gpu};
|
||||
|
||||
// These are many small objects, hence use IntrusivePtr
|
||||
class TensorBase;
|
||||
typedef Ptr<TensorBase> Tensor;
|
||||
typedef IPtr<TensorBase> Tensor;
|
||||
|
||||
// These are many small objects, hence use IntrusivePtr
|
||||
template <class DataType>
|
||||
class Chainable;
|
||||
typedef Ptr<Chainable<Tensor>> Expr;
|
||||
typedef IPtr<Chainable<Tensor>> Expr;
|
||||
|
||||
class OptimizerBase;
|
||||
typedef Ptr<OptimizerBase> OptimizerBasePtr;
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
#include "common/fastopt.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace marian {
|
||||
|
||||
const std::unique_ptr<const FastOpt> FastOpt::uniqueNullPtr{nullptr};
|
||||
|
||||
// see fastopt.h for comments
|
||||
namespace fastopt_helpers {
|
||||
|
||||
// helper structs for dynamic type conversion and specializations
|
||||
// for different conversion scenarios.
|
||||
|
||||
// general template, mostly for numerical and logical types
|
||||
template <typename To, typename From>
|
||||
struct Convert {
|
||||
static inline To apply(const From& from) {
|
||||
return (To)from;
|
||||
}
|
||||
};
|
||||
|
||||
// specialization for translating from string, @TODO check if this is required at all, mostly for compilation now.
|
||||
template <typename To>
|
||||
struct Convert<To, std::string> {
|
||||
static inline To apply(const std::string& /* from */) {
|
||||
ABORT("Not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
// convert anything to string, checked at compile-time
|
||||
template <typename From>
|
||||
struct Convert<std::string, From> {
|
||||
static inline std::string apply(const From& from) {
|
||||
return std::to_string(from);
|
||||
}
|
||||
};
|
||||
|
||||
// do nothing conversion for std::string
|
||||
template <>
|
||||
struct Convert<std::string, std::string> {
|
||||
static inline std::string apply(const std::string& from) {
|
||||
return from;
|
||||
}
|
||||
};
|
||||
|
||||
// helper class for FastOpt::as<T>() used for specializations
|
||||
template <typename T>
|
||||
T As<T>::apply(const FastOpt& node) {
|
||||
ABORT_IF(!node.isScalar(), "Node is not a scalar node");
|
||||
|
||||
if(node.isBool())
|
||||
return Convert<T, bool>::apply(node.value_.as<bool>());
|
||||
else if(node.isInt())
|
||||
return Convert<T, int64_t>::apply(node.value_.as<int64_t>());
|
||||
else if(node.isFloat())
|
||||
return Convert<T, double>::apply(node.value_.as<double>());
|
||||
else if(node.isString())
|
||||
return Convert<T, std::string>::apply(node.value_.as<std::string>());
|
||||
else {
|
||||
ABORT("Casting of value failed");
|
||||
}
|
||||
}
|
||||
|
||||
// specializations for simple types
|
||||
template struct As<bool>;
|
||||
template struct As<int>;
|
||||
template struct As<unsigned long>;
|
||||
template struct As<float>;
|
||||
template struct As<double>;
|
||||
template struct As<std::string>;
|
||||
|
||||
// specialization of above class for std::vector<T>
|
||||
template <typename T>
|
||||
std::vector<T> As<std::vector<T>>::apply(const FastOpt& node) {
|
||||
ABORT_IF(!node.isSequence(), "Node is not a sequence node");
|
||||
|
||||
std::vector<T> seq;
|
||||
for(const auto& elem : node.array_)
|
||||
seq.push_back(elem->as<T>());
|
||||
return seq;
|
||||
}
|
||||
|
||||
// specializations for simple vector types
|
||||
template struct As<std::vector<bool>>;
|
||||
template struct As<std::vector<int>>;
|
||||
// Windows, Linux based OS and Mac have different type definitions for 'unsigned long'.
|
||||
// So, we need an explicit definitions for uint64_t, that cover different platforms.
|
||||
// Otherwise, there's a linking error on windows or Linux or Mac.
|
||||
// https://software.intel.com/en-us/articles/size-of-long-integer-type-on-different-architecture-and-os/
|
||||
// https://stackoverflow.com/questions/32021860/c-should-you-size-t-with-a-regular-array
|
||||
// MacOS: size_t = unsigned long (8 bytes), uint64_t = unsigned long long (8 bytes)
|
||||
// Linux: size_t = unsigned long (8 bytes), uint64_t = unsigned long (8 bytes)
|
||||
// Windows: size_t = unsigned long long (8 bytes), uint64_t = unsigned long long (8 bytes)
|
||||
template struct As<std::vector<unsigned long long>>;
|
||||
template struct As<std::vector<unsigned long>>;
|
||||
template struct As<std::vector<float>>;
|
||||
template struct As<std::vector<double>>;
|
||||
template struct As<std::vector<std::string>>;
|
||||
|
||||
// specialization of above class for std::pair<T>
|
||||
template <typename T1, typename T2>
|
||||
std::pair<T1, T2> As<std::pair<T1, T2>>::apply(const FastOpt& node) {
|
||||
ABORT_IF(!node.isSequence(), "Node is not a sequence node");
|
||||
ABORT_IF(node.size() != 2, "Sequence must contain two elements in order to convert to pair");
|
||||
return std::make_pair(node[0].as<T1>(), node[1].as<T2>());
|
||||
}
|
||||
|
||||
template struct As<std::pair<int, int>>;
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,379 @@
|
|||
#pragma once
|
||||
|
||||
#include "common/definitions.h"
|
||||
#include "3rd_party/any_type.h"
|
||||
#include "3rd_party/phf/phf.h"
|
||||
#include "3rd_party/yaml-cpp/yaml.h"
|
||||
|
||||
// This file contains code to create a fast access option class,
|
||||
// meant as a replacment/supplement to YAML::Node.
|
||||
|
||||
namespace marian {
|
||||
|
||||
namespace crc {
|
||||
// has to stay in header due to constexpr
|
||||
|
||||
// This code comes from https://notes.underscorediscovery.com/constexpr-fnv1a/
|
||||
// and is distributed as public domain as stated by the author under that link
|
||||
|
||||
// constants for hash computations
|
||||
constexpr uint64_t val_64_const = 0xcbf29ce484222325;
|
||||
constexpr uint64_t prime_64_const = 0x100000001b3;
|
||||
|
||||
// recursive compile-time hash, looking for stack-overflow source
|
||||
inline constexpr uint64_t
|
||||
hash_64_fnv1a_const(const char* const str,
|
||||
const uint64_t value = val_64_const) noexcept {
|
||||
return (str[0] == '\0') ? value :
|
||||
hash_64_fnv1a_const(&str[1], (value ^ uint64_t(str[0])) * prime_64_const);
|
||||
}
|
||||
|
||||
// Compile time string hashing. Should work particularly well for option look up with explicitly used keys like options->get("dim-input");
|
||||
inline constexpr uint64_t crc(const char* const str) noexcept {
|
||||
return hash_64_fnv1a_const(str);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
|
||||
// PerfectHash constructs a perfect hash for a set K of n numeric keys. The size of
|
||||
// the hash is m > n (not much larger) and n << max(K) (much smaller). If I am not wrong m
|
||||
// is the next power of 2 larger than n? We then build an array of size m with n fields defined.
|
||||
// m - n fields stay undefined (a bit of waste).
|
||||
class PerfectHash {
|
||||
private:
|
||||
phf phf_;
|
||||
|
||||
PerfectHash(const uint64_t keys[], size_t num) {
|
||||
int error = PHF::init<uint64_t, true>(&phf_, keys, num,
|
||||
/* bucket size */ 4,
|
||||
/* loading factor */ 90,
|
||||
/* seed */ 123456);
|
||||
ABORT_IF(error != 0, "PHF error {}", error);
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
PerfectHash(const std::vector<uint64_t>& v)
|
||||
: PerfectHash(v.data(), v.size()) { }
|
||||
|
||||
~PerfectHash() {
|
||||
PHF::destroy(&phf_);
|
||||
}
|
||||
|
||||
uint32_t operator[](const uint64_t& key) const {
|
||||
return PHF::hash<uint64_t>(const_cast<phf*>(&phf_), key);
|
||||
}
|
||||
|
||||
uint32_t operator[](const char* const keyStr) const {
|
||||
return (*this)[crc::crc(keyStr)];
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return phf_.m;
|
||||
}
|
||||
};
|
||||
|
||||
/*****************************************************************************/
|
||||
|
||||
class FastOpt;
|
||||
|
||||
// helper class for conversion, see fastopt.cpp
|
||||
namespace fastopt_helpers {
|
||||
template <typename T>
|
||||
struct As {
|
||||
static T apply(const FastOpt&);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct As<std::vector<T>> {
|
||||
static std::vector<T> apply(const FastOpt&);
|
||||
};
|
||||
|
||||
template <typename T1, typename T2>
|
||||
struct As<std::pair<T1, T2>> {
|
||||
static std::pair<T1, T2> apply(const FastOpt&);
|
||||
};
|
||||
}
|
||||
|
||||
// Fast access option class, meant as a replacment/supplement to YAML::Node.
|
||||
// Relatively expensive to construct, fast to access (not visible in profiler)
|
||||
// via std::vector or perfect hash. The perfect hash only requires a few look-ups
|
||||
// and arithmentic operations, still O(1).
|
||||
// Still requires YAML::Node support for parsing and modification via rebuilding.
|
||||
class FastOpt {
|
||||
private:
|
||||
template <typename T>
|
||||
friend struct fastopt_helpers::As;
|
||||
|
||||
public:
|
||||
// Node types for FastOpt, seem to be enough to cover YAML:NodeType
|
||||
enum struct NodeType {
|
||||
Null, Bool, Int64, Float64, String, Sequence, Map
|
||||
};
|
||||
|
||||
private:
|
||||
any_type value_;
|
||||
std::unique_ptr<const PerfectHash> ph_;
|
||||
std::vector<std::unique_ptr<const FastOpt>> array_;
|
||||
NodeType type_{NodeType::Null};
|
||||
|
||||
static const std::unique_ptr<const FastOpt> uniqueNullPtr; // return this unique_ptr if key not found, equivalent to nullptr
|
||||
|
||||
uint64_t fingerprint_{0}; // When node is used as a value in a map, used to check if the perfect hash
|
||||
// returned the right value (they can produce false positives)
|
||||
size_t elements_{0}; // Number of elements if isMap or isSequence is true, 0 otherwise.
|
||||
|
||||
// Used to find elements if isSequence() is true.
|
||||
inline const std::unique_ptr<const FastOpt>& arrayLookup(size_t keyId) const {
|
||||
if(keyId < array_.size())
|
||||
return array_[keyId];
|
||||
else
|
||||
return uniqueNullPtr;
|
||||
}
|
||||
|
||||
// Used to find elements if isMap() is true.
|
||||
inline const std::unique_ptr<const FastOpt>& phLookup(size_t keyId) const {
|
||||
if(ph_)
|
||||
return array_[(*ph_)[keyId]];
|
||||
else
|
||||
return uniqueNullPtr;
|
||||
}
|
||||
|
||||
// Build Null node.
|
||||
void makeNull() {
|
||||
elements_ = 0;
|
||||
type_ = NodeType::Null;
|
||||
|
||||
ABORT_IF(ph_, "ph_ should be undefined");
|
||||
ABORT_IF(!array_.empty(), "array_ should be empty");
|
||||
}
|
||||
|
||||
// Build Scalar node via controlled failure to convert from a YAML::Node object.
|
||||
void makeScalar(const YAML::Node& v) {
|
||||
elements_ = 0;
|
||||
try {
|
||||
// Cast node to text first, that works for any scalar node and test that it does not contain single characters
|
||||
// that according to YAML could be boolean values. Unfortunately, we do not have any type information at this point.
|
||||
// This means we are disabling support for boolean values in YAML that are expressed with these characters.
|
||||
auto asText = v.as<std::string>();
|
||||
if(asText.size() == 1 && asText.find_first_of("nyNYtfTF") == 0) // @TODO: should we disallow other strings too?
|
||||
throw YAML::BadConversion(YAML::Mark()); // get's picked up by next catch block
|
||||
|
||||
value_ = v.as<bool>();
|
||||
type_ = NodeType::Bool;
|
||||
} catch(const YAML::BadConversion& /*e*/) {
|
||||
try {
|
||||
value_ = v.as<int64_t>();
|
||||
type_ = NodeType::Int64;
|
||||
} catch(const YAML::BadConversion& /*e*/) {
|
||||
try {
|
||||
value_ = v.as<double>();
|
||||
type_ = NodeType::Float64;
|
||||
} catch(const YAML::BadConversion& /*e*/) {
|
||||
try {
|
||||
value_ = v.as<std::string>();
|
||||
type_ = NodeType::String;
|
||||
} catch (const YAML::BadConversion& /*e*/) {
|
||||
ABORT("Cannot convert YAML node {}", v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ABORT_IF(ph_, "ph_ should be undefined");
|
||||
ABORT_IF(!array_.empty(), "array_ should be empty");
|
||||
}
|
||||
|
||||
// Build a Sequence node, can by converted to std::vector<T> if elements can be converted to T.
|
||||
void makeSequence(const std::vector<YAML::Node>& v) {
|
||||
elements_ = v.size();
|
||||
ABORT_IF(!array_.empty(), "array_ is not empty??");
|
||||
for(size_t pos = 0; pos < v.size(); ++pos) {
|
||||
array_.emplace_back(new FastOpt(v[pos], pos));
|
||||
}
|
||||
type_ = NodeType::Sequence;
|
||||
|
||||
ABORT_IF(ph_, "ph_ should be undefined");
|
||||
}
|
||||
|
||||
// Build a Map node.
|
||||
void makeMap(const std::map<uint64_t, YAML::Node>& m) {
|
||||
std::vector<uint64_t> keys;
|
||||
for(const auto& it : m)
|
||||
keys.push_back(it.first);
|
||||
|
||||
ABORT_IF(ph_, "ph_ is already defined??");
|
||||
ph_.reset(new PerfectHash(keys));
|
||||
|
||||
ABORT_IF(!array_.empty(), "array_ is not empty??");
|
||||
|
||||
// for lack of resize_emplace
|
||||
for(int i = 0; i < ph_->size(); ++i)
|
||||
array_.emplace_back(nullptr);
|
||||
elements_ = keys.size();
|
||||
|
||||
for(const auto& it : m) {
|
||||
uint64_t key = it.first;
|
||||
size_t pos = (*ph_)[key];
|
||||
array_[pos].reset(new FastOpt(it.second, key));
|
||||
}
|
||||
|
||||
type_ = NodeType::Map;
|
||||
}
|
||||
|
||||
// Build a Map node, uses std::string as key, which gets hashed to size_t and used in the function above.
|
||||
void makeMap(const std::map<std::string, YAML::Node>& m) {
|
||||
std::map<uint64_t, YAML::Node> mi;
|
||||
for(const auto& it : m) {
|
||||
auto key = it.first.c_str();
|
||||
mi[crc::crc(key)] = it.second;
|
||||
}
|
||||
|
||||
makeMap(mi);
|
||||
}
|
||||
|
||||
// Only build from YAML::Node
|
||||
FastOpt(const FastOpt&) = delete;
|
||||
FastOpt() = delete;
|
||||
|
||||
void construct(const YAML::Node& node) {
|
||||
switch(node.Type()) {
|
||||
case YAML::NodeType::Scalar:
|
||||
makeScalar(node);
|
||||
break;
|
||||
case YAML::NodeType::Sequence: {
|
||||
std::vector<YAML::Node> nodesVec;
|
||||
for(auto&& n : node)
|
||||
nodesVec.push_back(n);
|
||||
makeSequence(nodesVec);
|
||||
} break;
|
||||
case YAML::NodeType::Map: {
|
||||
std::map<std::string, YAML::Node> nodesMap;
|
||||
for(auto& n : node) {
|
||||
auto key = n.first.as<std::string>();
|
||||
nodesMap[key] = n.second;
|
||||
}
|
||||
makeMap(nodesMap);
|
||||
} break;
|
||||
case YAML::NodeType::Undefined:
|
||||
case YAML::NodeType::Null:
|
||||
makeNull();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Constructor to recursively create a FastOpt object from a YAML::Node following the yaml structure.
|
||||
FastOpt(const YAML::Node& node)
|
||||
{ construct(node); }
|
||||
|
||||
FastOpt(const YAML::Node& node, uint64_t fingerprint)
|
||||
: fingerprint_{fingerprint}
|
||||
{ construct(node); }
|
||||
|
||||
bool isSequence() const {
|
||||
return type_ == NodeType::Sequence;
|
||||
}
|
||||
|
||||
bool isMap() const {
|
||||
return type_ == NodeType::Map;
|
||||
}
|
||||
|
||||
bool isScalar() const {
|
||||
return type_ == NodeType::Bool
|
||||
|| type_ == NodeType::Float64
|
||||
|| type_ == NodeType::Int64
|
||||
|| type_ == NodeType::String;
|
||||
}
|
||||
|
||||
bool isNull() const {
|
||||
return type_ == NodeType::Null;
|
||||
}
|
||||
|
||||
bool isInt() const {
|
||||
return type_ == NodeType::Int64;
|
||||
}
|
||||
|
||||
bool isBool() const {
|
||||
return type_ == NodeType::Bool;
|
||||
}
|
||||
|
||||
bool isFloat() const {
|
||||
return type_ == NodeType::Float64;
|
||||
}
|
||||
|
||||
bool isString() const {
|
||||
return type_ == NodeType::String;
|
||||
}
|
||||
|
||||
// actual number of elements in a sequence or map, 0 (not 1) for scalar nodes.
|
||||
// 0 here means rather "not applicable".
|
||||
size_t size() const {
|
||||
return elements_;
|
||||
}
|
||||
|
||||
// replace current node with an externally built FastOpt object
|
||||
void swap(FastOpt& other) {
|
||||
std::swap(value_, other.value_);
|
||||
std::swap(ph_, other.ph_);
|
||||
std::swap(array_, other.array_);
|
||||
std::swap(type_, other.type_);
|
||||
std::swap(elements_, other.elements_);
|
||||
// leave fingerprint alone as it needed by parent node.
|
||||
}
|
||||
|
||||
// Is the hashed key in a map?
|
||||
bool has(size_t keyId) const {
|
||||
if(isMap() && elements_ > 0) {
|
||||
const auto& ptr = phLookup(keyId);
|
||||
return ptr ? ptr->fingerprint_ == keyId : false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool has(const char* const key) const {
|
||||
return has(crc::crc(key));
|
||||
}
|
||||
|
||||
bool has(const std::string& key) const {
|
||||
return has(key.c_str());
|
||||
}
|
||||
|
||||
// convert to requested type
|
||||
template <typename T>
|
||||
inline T as() const {
|
||||
return fastopt_helpers::As<T>::apply(*this);
|
||||
}
|
||||
|
||||
// access sequence or map element
|
||||
const FastOpt& operator[](size_t keyId) const {
|
||||
if(isSequence()) {
|
||||
const auto& ptr = arrayLookup(keyId);
|
||||
ABORT_IF(!ptr, "Unseen key {}" , keyId);
|
||||
return *ptr;
|
||||
} else if(isMap()) {
|
||||
const auto& ptr = phLookup(keyId);
|
||||
ABORT_IF(!ptr || ptr->fingerprint_ != keyId, "Unseen key {}", keyId);
|
||||
return *ptr;
|
||||
} else {
|
||||
ABORT("Not a sequence or map node");
|
||||
}
|
||||
}
|
||||
|
||||
const FastOpt& operator[](int key) const {
|
||||
return operator[]((size_t)key);
|
||||
}
|
||||
|
||||
const FastOpt& operator[](const char* const key) const {
|
||||
// MacOS requires explicit cast to size_t before we can use it.
|
||||
return operator[]((size_t)crc::crc(key));
|
||||
}
|
||||
|
||||
const FastOpt& operator[](const std::string& key) const {
|
||||
return operator[](key.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
}
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче