This commit is contained in:
Guolin Ke 2016-12-18 16:37:28 +08:00
Родитель 714c673257
Коммит c2e94f1748
14 изменённых файлов: 226 добавлений и 127 удалений

Просмотреть файл

@ -51,12 +51,6 @@ public:
*/
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual void ResetShrinkageRate(double shrinkage_rate) = 0;
/*!
* \brief Add a validation data
* \param valid_data Validation data

Просмотреть файл

@ -22,11 +22,17 @@ public:
virtual ~TreeLearner() {}
/*!
* \brief Initialize tree learner with training dataset and configs
* \brief Initialize tree learner with training dataset
* \param train_data The used training data
*/
virtual void Init(const Dataset* train_data) = 0;
/*!
* \brief Reset tree configs
* \param tree_config config of tree
*/
virtual void ResetConfig(const TreeConfig* tree_config) = 0;
/*!
* \brief training tree model on dataset
* \param gradients The first order gradients
@ -58,9 +64,10 @@ public:
/*!
* \brief Create object of tree learner
* \param type Type of tree learner
* \param tree_config config of tree
*/
static TreeLearner* CreateTreeLearner(TreeLearnerType type,
const TreeConfig& tree_config);
const TreeConfig* tree_config);
};
} // namespace LightGBM

Просмотреть файл

@ -35,7 +35,6 @@ public:
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, object_function, training_metrics);
drop_rate_ = gbdt_config_->drop_rate;
shrinkage_rate_ = 1.0;
random_for_drop_ = Random(gbdt_config_->drop_seed);
}
@ -53,6 +52,14 @@ public:
return false;
}
}
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) {
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
shrinkage_rate_ = 1.0;
random_for_drop_ = Random(gbdt_config_->drop_seed);
}
/*!
* \brief Get current training score
* \param out_len length of returned score
@ -81,9 +88,9 @@ private:
drop_index_.clear();
// select dropping tree indexes based on drop_rate
// if drop rate is too small, skip this step, drop one tree randomly
if (drop_rate_ > kEpsilon) {
if (gbdt_config_->drop_rate > kEpsilon) {
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextDouble() < drop_rate_) {
if (random_for_drop_.NextDouble() < gbdt_config_->drop_rate) {
drop_index_.push_back(i);
}
}
@ -123,8 +130,6 @@ private:
}
/*! \brief The indexes of dropping trees */
std::vector<int> drop_index_;
/*! \brief Dropping rate */
double drop_rate_;
/*! \brief Random generator, used to select dropping trees */
Random random_for_drop_;
/*! \brief Flag that the score is update on current iter or not*/

Просмотреть файл

@ -33,41 +33,57 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_ = 0;
num_class_ = config->num_class;
train_data_ = nullptr;
gbdt_config_ = nullptr;
ResetTrainingData(config, train_data, object_function, training_metrics);
}
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) {
if (train_data == nullptr) { return; }
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
random_ = Random(gbdt_config_->bagging_seed);
// create tree learner
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
random_ = Random(new_config->bagging_seed);
// create tree learner, only create once
if (gbdt_config_ == nullptr) {
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data);
// init tree learner
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
}
// init tree learner
if (train_data_ != train_data) {
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->Init(train_data);
}
}
// reset config for tree learner
for (int i = 0; i < num_class_; ++i) {
tree_learner_[i]->ResetConfig(&new_config->tree_config);
}
object_function_ = object_function;
sigmoid_ = -1.0f;
if (object_function_ != nullptr
&& std::string(object_function_->GetName()) == std::string("binary")) {
// only binary classification need sigmoid transform
sigmoid_ = new_config->sigmoid;
}
if (train_data_ != train_data) {
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
sigmoid_ = -1.0f;
if (object_function_ != nullptr
&& std::string(object_function_->GetName()) == std::string("binary")) {
// only binary classification need sigmoid transform
sigmoid_ = gbdt_config_->sigmoid;
}
if (train_data_ != train_data) {
// not same training data, need reset score and others
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data, num_class_));
@ -88,8 +104,13 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
max_feature_idx_ = train_data->num_total_features() - 1;
// get label index
label_idx_ = train_data->label_idx();
}
if (train_data_ != train_data
|| gbdt_config_ == nullptr
|| (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else {
@ -100,6 +121,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
}
train_data_ = train_data;
gbdt_config_.reset(new_config.release());
}
void GBDT::AddValidDataset(const Dataset* valid_data,

Просмотреть файл

@ -68,14 +68,6 @@ public:
*/
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void ResetShrinkageRate(double shrinkage_rate) override {
shrinkage_rate_ = shrinkage_rate;
}
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
@ -245,7 +237,7 @@ protected:
/*! \brief Pointer to training data */
const Dataset* train_data_;
/*! \brief Config of gbdt */
const BoostingConfig* gbdt_config_;
std::unique_ptr<BoostingConfig> gbdt_config_;
/*! \brief Tree learner, will use this class to learn trees */
std::vector<std::unique_ptr<TreeLearner>> tree_learner_;
/*! \brief Objective function */

Просмотреть файл

@ -40,12 +40,14 @@ public:
Log::Warning("continued train from model is not support for c_api, \
please use continued train with input score");
}
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
ResetTrainingData(train_data);
}
void MergeFrom(const Booster* other) {
@ -60,13 +62,34 @@ public:
void ResetTrainingData(const Dataset* train_data) {
std::lock_guard<std::mutex> lock(mutex_);
train_data_ = train_data;
ConstructObjectAndTrainingMetrics(train_data_);
// initialize the boosting
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function");
}
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
}
// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// reset the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_);
auto param = ConfigBase::Str2Map(parameters);
if (param.count("num_class")) {
Log::Fatal("cannot change num class during training");
@ -77,21 +100,28 @@ public:
if (param.count("metric")) {
Log::Fatal("cannot change metric during training");
}
{
std::lock_guard<std::mutex> lock(mutex_);
config_.Set(param);
}
if (config_.num_threads > 0) {
std::lock_guard<std::mutex> lock(mutex_);
omp_set_num_threads(config_.num_threads);
}
if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
// only need to set learning rate
std::lock_guard<std::mutex> lock(mutex_);
boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
} else {
ResetTrainingData(train_data_);
if (param.count("objective")) {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function");
}
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
}
}
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
void AddValidData(const Dataset* valid_data) {
@ -107,6 +137,7 @@ public:
boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
}
bool TrainOneIter() {
std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(nullptr, nullptr, false);
@ -142,10 +173,12 @@ public:
}
std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
std::lock_guard<std::mutex> lock(mutex_);
return predictor_->GetPredictFunction()(features);
}
void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
std::lock_guard<std::mutex> lock(mutex_);
predictor_->Predict(data_filename, result_filename, data_has_header);
}
@ -180,29 +213,6 @@ public:
private:
void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective functions");
}
// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(train_data->metadata(), train_data->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data->metadata(), train_data->num_data());
}
}
const Dataset* train_data_;
std::unique_ptr<Boosting> boosting_;
/*! \brief All configs */

Просмотреть файл

@ -7,7 +7,7 @@
namespace LightGBM {
DataParallelTreeLearner::DataParallelTreeLearner(const TreeConfig& tree_config)
DataParallelTreeLearner::DataParallelTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) {
}
@ -37,10 +37,13 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) {
buffer_write_start_pos_.resize(num_features_);
buffer_read_start_pos_.resize(num_features_);
global_data_count_in_leaf_.resize(tree_config_.num_leaves);
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
}
void DataParallelTreeLearner::ResetConfig(const TreeConfig* tree_config) {
SerialTreeLearner::ResetConfig(tree_config);
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
}
void DataParallelTreeLearner::BeforeTrain() {
SerialTreeLearner::BeforeTrain();

Просмотреть файл

@ -276,6 +276,10 @@ public:
*/
void set_is_splittable(bool val) { is_splittable_ = val; }
void ResetConfig(const TreeConfig* tree_config) {
tree_config_ = tree_config;
}
private:
/*!
* \brief Calculate the split gain based on regularized sum_gradients and sum_hessians
@ -336,6 +340,8 @@ public:
* \brief Constructor
*/
HistogramPool() {
cache_size_ = 0;
total_size_ = 0;
}
/*!
@ -348,7 +354,7 @@ public:
* \param cache_size Max cache size
* \param total_size Total size will be used
*/
void ResetSize(int cache_size, int total_size) {
void Reset(int cache_size, int total_size) {
cache_size_ = cache_size;
// at least need 2 bucket to store smaller leaf and larger leaf
CHECK(cache_size_ >= 2);
@ -382,6 +388,7 @@ public:
* \param obj_create_fun that used to generate object
*/
void Fill(std::function<FeatureHistogram*()> obj_create_fun) {
fill_func_ = obj_create_fun;
pool_.clear();
pool_.resize(cache_size_);
for (int i = 0; i < cache_size_; ++i) {
@ -389,6 +396,23 @@ public:
}
}
void DynamicChangeSize(int cache_size, int total_size) {
int old_cache_size = cache_size_;
Reset(cache_size, total_size);
pool_.resize(cache_size_);
for (int i = old_cache_size; i < cache_size_; ++i) {
pool_[i].reset(fill_func_());
}
}
void ResetConfig(const TreeConfig* tree_config, int array_size) {
for (int i = 0; i < cache_size_; ++i) {
auto data_ptr = pool_[i].get();
for (int j = 0; j < array_size; ++j) {
data_ptr[j].ResetConfig(tree_config);
}
}
}
/*!
* \brief Get data for the specific index
* \param idx which index want to get
@ -446,6 +470,7 @@ public:
private:
std::vector<std::unique_ptr<FeatureHistogram[]>> pool_;
std::function<FeatureHistogram*()> fill_func_;
int cache_size_;
int total_size_;
bool is_enough_ = false;

Просмотреть файл

@ -6,7 +6,7 @@
namespace LightGBM {
FeatureParallelTreeLearner::FeatureParallelTreeLearner(const TreeConfig& tree_config)
FeatureParallelTreeLearner::FeatureParallelTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) {
}

Просмотреть файл

@ -20,7 +20,7 @@ namespace LightGBM {
*/
class FeatureParallelTreeLearner: public SerialTreeLearner {
public:
explicit FeatureParallelTreeLearner(const TreeConfig& tree_config);
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
~FeatureParallelTreeLearner();
virtual void Init(const Dataset* train_data);
@ -45,9 +45,10 @@ private:
*/
class DataParallelTreeLearner: public SerialTreeLearner {
public:
explicit DataParallelTreeLearner(const TreeConfig& tree_config);
explicit DataParallelTreeLearner(const TreeConfig* tree_config);
~DataParallelTreeLearner();
void Init(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override;
protected:
void BeforeTrain() override;
void FindBestThresholds() override;
@ -96,10 +97,10 @@ private:
*/
class VotingParallelTreeLearner: public SerialTreeLearner {
public:
explicit VotingParallelTreeLearner(const TreeConfig& tree_config);
explicit VotingParallelTreeLearner(const TreeConfig* tree_config);
~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override;
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(int left_leaf, int right_leaf) override;

Просмотреть файл

@ -7,9 +7,9 @@
namespace LightGBM {
SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config)
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
:tree_config_(tree_config){
random_ = Random(tree_config.feature_fraction_seed);
random_ = Random(tree_config_->feature_fraction_seed);
}
SerialTreeLearner::~SerialTreeLearner() {
@ -22,32 +22,32 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
num_features_ = train_data_->num_features();
int max_cache_size = 0;
// Get the max size of pool
if (tree_config_.histogram_pool_size < 0) {
max_cache_size = tree_config_.num_leaves;
if (tree_config_->histogram_pool_size <= 0) {
max_cache_size = tree_config_->num_leaves;
} else {
size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
}
max_cache_size = static_cast<int>(tree_config_.histogram_pool_size * 1024 * 1024 / total_histogram_size);
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
}
// at least need 2 leaves
max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_.num_leaves);
histogram_pool_.ResetSize(max_cache_size, tree_config_.num_leaves);
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
histogram_pool_.Reset(max_cache_size, tree_config_->num_leaves);
auto histogram_create_function = [this]() {
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
for (int j = 0; j < train_data_->num_features(); ++j) {
tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
j, &tree_config_);
j, tree_config_);
}
return tmp_histogram_array.release();
};
histogram_pool_.Fill(histogram_create_function);
// push split information for all leaves
best_split_per_leaf_.resize(tree_config_.num_leaves);
best_split_per_leaf_.resize(tree_config_->num_leaves);
// initialize ordered_bins_ with nullptr
ordered_bins_.resize(num_features_);
@ -69,7 +69,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
// initialize data partition
data_partition_.reset(new DataPartition(num_data_, tree_config_.num_leaves));
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
is_feature_used_.resize(num_features_);
@ -84,19 +84,49 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
}
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
if (tree_config_->num_leaves != tree_config->num_leaves) {
tree_config_ = tree_config;
int max_cache_size = 0;
// Get the max size of pool
if (tree_config->histogram_pool_size <= 0) {
max_cache_size = tree_config_->num_leaves;
} else {
size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
}
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
}
// at least need 2 leaves
max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
histogram_pool_.DynamicChangeSize(max_cache_size, tree_config_->num_leaves);
// push split information for all leaves
best_split_per_leaf_.resize(tree_config_->num_leaves);
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
} else {
tree_config_ = tree_config;
}
histogram_pool_.ResetConfig(tree_config_, train_data_->num_features());
random_ = Random(tree_config_->feature_fraction_seed);
}
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
gradients_ = gradients;
hessians_ = hessians;
// some initial works before training
BeforeTrain();
auto tree = std::unique_ptr<Tree>(new Tree(tree_config_.num_leaves));
auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
// save pointer to last trained tree
last_trained_tree_ = tree.get();
// root leaf
int left_leaf = 0;
// only root leaf can be splitted on first time
int right_leaf = -1;
for (int split = 0; split < tree_config_.num_leaves - 1; split++) {
for (int split = 0; split < tree_config_->num_leaves - 1; split++) {
// some initial works before finding best split
if (BeforeFindBestSplit(left_leaf, right_leaf)) {
// find best threshold for every feature
@ -121,6 +151,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
}
void SerialTreeLearner::BeforeTrain() {
// reset histogram pool
histogram_pool_.ResetMap();
// initialize used features
@ -128,7 +159,7 @@ void SerialTreeLearner::BeforeTrain() {
is_feature_used_[i] = false;
}
// Get used feature at current tree
int used_feature_cnt = static_cast<int>(num_features_*tree_config_.feature_fraction);
int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
for (auto idx : used_feature_indices) {
is_feature_used_[idx] = true;
@ -138,7 +169,7 @@ void SerialTreeLearner::BeforeTrain() {
data_partition_->Init();
// reset the splits for leaves
for (int i = 0; i < tree_config_.num_leaves; ++i) {
for (int i = 0; i < tree_config_->num_leaves; ++i) {
best_split_per_leaf_[i].Reset();
}
@ -177,7 +208,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) {
if (ordered_bins_[i] != nullptr) {
ordered_bins_[i]->Init(nullptr, tree_config_.num_leaves);
ordered_bins_[i]->Init(nullptr, tree_config_->num_leaves);
}
}
} else {
@ -196,7 +227,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) {
if (ordered_bins_[i] != nullptr) {
ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_.num_leaves);
ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
}
}
}
@ -205,9 +236,9 @@ void SerialTreeLearner::BeforeTrain() {
bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
// check depth of current leaf
if (tree_config_.max_depth > 0) {
if (tree_config_->max_depth > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf
if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_.max_depth) {
if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore;
@ -218,8 +249,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// no enough data to continue
if (num_data_in_right_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2)
&& num_data_in_left_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2)) {
if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
&& num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore;

Просмотреть файл

@ -26,12 +26,14 @@ namespace LightGBM {
*/
class SerialTreeLearner: public TreeLearner {
public:
explicit SerialTreeLearner(const TreeConfig& tree_config);
explicit SerialTreeLearner(const TreeConfig* tree_config);
~SerialTreeLearner();
void Init(const Dataset* train_data) override;
void ResetConfig(const TreeConfig* tree_config) override;
Tree* Train(const score_t* gradients, const score_t *hessians) override;
void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override {
@ -153,7 +155,7 @@ protected:
/*! \brief used to cache historical histogram to speed up*/
HistogramPool histogram_pool_;
/*! \brief config of tree learner*/
const TreeConfig& tree_config_;
const TreeConfig* tree_config_;
};

Просмотреть файл

@ -5,7 +5,7 @@
namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConfig& tree_config) {
TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConfig* tree_config) {
if (type == TreeLearnerType::kSerialTreeLearner) {
return new SerialTreeLearner(tree_config);
} else if (type == TreeLearnerType::kFeatureParallelTreelearner) {

Просмотреть файл

@ -9,9 +9,9 @@
namespace LightGBM {
VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig& tree_config)
VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) {
top_k_ = tree_config.top_k;
top_k_ = tree_config_->top_k;
}
void VotingParallelTreeLearner::Init(const Dataset* train_data) {
@ -44,34 +44,41 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) {
smaller_buffer_read_start_pos_.resize(num_features_);
larger_buffer_read_start_pos_.resize(num_features_);
global_data_count_in_leaf_.resize(tree_config_.num_leaves);
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
smaller_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
larger_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
local_tree_config_ = tree_config_;
local_tree_config_ = *tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
auto histogram_create_function = [this]() {
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
for (int j = 0; j < train_data_->num_features(); ++j) {
tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
j, &local_tree_config_);
}
return tmp_histogram_array.release();
};
histogram_pool_.Fill(histogram_create_function);
histogram_pool_.ResetConfig(&local_tree_config_, train_data_->num_features());
// initialize histograms for global
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
larger_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
for (int j = 0; j < num_features_; ++j) {
smaller_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_);
larger_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_);
smaller_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, tree_config_);
larger_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, tree_config_);
}
}
void VotingParallelTreeLearner::ResetConfig(const TreeConfig* tree_config) {
SerialTreeLearner::ResetConfig(tree_config);
local_tree_config_ = *tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
histogram_pool_.ResetConfig(&local_tree_config_, train_data_->num_features());
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
for (int j = 0; j < num_features_; ++j) {
smaller_leaf_histogram_array_global_[j].ResetConfig(tree_config_);
larger_leaf_histogram_array_global_[j].ResetConfig(tree_config_);
}
}
void VotingParallelTreeLearner::BeforeTrain() {
SerialTreeLearner::BeforeTrain();