This commit is contained in:
Yu Shi 2024-10-25 06:29:46 +00:00
Родитель d11991aaf3 f30ee85beb
Коммит 4bb4411d1e
8 изменённых файлов: 118 добавлений и 80 удалений

Просмотреть файл

@ -6,7 +6,6 @@ option(USE_TIMETAG "Set to ON to output time costs" OFF)
option(USE_CUDA "Enable CUDA-accelerated training " OFF) option(USE_CUDA "Enable CUDA-accelerated training " OFF)
option(USE_DEBUG "Set to ON for Debug mode" OFF) option(USE_DEBUG "Set to ON for Debug mode" OFF)
option(USE_SANITIZER "Use santizer flags" OFF) option(USE_SANITIZER "Use santizer flags" OFF)
option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON)
set( set(
ENABLED_SANITIZERS ENABLED_SANITIZERS
"address" "leak" "undefined" "address" "leak" "undefined"
@ -15,7 +14,8 @@ set(
"Semicolon separated list of sanitizer names, e.g., 'address;leak'. \ "Semicolon separated list of sanitizer names, e.g., 'address;leak'. \
Supported sanitizers are address, leak, undefined and thread." Supported sanitizers are address, leak, undefined and thread."
) )
option(BUILD_CLI "Build the 'lightbgm' command-line interface in addition to lib_lightgbm" ON) option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON)
option(BUILD_CLI "Build the 'lightgbm' command-line interface in addition to lib_lightgbm" ON)
option(BUILD_CPP_TEST "Build C++ tests with Google Test" OFF) option(BUILD_CPP_TEST "Build C++ tests with Google Test" OFF)
option(BUILD_STATIC_LIB "Build static library" OFF) option(BUILD_STATIC_LIB "Build static library" OFF)
option(INSTALL_HEADERS "Install headers to CMAKE_INSTALL_PREFIX (e.g. '/usr/local/include')" ON) option(INSTALL_HEADERS "Install headers to CMAKE_INSTALL_PREFIX (e.g. '/usr/local/include')" ON)

Просмотреть файл

