diff --git a/CMakeLists.txt b/CMakeLists.txt index bd7a58cb1..167c625a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,6 @@ option(USE_TIMETAG "Set to ON to output time costs" OFF) option(USE_CUDA "Enable CUDA-accelerated training " OFF) option(USE_DEBUG "Set to ON for Debug mode" OFF) option(USE_SANITIZER "Use santizer flags" OFF) -option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON) set( ENABLED_SANITIZERS "address" "leak" "undefined" @@ -15,7 +14,8 @@ set( "Semicolon separated list of sanitizer names, e.g., 'address;leak'. \ Supported sanitizers are address, leak, undefined and thread." ) -option(BUILD_CLI "Build the 'lightbgm' command-line interface in addition to lib_lightgbm" ON) +option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON) +option(BUILD_CLI "Build the 'lightgbm' command-line interface in addition to lib_lightgbm" ON) option(BUILD_CPP_TEST "Build C++ tests with Google Test" OFF) option(BUILD_STATIC_LIB "Build static library" OFF) option(INSTALL_HEADERS "Install headers to CMAKE_INSTALL_PREFIX (e.g. '/usr/local/include')" ON) diff --git a/README.md b/README.md index 53e688ba5..f151c9db2 100644 --- a/README.md +++ b/README.md @@ -63,16 +63,6 @@ External (Unofficial) Repositories Projects listed here offer alternative ways to use LightGBM. They are not maintained or officially endorsed by the `LightGBM` development team. -LightGBMLSS (An extension of LightGBM to probabilistic modelling from which prediction intervals and quantiles can be derived): https://github.com/StatMixedML/LightGBMLSS - -FLAML (AutoML library for hyperparameter optimization): https://github.com/microsoft/FLAML - -supertree (interactive visualization of decision trees): https://github.com/mljar/supertree - -Optuna (hyperparameter optimization framework): https://github.com/optuna/optuna - -Julia-package: https://github.com/IQVIA-ML/LightGBM.jl - JPMML (Java PMML converter): https://github.com/jpmml/jpmml-lightgbm Nyoka (Python PMML converter): https://github.com/SoftwareAG/nyoka @@ -99,6 +89,8 @@ Shapash (model visualization and interpretation): https://github.com/MAIF/shapas dtreeviz (decision tree visualization and model interpretation): https://github.com/parrt/dtreeviz +supertree (interactive visualization of decision trees): https://github.com/mljar/supertree + SynapseML (LightGBM on Spark): https://github.com/microsoft/SynapseML Kubeflow Fairing (LightGBM on Kubernetes): https://github.com/kubeflow/fairing @@ -113,14 +105,32 @@ ML.NET (.NET/C#-package): https://github.com/dotnet/machinelearning LightGBM.NET (.NET/C#-package): https://github.com/rca22/LightGBM.Net -Ruby gem: https://github.com/ankane/lightgbm-ruby +LightGBM Ruby (Ruby gem): https://github.com/ankane/lightgbm-ruby LightGBM4j (Java high-level binding): https://github.com/metarank/lightgbm4j +LightGBM4J (JVM interface for LightGBM written in Scala): https://github.com/seek-oss/lightgbm4j + +Julia-package: https://github.com/IQVIA-ML/LightGBM.jl + lightgbm3 (Rust binding): https://github.com/Mottl/lightgbm3-rs +MLServer (inference server for LightGBM): https://github.com/SeldonIO/MLServer + MLflow (experiment tracking, model monitoring framework): https://github.com/mlflow/mlflow +FLAML (AutoML library for hyperparameter optimization): https://github.com/microsoft/FLAML + +MLJAR AutoML (AutoML on tabular data): https://github.com/mljar/mljar-supervised + +Optuna (hyperparameter optimization framework): https://github.com/optuna/optuna + +LightGBMLSS (probabilistic modelling with LightGBM): https://github.com/StatMixedML/LightGBMLSS + +mlforecast (time series forecasting with LightGBM): https://github.com/Nixtla/mlforecast + +skforecast (time series forecasting with LightGBM): https://github.com/JoaquinAmatRodrigo/skforecast + `{bonsai}` (R `{parsnip}`-compliant interface): https://github.com/tidymodels/bonsai `{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners diff --git a/src/io/config.cpp b/src/io/config.cpp index 20d327ca2..3f8e85d16 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -397,7 +397,7 @@ void Config::CheckParamConflict(const std::unordered_mapnum_leaves); +template +void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { + TREE_LEARNER_TYPE::Init(train_data, is_constant_hessian); + LinearTreeLearner::InitLinear(train_data, this->config_->num_leaves); } -void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { +template +void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { leaf_map_ = std::vector(train_data->num_data(), -1); contains_nan_ = std::vector(train_data->num_features(), 0); // identify features containing nans #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) for (int feat = 0; feat < train_data->num_features(); ++feat) { - auto bin_mapper = train_data_->FeatureBinMapper(feat); + auto bin_mapper = this->train_data_->FeatureBinMapper(feat); if (bin_mapper->bin_type() == BinType::NumericalBin) { - const float* feat_ptr = train_data_->raw_index(feat); + const float* feat_ptr = this->train_data_->raw_index(feat); for (int i = 0; i < train_data->num_data(); ++i) { if (std::isnan(feat_ptr[i])) { contains_nan_[feat] = 1; @@ -40,7 +42,7 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav } } // preallocate the matrix used to calculate linear model coefficients - int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features()); + int max_num_feat = std::min(max_leaves, this->train_data_->num_numeric_features()); XTHX_.clear(); XTg_.clear(); for (int i = 0; i < max_leaves; ++i) { @@ -59,25 +61,26 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav } } -Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { +template +Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); - gradients_ = gradients; - hessians_ = hessians; + this->gradients_ = gradients; + this->hessians_ = hessians; int num_threads = OMP_NUM_THREADS(); - if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) { + if (this->share_state_->num_threads != num_threads && this->share_state_->num_threads > 0) { Log::Warning( "Detected that num_threads changed during training (from %d to %d), " "it may cause unexpected errors.", - share_state_->num_threads, num_threads); + this->share_state_->num_threads, num_threads); } - share_state_->num_threads = num_threads; + this->share_state_->num_threads = num_threads; // some initial works before training - BeforeTrain(); + this->BeforeTrain(); - auto tree = std::unique_ptr(new Tree(config_->num_leaves, true, true)); + auto tree = std::unique_ptr(new Tree(this->config_->num_leaves, true, true)); auto tree_ptr = tree.get(); - constraints_->ShareTreePointer(tree_ptr); + this->constraints_->ShareTreePointer(tree_ptr); // root leaf int left_leaf = 0; @@ -85,25 +88,25 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians // only root leaf can be splitted on first time int right_leaf = -1; - int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); + int init_splits = this->ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); - for (int split = init_splits; split < config_->num_leaves - 1; ++split) { + for (int split = init_splits; split < this->config_->num_leaves - 1; ++split) { // some initial works before finding best split - if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { + if (this->BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { // find best threshold for every feature - FindBestSplits(tree_ptr); + this->FindBestSplits(tree_ptr); } // Get a leaf with max split gain - int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_)); + int best_leaf = static_cast(ArrayArgs::ArgMax(this->best_split_per_leaf_)); // Get split information for best leaf - const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf]; + const SplitInfo& best_leaf_SplitInfo = this->best_split_per_leaf_[best_leaf]; // cannot split, quit if (best_leaf_SplitInfo.gain <= 0.0) { Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain); break; } // split tree with best leaf - Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); + this->Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); } @@ -120,21 +123,22 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians GetLeafMap(tree_ptr); if (has_nan) { - CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + CalculateLinear(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree); } else { - CalculateLinear(tree_ptr, false, gradients_, hessians_, is_first_tree); + CalculateLinear(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree); } Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth); return tree.release(); } -Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { - auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians); +template +Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { + auto tree = TREE_LEARNER_TYPE::FitByExistingTree(old_tree, gradients, hessians); bool has_nan = false; if (any_nan_) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { - if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) { has_nan = true; break; } @@ -149,28 +153,31 @@ Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* return tree; } -Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, +template +Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, const score_t* gradients, const score_t *hessians) const { - data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); + this->data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians); } -void LinearTreeLearner::GetLeafMap(Tree* tree) const { +template +void LinearTreeLearner::GetLeafMap(Tree* tree) const { std::fill(leaf_map_.begin(), leaf_map_.end(), -1); // map data to leaf number - const data_size_t* ind = data_partition_->indices(); + const data_size_t* ind = this->data_partition_->indices(); #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic) for (int i = 0; i < tree->num_leaves(); ++i) { - data_size_t idx = data_partition_->leaf_begin(i); - for (int j = 0; j < data_partition_->leaf_count(i); ++j) { + data_size_t idx = this->data_partition_->leaf_begin(i); + for (int j = 0; j < this->data_partition_->leaf_count(i); ++j) { leaf_map_[ind[idx + j]] = i; } } } -template -void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { +template +template +void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { tree->SetIsLinear(true); int num_leaves = tree->num_leaves(); int num_threads = OMP_NUM_THREADS(); @@ -209,11 +216,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t std::vector numerical_features; std::vector data_ptr; for (size_t j = 0; j < raw_features.size(); ++j) { - int feat = train_data_->InnerFeatureIndex(raw_features[j]); - auto bin_mapper = train_data_->FeatureBinMapper(feat); + int feat = this->train_data_->InnerFeatureIndex(raw_features[j]); + auto bin_mapper = this->train_data_->FeatureBinMapper(feat); if (bin_mapper->bin_type() == BinType::NumericalBin) { numerical_features.push_back(feat); - data_ptr.push_back(train_data_->raw_index(feat)); + data_ptr.push_back(this->train_data_->raw_index(feat)); } } leaf_features.push_back(numerical_features); @@ -245,12 +252,12 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t } } OMP_INIT_EX(); -#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (num_data_ > 1024) +#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (this->num_data_ > 1024) { std::vector curr_row(max_num_features + 1); int tid = omp_get_thread_num(); #pragma omp for schedule(static) - for (int i = 0; i < num_data_; ++i) { + for (int i = 0; i < this->num_data_; ++i) { OMP_LOOP_EX_BEGIN(); int leaf_num = leaf_map_[i]; if (leaf_num < 0) { @@ -312,11 +319,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t } if (!HAS_NAN) { for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { - total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num); + total_nonzero[leaf_num] = this->data_partition_->leaf_count(leaf_num); } } double shrinkage = tree->shrinkage(); - double decay_rate = config_->refit_decay_rate; + double decay_rate = this->config_->refit_decay_rate; // copy into eigen matrices and solve #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { @@ -340,7 +347,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j]; XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2); if ((feat1 == feat2) && (feat1 < num_feat)) { - XTHX_mat(feat1, feat2) += config_->linear_lambda; + XTHX_mat(feat1, feat2) += this->config_->linear_lambda; } ++j; } @@ -366,7 +373,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t tree->SetLeafFeaturesInner(leaf_num, features_new); std::vector features_raw(features_new.size()); for (size_t i = 0; i < features_new.size(); ++i) { - features_raw[i] = train_data_->RealFeatureIndex(features_new[i]); + features_raw[i] = this->train_data_->RealFeatureIndex(features_new[i]); } tree->SetLeafFeatures(leaf_num, features_raw); tree->SetLeafCoeffs(leaf_num, coeffs_vec); @@ -378,4 +385,19 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t } } } + +template void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian); +template void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves); +template Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree); +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const; +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t *hessians) const; + +template void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian); +template void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves); +template Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree); +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const; +template Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector& leaf_pred, + const score_t* gradients, const score_t *hessians) const; + } // namespace LightGBM diff --git a/src/treelearner/linear_tree_learner.h b/src/treelearner/linear_tree_learner.h index e20a80ad4..376040cc6 100644 --- a/src/treelearner/linear_tree_learner.h +++ b/src/treelearner/linear_tree_learner.h @@ -11,13 +11,15 @@ #include #include +#include "gpu_tree_learner.h" #include "serial_tree_learner.h" namespace LightGBM { -class LinearTreeLearner: public SerialTreeLearner { +template +class LinearTreeLearner: public TREE_LEARNER_TYPE { public: - explicit LinearTreeLearner(const Config* config) : SerialTreeLearner(config) {} + explicit LinearTreeLearner(const Config* config) : TREE_LEARNER_TYPE(config) {} void Init(const Dataset* train_data, bool is_constant_hessian) override; @@ -38,12 +40,12 @@ class LinearTreeLearner: public SerialTreeLearner { void AddPredictionToScore(const Tree* tree, double* out_score) const override { - CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); + CHECK_LE(tree->num_leaves(), this->data_partition_->num_leaves()); bool has_nan = false; if (any_nan_) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { // use split_feature because split_feature_inner doesn't work when refitting existing tree - if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { + if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) { has_nan = true; break; } @@ -69,13 +71,13 @@ class LinearTreeLearner: public SerialTreeLearner { leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num); leaf_output[leaf_num] = tree->LeafOutput(leaf_num); for (int feat : tree->LeafFeaturesInner(leaf_num)) { - feat_ptr[leaf_num].push_back(train_data_->raw_index(feat)); + feat_ptr[leaf_num].push_back(this->train_data_->raw_index(feat)); } leaf_num_features[leaf_num] = static_cast(feat_ptr[leaf_num].size()); } OMP_INIT_EX(); -#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (num_data_ > 1024) - for (int i = 0; i < num_data_; ++i) { +#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (this->num_data_ > 1024) + for (int i = 0; i < this->num_data_; ++i) { OMP_LOOP_EX_BEGIN(); int leaf_num = leaf_map_[i]; if (leaf_num < 0) { diff --git a/src/treelearner/tree_learner.cpp b/src/treelearner/tree_learner.cpp index 2854b9876..7e8b9921e 100644 --- a/src/treelearner/tree_learner.cpp +++ b/src/treelearner/tree_learner.cpp @@ -17,7 +17,7 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con if (device_type == std::string("cpu")) { if (learner_type == std::string("serial")) { if (config->linear_tree) { - return new LinearTreeLearner(config); + return new LinearTreeLearner(config); } else { return new SerialTreeLearner(config); } @@ -30,7 +30,11 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con } } else if (device_type == std::string("gpu")) { if (learner_type == std::string("serial")) { - return new GPUTreeLearner(config); + if (config->linear_tree) { + return new LinearTreeLearner(config); + } else { + return new GPUTreeLearner(config); + } } else if (learner_type == std::string("feature")) { return new FeatureParallelTreeLearner(config); } else if (learner_type == std::string("data")) { diff --git a/windows/LightGBM.sln b/windows/LightGBM.sln index 42b090bec..ff9a97a0e 100644 --- a/windows/LightGBM.sln +++ b/windows/LightGBM.sln @@ -1,7 +1,7 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 14 VisualStudioVersion = 14.0.25420.1 -MinimumVisualStudioVersion = 10.0.40219.1 +MinimumVisualStudioVersion = 14.0.23107.0 Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LightGBM", "LightGBM.vcxproj", "{F31C0B5D-715E-4953-AA1B-8D2AEEE4344C}" EndProject Global diff --git a/windows/LightGBM.vcxproj b/windows/LightGBM.vcxproj index 009c74496..bb4818572 100644 --- a/windows/LightGBM.vcxproj +++ b/windows/LightGBM.vcxproj @@ -106,7 +106,7 @@ - USE_MPI;%(PreprocessorDefinitions) + USE_MPI;USE_DEBUG;%(PreprocessorDefinitions) Level4 true Neither @@ -117,7 +117,7 @@ Disabled MultiThreadedDebugDLL true - $(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories) + $(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories) @@ -129,7 +129,7 @@ - USE_SOCKET;%(PreprocessorDefinitions) + USE_SOCKET;USE_DEBUG;%(PreprocessorDefinitions) Level4 true Neither @@ -140,7 +140,7 @@ Disabled MultiThreadedDebugDLL true - $(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories) + $(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories) @@ -149,7 +149,7 @@ - USE_SOCKET;%(PreprocessorDefinitions) + USE_SOCKET;USE_DEBUG;%(PreprocessorDefinitions) Level4 true Neither @@ -160,7 +160,7 @@ Disabled MultiThreadedDebugDLL true - $(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories) + $(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories) @@ -183,7 +183,7 @@ true MultiThreadedDLL true - $(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories) + $(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories) @@ -210,7 +210,7 @@ true true true - $(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories) + $(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories) @@ -231,7 +231,7 @@ true true true - $(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories) + $(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories)