зеркало из https://github.com/microsoft/LightGBM.git
refine reset_parameters logic
This commit is contained in:
Родитель
714c673257
Коммит
c2e94f1748
|
@ -51,12 +51,6 @@ public:
|
|||
*/
|
||||
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Reset shrinkage_rate data for current boosting
|
||||
* \param shrinkage_rate Configs for boosting
|
||||
*/
|
||||
virtual void ResetShrinkageRate(double shrinkage_rate) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Add a validation data
|
||||
* \param valid_data Validation data
|
||||
|
|
|
@ -22,11 +22,17 @@ public:
|
|||
virtual ~TreeLearner() {}
|
||||
|
||||
/*!
|
||||
* \brief Initialize tree learner with training dataset and configs
|
||||
* \brief Initialize tree learner with training dataset
|
||||
* \param train_data The used training data
|
||||
*/
|
||||
virtual void Init(const Dataset* train_data) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Reset tree configs
|
||||
* \param tree_config config of tree
|
||||
*/
|
||||
virtual void ResetConfig(const TreeConfig* tree_config) = 0;
|
||||
|
||||
/*!
|
||||
* \brief training tree model on dataset
|
||||
* \param gradients The first order gradients
|
||||
|
@ -58,9 +64,10 @@ public:
|
|||
/*!
|
||||
* \brief Create object of tree learner
|
||||
* \param type Type of tree learner
|
||||
* \param tree_config config of tree
|
||||
*/
|
||||
static TreeLearner* CreateTreeLearner(TreeLearnerType type,
|
||||
const TreeConfig& tree_config);
|
||||
const TreeConfig* tree_config);
|
||||
};
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -35,7 +35,6 @@ public:
|
|||
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
|
||||
const std::vector<const Metric*>& training_metrics) override {
|
||||
GBDT::Init(config, train_data, object_function, training_metrics);
|
||||
drop_rate_ = gbdt_config_->drop_rate;
|
||||
shrinkage_rate_ = 1.0;
|
||||
random_for_drop_ = Random(gbdt_config_->drop_seed);
|
||||
}
|
||||
|
@ -53,6 +52,14 @@ public:
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
|
||||
const std::vector<const Metric*>& training_metrics) {
|
||||
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
|
||||
shrinkage_rate_ = 1.0;
|
||||
random_for_drop_ = Random(gbdt_config_->drop_seed);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get current training score
|
||||
* \param out_len length of returned score
|
||||
|
@ -81,9 +88,9 @@ private:
|
|||
drop_index_.clear();
|
||||
// select dropping tree indexes based on drop_rate
|
||||
// if drop rate is too small, skip this step, drop one tree randomly
|
||||
if (drop_rate_ > kEpsilon) {
|
||||
if (gbdt_config_->drop_rate > kEpsilon) {
|
||||
for (int i = 0; i < iter_; ++i) {
|
||||
if (random_for_drop_.NextDouble() < drop_rate_) {
|
||||
if (random_for_drop_.NextDouble() < gbdt_config_->drop_rate) {
|
||||
drop_index_.push_back(i);
|
||||
}
|
||||
}
|
||||
|
@ -123,8 +130,6 @@ private:
|
|||
}
|
||||
/*! \brief The indexes of dropping trees */
|
||||
std::vector<int> drop_index_;
|
||||
/*! \brief Dropping rate */
|
||||
double drop_rate_;
|
||||
/*! \brief Random generator, used to select dropping trees */
|
||||
Random random_for_drop_;
|
||||
/*! \brief Flag that the score is update on current iter or not*/
|
||||
|
|
|
@ -33,41 +33,57 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
|
|||
max_feature_idx_ = 0;
|
||||
num_class_ = config->num_class;
|
||||
train_data_ = nullptr;
|
||||
gbdt_config_ = nullptr;
|
||||
ResetTrainingData(config, train_data, object_function, training_metrics);
|
||||
}
|
||||
|
||||
void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
|
||||
const std::vector<const Metric*>& training_metrics) {
|
||||
if (train_data == nullptr) { return; }
|
||||
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
|
||||
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
|
||||
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
|
||||
}
|
||||
gbdt_config_ = config;
|
||||
early_stopping_round_ = gbdt_config_->early_stopping_round;
|
||||
shrinkage_rate_ = gbdt_config_->learning_rate;
|
||||
random_ = Random(gbdt_config_->bagging_seed);
|
||||
// create tree learner
|
||||
early_stopping_round_ = new_config->early_stopping_round;
|
||||
shrinkage_rate_ = new_config->learning_rate;
|
||||
random_ = Random(new_config->bagging_seed);
|
||||
|
||||
// create tree learner, only create once
|
||||
if (gbdt_config_ == nullptr) {
|
||||
tree_learner_.clear();
|
||||
for (int i = 0; i < num_class_; ++i) {
|
||||
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
|
||||
new_tree_learner->Init(train_data);
|
||||
// init tree learner
|
||||
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, &new_config->tree_config));
|
||||
tree_learner_.push_back(std::move(new_tree_learner));
|
||||
}
|
||||
tree_learner_.shrink_to_fit();
|
||||
}
|
||||
// init tree learner
|
||||
if (train_data_ != train_data) {
|
||||
for (int i = 0; i < num_class_; ++i) {
|
||||
tree_learner_[i]->Init(train_data);
|
||||
}
|
||||
}
|
||||
// reset config for tree learner
|
||||
for (int i = 0; i < num_class_; ++i) {
|
||||
tree_learner_[i]->ResetConfig(&new_config->tree_config);
|
||||
}
|
||||
|
||||
object_function_ = object_function;
|
||||
|
||||
sigmoid_ = -1.0f;
|
||||
if (object_function_ != nullptr
|
||||
&& std::string(object_function_->GetName()) == std::string("binary")) {
|
||||
// only binary classification need sigmoid transform
|
||||
sigmoid_ = new_config->sigmoid;
|
||||
}
|
||||
|
||||
if (train_data_ != train_data) {
|
||||
// push training metrics
|
||||
training_metrics_.clear();
|
||||
for (const auto& metric : training_metrics) {
|
||||
training_metrics_.push_back(metric);
|
||||
}
|
||||
training_metrics_.shrink_to_fit();
|
||||
sigmoid_ = -1.0f;
|
||||
if (object_function_ != nullptr
|
||||
&& std::string(object_function_->GetName()) == std::string("binary")) {
|
||||
// only binary classification need sigmoid transform
|
||||
sigmoid_ = gbdt_config_->sigmoid;
|
||||
}
|
||||
if (train_data_ != train_data) {
|
||||
// not same training data, need reset score and others
|
||||
// create score tracker
|
||||
train_score_updater_.reset(new ScoreUpdater(train_data, num_class_));
|
||||
|
@ -88,8 +104,13 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
|
|||
max_feature_idx_ = train_data->num_total_features() - 1;
|
||||
// get label index
|
||||
label_idx_ = train_data->label_idx();
|
||||
}
|
||||
|
||||
if (train_data_ != train_data
|
||||
|| gbdt_config_ == nullptr
|
||||
|| (gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
|
||||
// if need bagging, create buffer
|
||||
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
|
||||
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
|
||||
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
|
||||
bag_data_indices_ = std::vector<data_size_t>(num_data_);
|
||||
} else {
|
||||
|
@ -100,6 +121,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
|
|||
}
|
||||
}
|
||||
train_data_ = train_data;
|
||||
gbdt_config_.reset(new_config.release());
|
||||
}
|
||||
|
||||
void GBDT::AddValidDataset(const Dataset* valid_data,
|
||||
|
|
|
@ -68,14 +68,6 @@ public:
|
|||
*/
|
||||
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
|
||||
|
||||
/*!
|
||||
* \brief Reset shrinkage_rate data for current boosting
|
||||
* \param shrinkage_rate Configs for boosting
|
||||
*/
|
||||
void ResetShrinkageRate(double shrinkage_rate) override {
|
||||
shrinkage_rate_ = shrinkage_rate;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Adding a validation dataset
|
||||
* \param valid_data Validation dataset
|
||||
|
@ -245,7 +237,7 @@ protected:
|
|||
/*! \brief Pointer to training data */
|
||||
const Dataset* train_data_;
|
||||
/*! \brief Config of gbdt */
|
||||
const BoostingConfig* gbdt_config_;
|
||||
std::unique_ptr<BoostingConfig> gbdt_config_;
|
||||
/*! \brief Tree learner, will use this class to learn trees */
|
||||
std::vector<std::unique_ptr<TreeLearner>> tree_learner_;
|
||||
/*! \brief Objective function */
|
||||
|
|
|
@ -40,12 +40,14 @@ public:
|
|||
Log::Warning("continued train from model is not support for c_api, \
|
||||
please use continued train with input score");
|
||||
}
|
||||
|
||||
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
|
||||
train_data_ = train_data;
|
||||
ConstructObjectAndTrainingMetrics(train_data);
|
||||
|
||||
// initialize the boosting
|
||||
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
|
||||
boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
|
||||
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
|
||||
|
||||
ResetTrainingData(train_data);
|
||||
}
|
||||
|
||||
void MergeFrom(const Booster* other) {
|
||||
|
@ -60,13 +62,34 @@ public:
|
|||
void ResetTrainingData(const Dataset* train_data) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
train_data_ = train_data;
|
||||
ConstructObjectAndTrainingMetrics(train_data_);
|
||||
// initialize the boosting
|
||||
// create objective function
|
||||
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
|
||||
config_.objective_config));
|
||||
if (objective_fun_ == nullptr) {
|
||||
Log::Warning("Using self-defined objective function");
|
||||
}
|
||||
// initialize the objective function
|
||||
if (objective_fun_ != nullptr) {
|
||||
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
|
||||
}
|
||||
|
||||
// create training metric
|
||||
train_metric_.clear();
|
||||
for (auto metric_type : config_.metric_types) {
|
||||
auto metric = std::unique_ptr<Metric>(
|
||||
Metric::CreateMetric(metric_type, config_.metric_config));
|
||||
if (metric == nullptr) { continue; }
|
||||
metric->Init(train_data_->metadata(), train_data_->num_data());
|
||||
train_metric_.push_back(std::move(metric));
|
||||
}
|
||||
train_metric_.shrink_to_fit();
|
||||
// reset the boosting
|
||||
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
|
||||
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
|
||||
}
|
||||
|
||||
void ResetConfig(const char* parameters) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto param = ConfigBase::Str2Map(parameters);
|
||||
if (param.count("num_class")) {
|
||||
Log::Fatal("cannot change num class during training");
|
||||
|
@ -77,21 +100,28 @@ public:
|
|||
if (param.count("metric")) {
|
||||
Log::Fatal("cannot change metric during training");
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
config_.Set(param);
|
||||
}
|
||||
if (config_.num_threads > 0) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
omp_set_num_threads(config_.num_threads);
|
||||
}
|
||||
if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
|
||||
// only need to set learning rate
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
|
||||
} else {
|
||||
ResetTrainingData(train_data_);
|
||||
|
||||
if (param.count("objective")) {
|
||||
// create objective function
|
||||
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
|
||||
config_.objective_config));
|
||||
if (objective_fun_ == nullptr) {
|
||||
Log::Warning("Using self-defined objective function");
|
||||
}
|
||||
// initialize the objective function
|
||||
if (objective_fun_ != nullptr) {
|
||||
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
|
||||
}
|
||||
}
|
||||
|
||||
boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
|
||||
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
|
||||
|
||||
}
|
||||
|
||||
void AddValidData(const Dataset* valid_data) {
|
||||
|
@ -107,6 +137,7 @@ public:
|
|||
boosting_->AddValidDataset(valid_data,
|
||||
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
|
||||
}
|
||||
|
||||
bool TrainOneIter() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return boosting_->TrainOneIter(nullptr, nullptr, false);
|
||||
|
@ -142,10 +173,12 @@ public:
|
|||
}
|
||||
|
||||
std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return predictor_->GetPredictFunction()(features);
|
||||
}
|
||||
|
||||
void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
predictor_->Predict(data_filename, result_filename, data_has_header);
|
||||
}
|
||||
|
||||
|
@ -180,29 +213,6 @@ public:
|
|||
|
||||
private:
|
||||
|
||||
void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
|
||||
// create objective function
|
||||
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
|
||||
config_.objective_config));
|
||||
if (objective_fun_ == nullptr) {
|
||||
Log::Warning("Using self-defined objective functions");
|
||||
}
|
||||
// create training metric
|
||||
train_metric_.clear();
|
||||
for (auto metric_type : config_.metric_types) {
|
||||
auto metric = std::unique_ptr<Metric>(
|
||||
Metric::CreateMetric(metric_type, config_.metric_config));
|
||||
if (metric == nullptr) { continue; }
|
||||
metric->Init(train_data->metadata(), train_data->num_data());
|
||||
train_metric_.push_back(std::move(metric));
|
||||
}
|
||||
train_metric_.shrink_to_fit();
|
||||
// initialize the objective function
|
||||
if (objective_fun_ != nullptr) {
|
||||
objective_fun_->Init(train_data->metadata(), train_data->num_data());
|
||||
}
|
||||
}
|
||||
|
||||
const Dataset* train_data_;
|
||||
std::unique_ptr<Boosting> boosting_;
|
||||
/*! \brief All configs */
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
namespace LightGBM {
|
||||
|
||||
DataParallelTreeLearner::DataParallelTreeLearner(const TreeConfig& tree_config)
|
||||
DataParallelTreeLearner::DataParallelTreeLearner(const TreeConfig* tree_config)
|
||||
:SerialTreeLearner(tree_config) {
|
||||
}
|
||||
|
||||
|
@ -37,10 +37,13 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) {
|
|||
|
||||
buffer_write_start_pos_.resize(num_features_);
|
||||
buffer_read_start_pos_.resize(num_features_);
|
||||
global_data_count_in_leaf_.resize(tree_config_.num_leaves);
|
||||
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
|
||||
}
|
||||
|
||||
|
||||
void DataParallelTreeLearner::ResetConfig(const TreeConfig* tree_config) {
|
||||
SerialTreeLearner::ResetConfig(tree_config);
|
||||
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
|
||||
}
|
||||
|
||||
void DataParallelTreeLearner::BeforeTrain() {
|
||||
SerialTreeLearner::BeforeTrain();
|
||||
|
|
|
@ -276,6 +276,10 @@ public:
|
|||
*/
|
||||
void set_is_splittable(bool val) { is_splittable_ = val; }
|
||||
|
||||
void ResetConfig(const TreeConfig* tree_config) {
|
||||
tree_config_ = tree_config;
|
||||
}
|
||||
|
||||
private:
|
||||
/*!
|
||||
* \brief Calculate the split gain based on regularized sum_gradients and sum_hessians
|
||||
|
@ -336,6 +340,8 @@ public:
|
|||
* \brief Constructor
|
||||
*/
|
||||
HistogramPool() {
|
||||
cache_size_ = 0;
|
||||
total_size_ = 0;
|
||||
}
|
||||
|
||||
/*!
|
||||
|
@ -348,7 +354,7 @@ public:
|
|||
* \param cache_size Max cache size
|
||||
* \param total_size Total size will be used
|
||||
*/
|
||||
void ResetSize(int cache_size, int total_size) {
|
||||
void Reset(int cache_size, int total_size) {
|
||||
cache_size_ = cache_size;
|
||||
// at least need 2 bucket to store smaller leaf and larger leaf
|
||||
CHECK(cache_size_ >= 2);
|
||||
|
@ -382,6 +388,7 @@ public:
|
|||
* \param obj_create_fun that used to generate object
|
||||
*/
|
||||
void Fill(std::function<FeatureHistogram*()> obj_create_fun) {
|
||||
fill_func_ = obj_create_fun;
|
||||
pool_.clear();
|
||||
pool_.resize(cache_size_);
|
||||
for (int i = 0; i < cache_size_; ++i) {
|
||||
|
@ -389,6 +396,23 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void DynamicChangeSize(int cache_size, int total_size) {
|
||||
int old_cache_size = cache_size_;
|
||||
Reset(cache_size, total_size);
|
||||
pool_.resize(cache_size_);
|
||||
for (int i = old_cache_size; i < cache_size_; ++i) {
|
||||
pool_[i].reset(fill_func_());
|
||||
}
|
||||
}
|
||||
|
||||
void ResetConfig(const TreeConfig* tree_config, int array_size) {
|
||||
for (int i = 0; i < cache_size_; ++i) {
|
||||
auto data_ptr = pool_[i].get();
|
||||
for (int j = 0; j < array_size; ++j) {
|
||||
data_ptr[j].ResetConfig(tree_config);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief Get data for the specific index
|
||||
* \param idx which index want to get
|
||||
|
@ -446,6 +470,7 @@ public:
|
|||
private:
|
||||
|
||||
std::vector<std::unique_ptr<FeatureHistogram[]>> pool_;
|
||||
std::function<FeatureHistogram*()> fill_func_;
|
||||
int cache_size_;
|
||||
int total_size_;
|
||||
bool is_enough_ = false;
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
namespace LightGBM {
|
||||
|
||||
FeatureParallelTreeLearner::FeatureParallelTreeLearner(const TreeConfig& tree_config)
|
||||
FeatureParallelTreeLearner::FeatureParallelTreeLearner(const TreeConfig* tree_config)
|
||||
:SerialTreeLearner(tree_config) {
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace LightGBM {
|
|||
*/
|
||||
class FeatureParallelTreeLearner: public SerialTreeLearner {
|
||||
public:
|
||||
explicit FeatureParallelTreeLearner(const TreeConfig& tree_config);
|
||||
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
|
||||
~FeatureParallelTreeLearner();
|
||||
virtual void Init(const Dataset* train_data);
|
||||
|
||||
|
@ -45,9 +45,10 @@ private:
|
|||
*/
|
||||
class DataParallelTreeLearner: public SerialTreeLearner {
|
||||
public:
|
||||
explicit DataParallelTreeLearner(const TreeConfig& tree_config);
|
||||
explicit DataParallelTreeLearner(const TreeConfig* tree_config);
|
||||
~DataParallelTreeLearner();
|
||||
void Init(const Dataset* train_data) override;
|
||||
void ResetConfig(const TreeConfig* tree_config) override;
|
||||
protected:
|
||||
void BeforeTrain() override;
|
||||
void FindBestThresholds() override;
|
||||
|
@ -96,10 +97,10 @@ private:
|
|||
*/
|
||||
class VotingParallelTreeLearner: public SerialTreeLearner {
|
||||
public:
|
||||
explicit VotingParallelTreeLearner(const TreeConfig& tree_config);
|
||||
explicit VotingParallelTreeLearner(const TreeConfig* tree_config);
|
||||
~VotingParallelTreeLearner() { }
|
||||
void Init(const Dataset* train_data) override;
|
||||
|
||||
void ResetConfig(const TreeConfig* tree_config) override;
|
||||
protected:
|
||||
void BeforeTrain() override;
|
||||
bool BeforeFindBestSplit(int left_leaf, int right_leaf) override;
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
|
||||
namespace LightGBM {
|
||||
|
||||
SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config)
|
||||
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
|
||||
:tree_config_(tree_config){
|
||||
random_ = Random(tree_config.feature_fraction_seed);
|
||||
random_ = Random(tree_config_->feature_fraction_seed);
|
||||
}
|
||||
|
||||
SerialTreeLearner::~SerialTreeLearner() {
|
||||
|
@ -22,32 +22,32 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
|
|||
num_features_ = train_data_->num_features();
|
||||
int max_cache_size = 0;
|
||||
// Get the max size of pool
|
||||
if (tree_config_.histogram_pool_size < 0) {
|
||||
max_cache_size = tree_config_.num_leaves;
|
||||
if (tree_config_->histogram_pool_size <= 0) {
|
||||
max_cache_size = tree_config_->num_leaves;
|
||||
} else {
|
||||
size_t total_histogram_size = 0;
|
||||
for (int i = 0; i < train_data_->num_features(); ++i) {
|
||||
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
|
||||
}
|
||||
max_cache_size = static_cast<int>(tree_config_.histogram_pool_size * 1024 * 1024 / total_histogram_size);
|
||||
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
|
||||
}
|
||||
// at least need 2 leaves
|
||||
max_cache_size = std::max(2, max_cache_size);
|
||||
max_cache_size = std::min(max_cache_size, tree_config_.num_leaves);
|
||||
histogram_pool_.ResetSize(max_cache_size, tree_config_.num_leaves);
|
||||
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
|
||||
histogram_pool_.Reset(max_cache_size, tree_config_->num_leaves);
|
||||
|
||||
auto histogram_create_function = [this]() {
|
||||
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
|
||||
for (int j = 0; j < train_data_->num_features(); ++j) {
|
||||
tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
|
||||
j, &tree_config_);
|
||||
j, tree_config_);
|
||||
}
|
||||
return tmp_histogram_array.release();
|
||||
};
|
||||
histogram_pool_.Fill(histogram_create_function);
|
||||
|
||||
// push split information for all leaves
|
||||
best_split_per_leaf_.resize(tree_config_.num_leaves);
|
||||
best_split_per_leaf_.resize(tree_config_->num_leaves);
|
||||
// initialize ordered_bins_ with nullptr
|
||||
ordered_bins_.resize(num_features_);
|
||||
|
||||
|
@ -69,7 +69,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
|
|||
larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
|
||||
|
||||
// initialize data partition
|
||||
data_partition_.reset(new DataPartition(num_data_, tree_config_.num_leaves));
|
||||
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
|
||||
|
||||
is_feature_used_.resize(num_features_);
|
||||
|
||||
|
@ -84,19 +84,49 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
|
|||
}
|
||||
|
||||
|
||||
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
|
||||
if (tree_config_->num_leaves != tree_config->num_leaves) {
|
||||
tree_config_ = tree_config;
|
||||
int max_cache_size = 0;
|
||||
// Get the max size of pool
|
||||
if (tree_config->histogram_pool_size <= 0) {
|
||||
max_cache_size = tree_config_->num_leaves;
|
||||
} else {
|
||||
size_t total_histogram_size = 0;
|
||||
for (int i = 0; i < train_data_->num_features(); ++i) {
|
||||
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
|
||||
}
|
||||
max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
|
||||
}
|
||||
// at least need 2 leaves
|
||||
max_cache_size = std::max(2, max_cache_size);
|
||||
max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
|
||||
histogram_pool_.DynamicChangeSize(max_cache_size, tree_config_->num_leaves);
|
||||
|
||||
// push split information for all leaves
|
||||
best_split_per_leaf_.resize(tree_config_->num_leaves);
|
||||
data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
|
||||
} else {
|
||||
tree_config_ = tree_config;
|
||||
}
|
||||
|
||||
histogram_pool_.ResetConfig(tree_config_, train_data_->num_features());
|
||||
random_ = Random(tree_config_->feature_fraction_seed);
|
||||
}
|
||||
|
||||
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
|
||||
gradients_ = gradients;
|
||||
hessians_ = hessians;
|
||||
// some initial works before training
|
||||
BeforeTrain();
|
||||
auto tree = std::unique_ptr<Tree>(new Tree(tree_config_.num_leaves));
|
||||
auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
|
||||
// save pointer to last trained tree
|
||||
last_trained_tree_ = tree.get();
|
||||
// root leaf
|
||||
int left_leaf = 0;
|
||||
// only root leaf can be splitted on first time
|
||||
int right_leaf = -1;
|
||||
for (int split = 0; split < tree_config_.num_leaves - 1; split++) {
|
||||
for (int split = 0; split < tree_config_->num_leaves - 1; split++) {
|
||||
// some initial works before finding best split
|
||||
if (BeforeFindBestSplit(left_leaf, right_leaf)) {
|
||||
// find best threshold for every feature
|
||||
|
@ -121,6 +151,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
|
|||
}
|
||||
|
||||
void SerialTreeLearner::BeforeTrain() {
|
||||
|
||||
// reset histogram pool
|
||||
histogram_pool_.ResetMap();
|
||||
// initialize used features
|
||||
|
@ -128,7 +159,7 @@ void SerialTreeLearner::BeforeTrain() {
|
|||
is_feature_used_[i] = false;
|
||||
}
|
||||
// Get used feature at current tree
|
||||
int used_feature_cnt = static_cast<int>(num_features_*tree_config_.feature_fraction);
|
||||
int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
|
||||
auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
|
||||
for (auto idx : used_feature_indices) {
|
||||
is_feature_used_[idx] = true;
|
||||
|
@ -138,7 +169,7 @@ void SerialTreeLearner::BeforeTrain() {
|
|||
data_partition_->Init();
|
||||
|
||||
// reset the splits for leaves
|
||||
for (int i = 0; i < tree_config_.num_leaves; ++i) {
|
||||
for (int i = 0; i < tree_config_->num_leaves; ++i) {
|
||||
best_split_per_leaf_[i].Reset();
|
||||
}
|
||||
|
||||
|
@ -177,7 +208,7 @@ void SerialTreeLearner::BeforeTrain() {
|
|||
#pragma omp parallel for schedule(guided)
|
||||
for (int i = 0; i < num_features_; ++i) {
|
||||
if (ordered_bins_[i] != nullptr) {
|
||||
ordered_bins_[i]->Init(nullptr, tree_config_.num_leaves);
|
||||
ordered_bins_[i]->Init(nullptr, tree_config_->num_leaves);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -196,7 +227,7 @@ void SerialTreeLearner::BeforeTrain() {
|
|||
#pragma omp parallel for schedule(guided)
|
||||
for (int i = 0; i < num_features_; ++i) {
|
||||
if (ordered_bins_[i] != nullptr) {
|
||||
ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_.num_leaves);
|
||||
ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -205,9 +236,9 @@ void SerialTreeLearner::BeforeTrain() {
|
|||
|
||||
bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
|
||||
// check depth of current leaf
|
||||
if (tree_config_.max_depth > 0) {
|
||||
if (tree_config_->max_depth > 0) {
|
||||
// only need to check left leaf, since right leaf is in same level of left leaf
|
||||
if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_.max_depth) {
|
||||
if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
|
||||
best_split_per_leaf_[left_leaf].gain = kMinScore;
|
||||
if (right_leaf >= 0) {
|
||||
best_split_per_leaf_[right_leaf].gain = kMinScore;
|
||||
|
@ -218,8 +249,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
|
|||
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
|
||||
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
|
||||
// no enough data to continue
|
||||
if (num_data_in_right_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2)
|
||||
&& num_data_in_left_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2)) {
|
||||
if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
|
||||
&& num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
|
||||
best_split_per_leaf_[left_leaf].gain = kMinScore;
|
||||
if (right_leaf >= 0) {
|
||||
best_split_per_leaf_[right_leaf].gain = kMinScore;
|
||||
|
|
|
@ -26,12 +26,14 @@ namespace LightGBM {
|
|||
*/
|
||||
class SerialTreeLearner: public TreeLearner {
|
||||
public:
|
||||
explicit SerialTreeLearner(const TreeConfig& tree_config);
|
||||
explicit SerialTreeLearner(const TreeConfig* tree_config);
|
||||
|
||||
~SerialTreeLearner();
|
||||
|
||||
void Init(const Dataset* train_data) override;
|
||||
|
||||
void ResetConfig(const TreeConfig* tree_config) override;
|
||||
|
||||
Tree* Train(const score_t* gradients, const score_t *hessians) override;
|
||||
|
||||
void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override {
|
||||
|
@ -153,7 +155,7 @@ protected:
|
|||
/*! \brief used to cache historical histogram to speed up*/
|
||||
HistogramPool histogram_pool_;
|
||||
/*! \brief config of tree learner*/
|
||||
const TreeConfig& tree_config_;
|
||||
const TreeConfig* tree_config_;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace LightGBM {
|
||||
|
||||
TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConfig& tree_config) {
|
||||
TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConfig* tree_config) {
|
||||
if (type == TreeLearnerType::kSerialTreeLearner) {
|
||||
return new SerialTreeLearner(tree_config);
|
||||
} else if (type == TreeLearnerType::kFeatureParallelTreelearner) {
|
||||
|
|
|
@ -9,9 +9,9 @@
|
|||
|
||||
namespace LightGBM {
|
||||
|
||||
VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig& tree_config)
|
||||
VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig* tree_config)
|
||||
:SerialTreeLearner(tree_config) {
|
||||
top_k_ = tree_config.top_k;
|
||||
top_k_ = tree_config_->top_k;
|
||||
}
|
||||
|
||||
void VotingParallelTreeLearner::Init(const Dataset* train_data) {
|
||||
|
@ -44,34 +44,41 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) {
|
|||
|
||||
smaller_buffer_read_start_pos_.resize(num_features_);
|
||||
larger_buffer_read_start_pos_.resize(num_features_);
|
||||
global_data_count_in_leaf_.resize(tree_config_.num_leaves);
|
||||
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
|
||||
|
||||
smaller_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
|
||||
larger_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
|
||||
|
||||
local_tree_config_ = tree_config_;
|
||||
local_tree_config_ = *tree_config_;
|
||||
local_tree_config_.min_data_in_leaf /= num_machines_;
|
||||
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
|
||||
|
||||
auto histogram_create_function = [this]() {
|
||||
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
|
||||
for (int j = 0; j < train_data_->num_features(); ++j) {
|
||||
tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
|
||||
j, &local_tree_config_);
|
||||
}
|
||||
return tmp_histogram_array.release();
|
||||
};
|
||||
histogram_pool_.Fill(histogram_create_function);
|
||||
histogram_pool_.ResetConfig(&local_tree_config_, train_data_->num_features());
|
||||
|
||||
// initialize histograms for global
|
||||
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
|
||||
larger_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
|
||||
for (int j = 0; j < num_features_; ++j) {
|
||||
smaller_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_);
|
||||
larger_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_);
|
||||
smaller_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, tree_config_);
|
||||
larger_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, tree_config_);
|
||||
}
|
||||
}
|
||||
|
||||
void VotingParallelTreeLearner::ResetConfig(const TreeConfig* tree_config) {
|
||||
SerialTreeLearner::ResetConfig(tree_config);
|
||||
|
||||
local_tree_config_ = *tree_config_;
|
||||
local_tree_config_.min_data_in_leaf /= num_machines_;
|
||||
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
|
||||
|
||||
histogram_pool_.ResetConfig(&local_tree_config_, train_data_->num_features());
|
||||
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
|
||||
|
||||
for (int j = 0; j < num_features_; ++j) {
|
||||
smaller_leaf_histogram_array_global_[j].ResetConfig(tree_config_);
|
||||
larger_leaf_histogram_array_global_[j].ResetConfig(tree_config_);
|
||||
}
|
||||
}
|
||||
|
||||
void VotingParallelTreeLearner::BeforeTrain() {
|
||||
SerialTreeLearner::BeforeTrain();
|
||||
|
|
Загрузка…
Ссылка в новой задаче