@ -63,16 +63,6 @@ External (Unofficial) Repositories
Projects listed here offer alternative ways to use LightGBM. Projects listed here offer alternative ways to use LightGBM.
They are not maintained or officially endorsed by the `LightGBM` development team. They are not maintained or officially endorsed by the `LightGBM` development team.
LightGBMLSS (An extension of LightGBM to probabilistic modelling from which prediction intervals and quantiles can be derived): https://github.com/StatMixedML/LightGBMLSS
FLAML (AutoML library for hyperparameter optimization): https://github.com/microsoft/FLAML
supertree (interactive visualization of decision trees): https://github.com/mljar/supertree
Optuna (hyperparameter optimization framework): https://github.com/optuna/optuna
Julia-package: https://github.com/IQVIA-ML/LightGBM.jl
JPMML (Java PMML converter): https://github.com/jpmml/jpmml-lightgbm JPMML (Java PMML converter): https://github.com/jpmml/jpmml-lightgbm
Nyoka (Python PMML converter): https://github.com/SoftwareAG/nyoka Nyoka (Python PMML converter): https://github.com/SoftwareAG/nyoka
@ -99,6 +89,8 @@ Shapash (model visualization and interpretation): https://github.com/MAIF/shapas
dtreeviz (decision tree visualization and model interpretation): https://github.com/parrt/dtreeviz dtreeviz (decision tree visualization and model interpretation): https://github.com/parrt/dtreeviz
supertree (interactive visualization of decision trees): https://github.com/mljar/supertree
SynapseML (LightGBM on Spark): https://github.com/microsoft/SynapseML SynapseML (LightGBM on Spark): https://github.com/microsoft/SynapseML
Kubeflow Fairing (LightGBM on Kubernetes): https://github.com/kubeflow/fairing Kubeflow Fairing (LightGBM on Kubernetes): https://github.com/kubeflow/fairing
@ -113,14 +105,32 @@ ML.NET (.NET/C#-package): https://github.com/dotnet/machinelearning
LightGBM.NET (.NET/C#-package): https://github.com/rca22/LightGBM.Net LightGBM.NET (.NET/C#-package): https://github.com/rca22/LightGBM.Net
Ruby gem: https://github.com/ankane/lightgbm-ruby LightGBM Ruby (Ruby gem): https://github.com/ankane/lightgbm-ruby
LightGBM4j (Java high-level binding): https://github.com/metarank/lightgbm4j LightGBM4j (Java high-level binding): https://github.com/metarank/lightgbm4j
LightGBM4J (JVM interface for LightGBM written in Scala): https://github.com/seek-oss/lightgbm4j
Julia-package: https://github.com/IQVIA-ML/LightGBM.jl
lightgbm3 (Rust binding): https://github.com/Mottl/lightgbm3-rs lightgbm3 (Rust binding): https://github.com/Mottl/lightgbm3-rs
MLServer (inference server for LightGBM): https://github.com/SeldonIO/MLServer
MLflow (experiment tracking, model monitoring framework): https://github.com/mlflow/mlflow MLflow (experiment tracking, model monitoring framework): https://github.com/mlflow/mlflow
FLAML (AutoML library for hyperparameter optimization): https://github.com/microsoft/FLAML
MLJAR AutoML (AutoML on tabular data): https://github.com/mljar/mljar-supervised
Optuna (hyperparameter optimization framework): https://github.com/optuna/optuna
LightGBMLSS (probabilistic modelling with LightGBM): https://github.com/StatMixedML/LightGBMLSS
mlforecast (time series forecasting with LightGBM): https://github.com/Nixtla/mlforecast
skforecast (time series forecasting with LightGBM): https://github.com/JoaquinAmatRodrigo/skforecast
`{bonsai}` (R `{parsnip}`-compliant interface): https://github.com/tidymodels/bonsai `{bonsai}` (R `{parsnip}`-compliant interface): https://github.com/tidymodels/bonsai
`{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners `{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners

Просмотреть файл

@ -397,7 +397,7 @@ void Config::CheckParamConflict(const std::unordered_map<std::string, std::strin
} }
} }
if (device_type == std::string("gpu")) { if (device_type == std::string("gpu")) {
// force col-wise for gpu, and cuda version // force col-wise for gpu version
force_col_wise = true; force_col_wise = true;
force_row_wise = false; force_row_wise = false;
if (deterministic) { if (deterministic) {
@ -417,9 +417,9 @@ void Config::CheckParamConflict(const std::unordered_map<std::string, std::strin
} }
// linear tree learner must be serial type and run on CPU device // linear tree learner must be serial type and run on CPU device
if (linear_tree) { if (linear_tree) {
if (device_type != std::string("cpu")) { if (device_type != std::string("cpu") && device_type != std::string("gpu")) {
device_type = "cpu"; device_type = "cpu";
Log::Warning("Linear tree learner only works with CPU."); Log::Warning("Linear tree learner only works with CPU and GPU. Falling back to CPU now.");
} }
if (tree_learner != std::string("serial")) { if (tree_learner != std::string("serial")) {
tree_learner = "serial"; tree_learner = "serial";

Просмотреть файл

@ -10,20 +10,22 @@
namespace LightGBM { namespace LightGBM {
void LinearTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { template <typename TREE_LEARNER_TYPE>
SerialTreeLearner::Init(train_data, is_constant_hessian); void LinearTreeLearner<TREE_LEARNER_TYPE>::Init(const Dataset* train_data, bool is_constant_hessian) {
LinearTreeLearner::InitLinear(train_data, config_->num_leaves); TREE_LEARNER_TYPE::Init(train_data, is_constant_hessian);
LinearTreeLearner::InitLinear(train_data, this->config_->num_leaves);
} }
void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leaves) { template <typename TREE_LEARNER_TYPE>
void LinearTreeLearner<TREE_LEARNER_TYPE>::InitLinear(const Dataset* train_data, const int max_leaves) {
leaf_map_ = std::vector<int>(train_data->num_data(), -1); leaf_map_ = std::vector<int>(train_data->num_data(), -1);
contains_nan_ = std::vector<int8_t>(train_data->num_features(), 0); contains_nan_ = std::vector<int8_t>(train_data->num_features(), 0);
// identify features containing nans // identify features containing nans
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int feat = 0; feat < train_data->num_features(); ++feat) { for (int feat = 0; feat < train_data->num_features(); ++feat) {
auto bin_mapper = train_data_->FeatureBinMapper(feat); auto bin_mapper = this->train_data_->FeatureBinMapper(feat);
if (bin_mapper->bin_type() == BinType::NumericalBin) { if (bin_mapper->bin_type() == BinType::NumericalBin) {
const float* feat_ptr = train_data_->raw_index(feat); const float* feat_ptr = this->train_data_->raw_index(feat);
for (int i = 0; i < train_data->num_data(); ++i) { for (int i = 0; i < train_data->num_data(); ++i) {
if (std::isnan(feat_ptr[i])) { if (std::isnan(feat_ptr[i])) {
contains_nan_[feat] = 1; contains_nan_[feat] = 1;
@ -40,7 +42,7 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav
} }
} }
// preallocate the matrix used to calculate linear model coefficients // preallocate the matrix used to calculate linear model coefficients
int max_num_feat = std::min(max_leaves, train_data_->num_numeric_features()); int max_num_feat = std::min(max_leaves, this->train_data_->num_numeric_features());
XTHX_.clear(); XTHX_.clear();
XTg_.clear(); XTg_.clear();
for (int i = 0; i < max_leaves; ++i) { for (int i = 0; i < max_leaves; ++i) {
@ -59,25 +61,26 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav
} }
} }
Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) { template <typename TREE_LEARNER_TYPE>
Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) {
Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer); Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
gradients_ = gradients; this->gradients_ = gradients;
hessians_ = hessians; this->hessians_ = hessians;
int num_threads = OMP_NUM_THREADS(); int num_threads = OMP_NUM_THREADS();
if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) { if (this->share_state_->num_threads != num_threads && this->share_state_->num_threads > 0) {
Log::Warning( Log::Warning(
"Detected that num_threads changed during training (from %d to %d), " "Detected that num_threads changed during training (from %d to %d), "
"it may cause unexpected errors.", "it may cause unexpected errors.",
share_state_->num_threads, num_threads); this->share_state_->num_threads, num_threads);
} }
share_state_->num_threads = num_threads; this->share_state_->num_threads = num_threads;
// some initial works before training // some initial works before training
BeforeTrain(); this->BeforeTrain();
auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, true, true)); auto tree = std::unique_ptr<Tree>(new Tree(this->config_->num_leaves, true, true));
auto tree_ptr = tree.get(); auto tree_ptr = tree.get();
constraints_->ShareTreePointer(tree_ptr); this->constraints_->ShareTreePointer(tree_ptr);
// root leaf // root leaf
int left_leaf = 0; int left_leaf = 0;
@ -85,25 +88,25 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians
// only root leaf can be splitted on first time // only root leaf can be splitted on first time
int right_leaf = -1; int right_leaf = -1;
int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth); int init_splits = this->ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth);
for (int split = init_splits; split < config_->num_leaves - 1; ++split) { for (int split = init_splits; split < this->config_->num_leaves - 1; ++split) {
// some initial works before finding best split // some initial works before finding best split
if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) { if (this->BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) {
// find best threshold for every feature // find best threshold for every feature
FindBestSplits(tree_ptr); this->FindBestSplits(tree_ptr);
} }
// Get a leaf with max split gain // Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_)); int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(this->best_split_per_leaf_));
// Get split information for best leaf // Get split information for best leaf
const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf]; const SplitInfo& best_leaf_SplitInfo = this->best_split_per_leaf_[best_leaf];
// cannot split, quit // cannot split, quit
if (best_leaf_SplitInfo.gain <= 0.0) { if (best_leaf_SplitInfo.gain <= 0.0) {
Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain); Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain);
break; break;
} }
// split tree with best leaf // split tree with best leaf
Split(tree_ptr, best_leaf, &left_leaf, &right_leaf); this->Split(tree_ptr, best_leaf, &left_leaf, &right_leaf);
cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf)); cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
} }
@ -120,21 +123,22 @@ Tree* LinearTreeLearner::Train(const score_t* gradients, const score_t *hessians
GetLeafMap(tree_ptr); GetLeafMap(tree_ptr);
if (has_nan) { if (has_nan) {
CalculateLinear<true>(tree_ptr, false, gradients_, hessians_, is_first_tree); CalculateLinear<true>(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree);
} else { } else {
CalculateLinear<false>(tree_ptr, false, gradients_, hessians_, is_first_tree); CalculateLinear<false>(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree);
} }
Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth); Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth);
return tree.release(); return tree.release();
} }
Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const { template <typename TREE_LEARNER_TYPE>
auto tree = SerialTreeLearner::FitByExistingTree(old_tree, gradients, hessians); Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const {
auto tree = TREE_LEARNER_TYPE::FitByExistingTree(old_tree, gradients, hessians);
bool has_nan = false; bool has_nan = false;
if (any_nan_) { if (any_nan_) {
for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) {
has_nan = true; has_nan = true;
break; break;
} }
@ -149,28 +153,31 @@ Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
return tree; return tree;
} }
Tree* LinearTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred, template <typename TREE_LEARNER_TYPE>
Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t *hessians) const { const score_t* gradients, const score_t *hessians) const {
data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves()); this->data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves());
return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians); return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians);
} }
void LinearTreeLearner::GetLeafMap(Tree* tree) const { template <typename TREE_LEARNER_TYPE>
void LinearTreeLearner<TREE_LEARNER_TYPE>::GetLeafMap(Tree* tree) const {
std::fill(leaf_map_.begin(), leaf_map_.end(), -1); std::fill(leaf_map_.begin(), leaf_map_.end(), -1);
// map data to leaf number // map data to leaf number
const data_size_t* ind = data_partition_->indices(); const data_size_t* ind = this->data_partition_->indices();
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic)
for (int i = 0; i < tree->num_leaves(); ++i) { for (int i = 0; i < tree->num_leaves(); ++i) {
data_size_t idx = data_partition_->leaf_begin(i); data_size_t idx = this->data_partition_->leaf_begin(i);
for (int j = 0; j < data_partition_->leaf_count(i); ++j) { for (int j = 0; j < this->data_partition_->leaf_count(i); ++j) {
leaf_map_[ind[idx + j]] = i; leaf_map_[ind[idx + j]] = i;
} }
} }
} }
template<bool HAS_NAN> template<typename TREE_LEARNER_TYPE>
void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const { template <bool HAS_NAN>
void LinearTreeLearner<TREE_LEARNER_TYPE>::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const {
tree->SetIsLinear(true); tree->SetIsLinear(true);
int num_leaves = tree->num_leaves(); int num_leaves = tree->num_leaves();
int num_threads = OMP_NUM_THREADS(); int num_threads = OMP_NUM_THREADS();
@ -209,11 +216,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
std::vector<int> numerical_features; std::vector<int> numerical_features;
std::vector<const float*> data_ptr; std::vector<const float*> data_ptr;
for (size_t j = 0; j < raw_features.size(); ++j) { for (size_t j = 0; j < raw_features.size(); ++j) {
int feat = train_data_->InnerFeatureIndex(raw_features[j]); int feat = this->train_data_->InnerFeatureIndex(raw_features[j]);
auto bin_mapper = train_data_->FeatureBinMapper(feat); auto bin_mapper = this->train_data_->FeatureBinMapper(feat);
if (bin_mapper->bin_type() == BinType::NumericalBin) { if (bin_mapper->bin_type() == BinType::NumericalBin) {
numerical_features.push_back(feat); numerical_features.push_back(feat);
data_ptr.push_back(train_data_->raw_index(feat)); data_ptr.push_back(this->train_data_->raw_index(feat));
} }
} }
leaf_features.push_back(numerical_features); leaf_features.push_back(numerical_features);
@ -245,12 +252,12 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
} }
} }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (num_data_ > 1024) #pragma omp parallel num_threads(OMP_NUM_THREADS()) if (this->num_data_ > 1024)
{ {
std::vector<float> curr_row(max_num_features + 1); std::vector<float> curr_row(max_num_features + 1);
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
#pragma omp for schedule(static) #pragma omp for schedule(static)
for (int i = 0; i < num_data_; ++i) { for (int i = 0; i < this->num_data_; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
int leaf_num = leaf_map_[i]; int leaf_num = leaf_map_[i];
if (leaf_num < 0) { if (leaf_num < 0) {
@ -312,11 +319,11 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
} }
if (!HAS_NAN) { if (!HAS_NAN) {
for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
total_nonzero[leaf_num] = data_partition_->leaf_count(leaf_num); total_nonzero[leaf_num] = this->data_partition_->leaf_count(leaf_num);
} }
} }
double shrinkage = tree->shrinkage(); double shrinkage = tree->shrinkage();
double decay_rate = config_->refit_decay_rate; double decay_rate = this->config_->refit_decay_rate;
// copy into eigen matrices and solve // copy into eigen matrices and solve
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) { for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
@ -340,7 +347,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j]; XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j];
XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2); XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2);
if ((feat1 == feat2) && (feat1 < num_feat)) { if ((feat1 == feat2) && (feat1 < num_feat)) {
XTHX_mat(feat1, feat2) += config_->linear_lambda; XTHX_mat(feat1, feat2) += this->config_->linear_lambda;
} }
++j; ++j;
} }
@ -366,7 +373,7 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
tree->SetLeafFeaturesInner(leaf_num, features_new); tree->SetLeafFeaturesInner(leaf_num, features_new);
std::vector<int> features_raw(features_new.size()); std::vector<int> features_raw(features_new.size());
for (size_t i = 0; i < features_new.size(); ++i) { for (size_t i = 0; i < features_new.size(); ++i) {
features_raw[i] = train_data_->RealFeatureIndex(features_new[i]); features_raw[i] = this->train_data_->RealFeatureIndex(features_new[i]);
} }
tree->SetLeafFeatures(leaf_num, features_raw); tree->SetLeafFeatures(leaf_num, features_raw);
tree->SetLeafCoeffs(leaf_num, coeffs_vec); tree->SetLeafCoeffs(leaf_num, coeffs_vec);
@ -378,4 +385,19 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
} }
} }
} }
template void LinearTreeLearner<SerialTreeLearner>::Init(const Dataset* train_data, bool is_constant_hessian);
template void LinearTreeLearner<SerialTreeLearner>::InitLinear(const Dataset* train_data, const int max_leaves);
template Tree* LinearTreeLearner<SerialTreeLearner>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree);
template Tree* LinearTreeLearner<SerialTreeLearner>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const;
template Tree* LinearTreeLearner<SerialTreeLearner>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t *hessians) const;
template void LinearTreeLearner<GPUTreeLearner>::Init(const Dataset* train_data, bool is_constant_hessian);
template void LinearTreeLearner<GPUTreeLearner>::InitLinear(const Dataset* train_data, const int max_leaves);
template Tree* LinearTreeLearner<GPUTreeLearner>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree);
template Tree* LinearTreeLearner<GPUTreeLearner>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const;
template Tree* LinearTreeLearner<GPUTreeLearner>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
const score_t* gradients, const score_t *hessians) const;
} // namespace LightGBM } // namespace LightGBM

Просмотреть файл

@ -11,13 +11,15 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "gpu_tree_learner.h"
#include "serial_tree_learner.h" #include "serial_tree_learner.h"
namespace LightGBM { namespace LightGBM {
class LinearTreeLearner: public SerialTreeLearner { template <typename TREE_LEARNER_TYPE>
class LinearTreeLearner: public TREE_LEARNER_TYPE {
public: public:
explicit LinearTreeLearner(const Config* config) : SerialTreeLearner(config) {} explicit LinearTreeLearner(const Config* config) : TREE_LEARNER_TYPE(config) {}
void Init(const Dataset* train_data, bool is_constant_hessian) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
@ -38,12 +40,12 @@ class LinearTreeLearner: public SerialTreeLearner {
void AddPredictionToScore(const Tree* tree, void AddPredictionToScore(const Tree* tree,
double* out_score) const override { double* out_score) const override {
CHECK_LE(tree->num_leaves(), data_partition_->num_leaves()); CHECK_LE(tree->num_leaves(), this->data_partition_->num_leaves());
bool has_nan = false; bool has_nan = false;
if (any_nan_) { if (any_nan_) {
for (int i = 0; i < tree->num_leaves() - 1 ; ++i) { for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
// use split_feature because split_feature_inner doesn't work when refitting existing tree // use split_feature because split_feature_inner doesn't work when refitting existing tree
if (contains_nan_[train_data_->InnerFeatureIndex(tree->split_feature(i))]) { if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) {
has_nan = true; has_nan = true;
break; break;
} }
@ -69,13 +71,13 @@ class LinearTreeLearner: public SerialTreeLearner {
leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num); leaf_coeff[leaf_num] = tree->LeafCoeffs(leaf_num);
leaf_output[leaf_num] = tree->LeafOutput(leaf_num); leaf_output[leaf_num] = tree->LeafOutput(leaf_num);
for (int feat : tree->LeafFeaturesInner(leaf_num)) { for (int feat : tree->LeafFeaturesInner(leaf_num)) {
feat_ptr[leaf_num].push_back(train_data_->raw_index(feat)); feat_ptr[leaf_num].push_back(this->train_data_->raw_index(feat));
} }
leaf_num_features[leaf_num] = static_cast<int>(feat_ptr[leaf_num].size()); leaf_num_features[leaf_num] = static_cast<int>(feat_ptr[leaf_num].size());
} }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (num_data_ > 1024) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) if (this->num_data_ > 1024)
for (int i = 0; i < num_data_; ++i) { for (int i = 0; i < this->num_data_; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
int leaf_num = leaf_map_[i]; int leaf_num = leaf_map_[i];
if (leaf_num < 0) { if (leaf_num < 0) {

Просмотреть файл

@ -17,7 +17,7 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con
if (device_type == std::string("cpu")) { if (device_type == std::string("cpu")) {
if (learner_type == std::string("serial")) { if (learner_type == std::string("serial")) {
if (config->linear_tree) { if (config->linear_tree) {
return new LinearTreeLearner(config); return new LinearTreeLearner<SerialTreeLearner>(config);
} else { } else {
return new SerialTreeLearner(config); return new SerialTreeLearner(config);
} }
@ -30,7 +30,11 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con
} }
} else if (device_type == std::string("gpu")) { } else if (device_type == std::string("gpu")) {
if (learner_type == std::string("serial")) { if (learner_type == std::string("serial")) {
return new GPUTreeLearner(config); if (config->linear_tree) {
return new LinearTreeLearner<GPUTreeLearner>(config);
} else {
return new GPUTreeLearner(config);
}
} else if (learner_type == std::string("feature")) { } else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<GPUTreeLearner>(config); return new FeatureParallelTreeLearner<GPUTreeLearner>(config);
} else if (learner_type == std::string("data")) { } else if (learner_type == std::string("data")) {

Просмотреть файл

@ -1,7 +1,7 @@
Microsoft Visual Studio Solution File, Format Version 12.00 Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 14 # Visual Studio 14
VisualStudioVersion = 14.0.25420.1 VisualStudioVersion = 14.0.25420.1
MinimumVisualStudioVersion = 10.0.40219.1 MinimumVisualStudioVersion = 14.0.23107.0
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LightGBM", "LightGBM.vcxproj", "{F31C0B5D-715E-4953-AA1B-8D2AEEE4344C}" Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LightGBM", "LightGBM.vcxproj", "{F31C0B5D-715E-4953-AA1B-8D2AEEE4344C}"
EndProject EndProject
Global Global

Просмотреть файл

@ -106,7 +106,7 @@
</ItemDefinitionGroup> </ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_mpi|x64'"> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_mpi|x64'">
<ClCompile> <ClCompile>
<PreprocessorDefinitions>USE_MPI;%(PreprocessorDefinitions)</PreprocessorDefinitions> <PreprocessorDefinitions>USE_MPI;USE_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<WarningLevel>Level4</WarningLevel> <WarningLevel>Level4</WarningLevel>
<OpenMPSupport>true</OpenMPSupport> <OpenMPSupport>true</OpenMPSupport>
<FavorSizeOrSpeed>Neither</FavorSizeOrSpeed> <FavorSizeOrSpeed>Neither</FavorSizeOrSpeed>
@ -117,7 +117,7 @@
<Optimization>Disabled</Optimization> <Optimization>Disabled</Optimization>
<RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary> <RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary>
<MultiProcessorCompilation>true</MultiProcessorCompilation> <MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile> </ClCompile>
<Link> <Link>
<AdditionalLibraryDirectories> <AdditionalLibraryDirectories>
@ -129,7 +129,7 @@
</ItemDefinitionGroup> </ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile> <ClCompile>
<PreprocessorDefinitions>USE_SOCKET;%(PreprocessorDefinitions)</PreprocessorDefinitions> <PreprocessorDefinitions>USE_SOCKET;USE_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<WarningLevel>Level4</WarningLevel> <WarningLevel>Level4</WarningLevel>
<OpenMPSupport>true</OpenMPSupport> <OpenMPSupport>true</OpenMPSupport>
<FavorSizeOrSpeed>Neither</FavorSizeOrSpeed> <FavorSizeOrSpeed>Neither</FavorSizeOrSpeed>
@ -140,7 +140,7 @@
<Optimization>Disabled</Optimization> <Optimization>Disabled</Optimization>
<RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary> <RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary>
<MultiProcessorCompilation>true</MultiProcessorCompilation> <MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile> </ClCompile>
<Link> <Link>
<AdditionalDependencies> <AdditionalDependencies>
@ -149,7 +149,7 @@
</ItemDefinitionGroup> </ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_DLL|x64'"> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_DLL|x64'">
<ClCompile> <ClCompile>
<PreprocessorDefinitions>USE_SOCKET;%(PreprocessorDefinitions)</PreprocessorDefinitions> <PreprocessorDefinitions>USE_SOCKET;USE_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<WarningLevel>Level4</WarningLevel> <WarningLevel>Level4</WarningLevel>
<OpenMPSupport>true</OpenMPSupport> <OpenMPSupport>true</OpenMPSupport>
<FavorSizeOrSpeed>Neither</FavorSizeOrSpeed> <FavorSizeOrSpeed>Neither</FavorSizeOrSpeed>
@ -160,7 +160,7 @@
<Optimization>Disabled</Optimization> <Optimization>Disabled</Optimization>
<RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary> <RuntimeLibrary>MultiThreadedDebugDLL</RuntimeLibrary>
<MultiProcessorCompilation>true</MultiProcessorCompilation> <MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile> </ClCompile>
<Link> <Link>
<AdditionalDependencies> <AdditionalDependencies>
@ -183,7 +183,7 @@
<OmitFramePointers>true</OmitFramePointers> <OmitFramePointers>true</OmitFramePointers>
<RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary> <RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary>
<MultiProcessorCompilation>true</MultiProcessorCompilation> <MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile> </ClCompile>
<Link> <Link>
<AdditionalLibraryDirectories> <AdditionalLibraryDirectories>
@ -210,7 +210,7 @@
<OmitFramePointers>true</OmitFramePointers> <OmitFramePointers>true</OmitFramePointers>
<FunctionLevelLinking>true</FunctionLevelLinking> <FunctionLevelLinking>true</FunctionLevelLinking>
<MultiProcessorCompilation>true</MultiProcessorCompilation> <MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile> </ClCompile>
<Link> <Link>
<AdditionalDependencies /> <AdditionalDependencies />
@ -231,7 +231,7 @@
<OmitFramePointers>true</OmitFramePointers> <OmitFramePointers>true</OmitFramePointers>
<FunctionLevelLinking>true</FunctionLevelLinking> <FunctionLevelLinking>true</FunctionLevelLinking>
<MultiProcessorCompilation>true</MultiProcessorCompilation> <MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>$(ProjectDir)\..\external_libs\eigen;$(ProjectDir)\..\external_libs\fast_double_parser\include;$(ProjectDir)\..\external_libs\fmt\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile> </ClCompile>
<Link> <Link>
<AdditionalDependencies> <AdditionalDependencies>