This commit is contained in:
Guolin Ke 2016-12-13 15:53:01 +08:00
Родитель 26d3323212
Коммит 762b5707df
7 изменённых файлов: 442 добавлений и 3 удалений

Просмотреть файл

@ -181,13 +181,14 @@ public:
// And the max leaves will be min(num_leaves, pow(2, max_depth - 1))
// max_depth < 0 means not limit
int max_depth = -1;
int top_k = 20;
void Set(const std::unordered_map<std::string, std::string>& params) override;
};
/*! \brief Types of tree learning algorithms */
enum TreeLearnerType {
kSerialTreeLearner, kFeatureParallelTreelearner,
kDataParallelTreeLearner
kDataParallelTreeLearner, KVotingParallelTreeLearner
};
/*! \brief Config for Boosting */
@ -385,6 +386,7 @@ struct ParameterAlias {
{ "raw_score", "is_predict_raw_score" },
{ "leaf_index", "is_predict_leaf_index" },
{ "min_split_gain", "min_gain_to_split" },
{ "topk", "top_k" },
{ "reg_alpha", "lambda_l1" },
{ "reg_lambda", "lambda_l2" },
{ "num_classes", "num_class" }

Просмотреть файл

@ -289,6 +289,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
GetInt(params, "top_k", &top_k);
CHECK(max_depth > 1 || max_depth < 0);
}
@ -327,8 +328,9 @@ void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, st
tree_learner_type = TreeLearnerType::kFeatureParallelTreelearner;
} else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = TreeLearnerType::kDataParallelTreeLearner;
}
else {
} else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = TreeLearnerType::KVotingParallelTreeLearner;
} else {
Log::Fatal("Unknown tree learner type %s", value.c_str());
}
}

Просмотреть файл

@ -88,6 +88,88 @@ private:
std::vector<data_size_t> global_data_count_in_leaf_;
};
/*!
* \brief Voting based data parallel learning algorithm.
* Like data parallel, but not aggregate histograms for all features.
* Here using voting to reduce features, and only aggregate histograms for selected features.
* When #data is large and #feature is large, you can use this to have better speed-up
*/
class VotingParallelTreeLearner: public SerialTreeLearner {
public:
explicit VotingParallelTreeLearner(const TreeConfig& tree_config);
~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data) override;
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(int left_leaf, int right_leaf) override;
void FindBestThresholds() override;
void FindBestSplitsForLeaves() override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
if (leaf_idx >= 0) {
return global_data_count_in_leaf_[leaf_idx];
} else {
return 0;
}
}
/*!
* \brief Perform global voting
* \param leaf_idx index of leaf
* \param splits All splits from local voting
* \param out Result of gobal voting, only store feature indices
*/
void GlobalVoting(int leaf_idx, const std::vector<SplitInfo>& splits,
std::vector<int>* out);
/*!
* \brief Copy local histgram to buffer
* \param smaller_top_features Selected features for smaller leaf
* \param larger_top_features Selected features for larger leaf
*/
void CopyLocalHistogram(const std::vector<int>& smaller_top_features,
const std::vector<int>& larger_top_features);
private:
/*! \brief Tree config used in local mode */
TreeConfig local_tree_config_;
/*! \brief Voting size */
int top_k_;
/*! \brief Rank of local machine*/
int rank_;
/*! \brief Number of machines */
int num_machines_;
/*! \brief Buffer for network send */
std::vector<char> input_buffer_;
/*! \brief Buffer for network receive */
std::vector<char> output_buffer_;
/*! \brief different machines will aggregate histograms for different features,
use this to mark local aggregate features*/
std::vector<bool> smaller_is_feature_aggregated_;
/*! \brief different machines will aggregate histograms for different features,
use this to mark local aggregate features*/
std::vector<bool> larger_is_feature_aggregated_;
/*! \brief Block start index for reduce scatter */
std::vector<int> block_start_;
/*! \brief Block size for reduce scatter */
std::vector<int> block_len_;
/*! \brief Read positions for feature histgrams at smaller leaf */
std::vector<int> smaller_buffer_read_start_pos_;
/*! \brief Read positions for feature histgrams at larger leaf */
std::vector<int> larger_buffer_read_start_pos_;
/*! \brief Size for reduce scatter */
int reduce_scatter_size_;
/*! \brief Store global number of data in leaves */
std::vector<data_size_t> global_data_count_in_leaf_;
/*! \brief Store global split information for smaller leaf */
std::unique_ptr<LeafSplits> smaller_leaf_splits_global_;
/*! \brief Store global split information for larger leaf */
std::unique_ptr<LeafSplits> larger_leaf_splits_global_;
/*! \brief Store global histogram for smaller leaf */
std::unique_ptr<FeatureHistogram[]> smaller_leaf_histogram_array_global_;
/*! \brief Store global histogram for larger leaf */
std::unique_ptr<FeatureHistogram[]> larger_leaf_histogram_array_global_;
};
} // namespace LightGBM
#endif // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_

Просмотреть файл

@ -12,6 +12,8 @@ TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConf
return new FeatureParallelTreeLearner(tree_config);
} else if (type == TreeLearnerType::kDataParallelTreeLearner) {
return new DataParallelTreeLearner(tree_config);
} else if (type == TreeLearnerType::KVotingParallelTreeLearner) {
return new VotingParallelTreeLearner(tree_config);
}
return nullptr;
}

Просмотреть файл

@ -0,0 +1,347 @@
#include "parallel_tree_learner.h"
#include <LightGBM/utils/common.h>
#include <cstring>
#include <tuple>
#include <vector>
namespace LightGBM {
VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig& tree_config)
:SerialTreeLearner(tree_config) {
top_k_ = tree_config.top_k;
}
void VotingParallelTreeLearner::Init(const Dataset* train_data) {
SerialTreeLearner::Init(train_data);
rank_ = Network::rank();
num_machines_ = Network::num_machines();
// limit top k
if (top_k_ > num_features_) {
top_k_ = num_features_;
}
// get max bin
int max_bin = 0;
for (int i = 0; i < num_features_; ++i) {
if (max_bin < train_data_->FeatureAt(i)->num_bin()) {
max_bin = train_data_->FeatureAt(i)->num_bin();
}
}
// calculate buffer size
size_t buffer_size = 2 * top_k_ * std::max(max_bin * sizeof(HistogramBinEntry), sizeof(SplitInfo) * num_machines_);
// left and right on same time, so need double size
input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size);
smaller_is_feature_aggregated_.resize(num_features_);
larger_is_feature_aggregated_.resize(num_features_);
block_start_.resize(num_machines_);
block_len_.resize(num_machines_);
smaller_buffer_read_start_pos_.resize(num_features_);
larger_buffer_read_start_pos_.resize(num_features_);
global_data_count_in_leaf_.resize(tree_config_.num_leaves);
smaller_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
larger_leaf_splits_global_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
local_tree_config_ = tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
auto histogram_create_function = [this]() {
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
for (int j = 0; j < train_data_->num_features(); ++j) {
tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
j, &local_tree_config_);
}
return tmp_histogram_array.release();
};
histogram_pool_.Fill(histogram_create_function);
// initialize histograms for global
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
larger_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
for (int j = 0; j < num_features_; ++j) {
smaller_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_);
larger_leaf_histogram_array_global_[j].Init(train_data_->FeatureAt(j), j, &tree_config_);
}
}
void VotingParallelTreeLearner::BeforeTrain() {
SerialTreeLearner::BeforeTrain();
// sync global data sumup info
std::tuple<data_size_t, double, double> data(smaller_leaf_splits_->num_data_in_leaf(), smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians());
int size = sizeof(std::tuple<data_size_t, double, double>);
std::memcpy(input_buffer_.data(), &data, size);
Network::Allreduce(input_buffer_.data(), size, size, output_buffer_.data(), [](const char *src, char *dst, int len) {
int used_size = 0;
int type_size = sizeof(std::tuple<data_size_t, double, double>);
const std::tuple<data_size_t, double, double> *p1;
std::tuple<data_size_t, double, double> *p2;
while (used_size < len) {
p1 = reinterpret_cast<const std::tuple<data_size_t, double, double> *>(src);
p2 = reinterpret_cast<std::tuple<data_size_t, double, double> *>(dst);
std::get<0>(*p2) = std::get<0>(*p2) + std::get<0>(*p1);
std::get<1>(*p2) = std::get<1>(*p2) + std::get<1>(*p1);
std::get<2>(*p2) = std::get<2>(*p2) + std::get<2>(*p1);
src += type_size;
dst += type_size;
used_size += type_size;
}
});
std::memcpy(&data, output_buffer_.data(), size);
// set global sumup info
smaller_leaf_splits_global_->Init(std::get<1>(data), std::get<2>(data));
larger_leaf_splits_global_->Init();
// init global data count in leaf
global_data_count_in_leaf_[0] = std::get<0>(data);
}
bool VotingParallelTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
if (SerialTreeLearner::BeforeFindBestSplit(left_leaf, right_leaf)) {
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
if (right_leaf < 0) {
return true;
} else if (num_data_in_left_child < num_data_in_right_child) {
// get local sumup
smaller_leaf_splits_->Init(left_leaf, data_partition_.get(), gradients_, hessians_);
larger_leaf_splits_->Init(right_leaf, data_partition_.get(), gradients_, hessians_);
} else {
// get local sumup
smaller_leaf_splits_->Init(right_leaf, data_partition_.get(), gradients_, hessians_);
larger_leaf_splits_->Init(left_leaf, data_partition_.get(), gradients_, hessians_);
}
return true;
} else {
return false;
}
}
void VotingParallelTreeLearner::GlobalVoting(int leaf_idx, const std::vector<SplitInfo>& splits, std::vector<int>* out) {
out->clear();
if (leaf_idx < 0) {
return;
}
// get mean number on machines
score_t mean_num_data = GetGlobalDataCountInLeaf(leaf_idx) / static_cast<score_t>(num_machines_);
std::vector<SplitInfo> feature_best_split(num_features_, SplitInfo());
for (auto & split : splits) {
int fid = split.feature;
if (fid < 0) {
continue;
}
// weighted gain
double gain = split.gain * (split.left_count + split.right_count) / mean_num_data;
if (gain > feature_best_split[fid].gain) {
feature_best_split[fid] = split;
feature_best_split[fid].gain = gain;
}
}
// get top k
std::vector<SplitInfo> top_k_splits;
ArrayArgs<SplitInfo>::MaxK(feature_best_split, top_k_, &top_k_splits);
for (auto& split : top_k_splits) {
if (split.gain == kMinScore || split.feature == -1) {
continue;
}
out->push_back(split.feature);
}
}
void VotingParallelTreeLearner::CopyLocalHistogram(const std::vector<int>& smaller_top_features, const std::vector<int>& larger_top_features) {
for (int i = 0; i < num_features_; ++i) {
smaller_is_feature_aggregated_[i] = false;
larger_is_feature_aggregated_[i] = false;
}
size_t total_num_features = smaller_top_features.size() + larger_top_features.size();
size_t average_feature = (total_num_features + num_machines_ - 1) / num_machines_;
size_t used_num_features = 0, smaller_idx = 0, larger_idx = 0;
block_start_[0] = 0;
reduce_scatter_size_ = 0;
// Copy histogram to buffer, and Get local aggregate features
for (int i = 0; i < num_machines_; ++i) {
size_t cur_size = 0, cur_used_features = 0;
size_t cur_total_feature = std::min(average_feature, total_num_features - used_num_features);
// copy histograms.
while (cur_used_features < cur_total_feature) {
// copy smaller leaf histograms first
if (smaller_idx < smaller_top_features.size()) {
int fid = smaller_top_features[smaller_idx];
++cur_used_features;
// mark local aggregated feature
if (i == rank_) {
smaller_is_feature_aggregated_[fid] = true;
smaller_buffer_read_start_pos_[fid] = static_cast<int>(cur_size);
}
// copy
std::memcpy(input_buffer_.data() + reduce_scatter_size_, smaller_leaf_histogram_array_[fid].HistogramData(), smaller_leaf_histogram_array_[fid].SizeOfHistgram());
cur_size += smaller_leaf_histogram_array_[fid].SizeOfHistgram();
reduce_scatter_size_ += smaller_leaf_histogram_array_[fid].SizeOfHistgram();
++smaller_idx;
}
if (cur_used_features >= cur_total_feature) {
break;
}
// then copy larger leaf histograms
if (larger_idx < larger_top_features.size()) {
int fid = larger_top_features[larger_idx];
++cur_used_features;
// mark local aggregated feature
if (i == rank_) {
larger_is_feature_aggregated_[fid] = true;
larger_buffer_read_start_pos_[fid] = static_cast<int>(cur_size);
}
// copy
std::memcpy(input_buffer_.data() + reduce_scatter_size_, larger_leaf_histogram_array_[fid].HistogramData(), larger_leaf_histogram_array_[fid].SizeOfHistgram());
cur_size += larger_leaf_histogram_array_[fid].SizeOfHistgram();
reduce_scatter_size_ += larger_leaf_histogram_array_[fid].SizeOfHistgram();
++larger_idx;
}
}
used_num_features += cur_used_features;
block_len_[i] = static_cast<int>(cur_size);
if (i < num_machines_ - 1) {
block_start_[i + 1] = block_start_[i] + block_len_[i];
}
}
}
void VotingParallelTreeLearner::FindBestThresholds() {
// use local data to find local best splits
SerialTreeLearner::FindBestThresholds();
std::vector<SplitInfo> smaller_top_k_splits, larger_top_k_splits;
// local voting
ArrayArgs<SplitInfo>::MaxK(smaller_leaf_splits_->BestSplitPerFeature(), top_k_, &smaller_top_k_splits);
ArrayArgs<SplitInfo>::MaxK(larger_leaf_splits_->BestSplitPerFeature(), top_k_, &larger_top_k_splits);
// gather
int offset = 0;
for (int i = 0; i < top_k_; ++i) {
std::memcpy(input_buffer_.data() + offset, &smaller_top_k_splits[i], sizeof(SplitInfo));
offset += sizeof(SplitInfo);
std::memcpy(input_buffer_.data() + offset, &larger_top_k_splits[i], sizeof(SplitInfo));
offset += sizeof(SplitInfo);
}
Network::Allgather(input_buffer_.data(), offset, output_buffer_.data());
// get all top-k from all machines
std::vector<SplitInfo> smaller_top_k_splits_global;
std::vector<SplitInfo> larger_top_k_splits_global;
offset = 0;
for (int i = 0; i < num_machines_; ++i) {
for (int j = 0; j < top_k_; ++j) {
smaller_top_k_splits_global.push_back(SplitInfo());
std::memcpy(&smaller_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(SplitInfo));
offset += sizeof(SplitInfo);
larger_top_k_splits_global.push_back(SplitInfo());
std::memcpy(&larger_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(SplitInfo));
offset += sizeof(SplitInfo);
}
}
// global voting
std::vector<int> smaller_top_features, larger_top_features;
GlobalVoting(smaller_leaf_splits_->LeafIndex(), smaller_top_k_splits_global, &smaller_top_features);
GlobalVoting(larger_leaf_splits_->LeafIndex(), larger_top_k_splits_global, &larger_top_features);
// copy local histgrams to buffer
CopyLocalHistogram(smaller_top_features, larger_top_features);
// Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), block_len_.data(),
output_buffer_.data(), &HistogramBinEntry::SumReducer);
// find best split from local aggregated histograms
#pragma omp parallel for schedule(guided)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (smaller_is_feature_aggregated_[feature_index]) {
smaller_leaf_histogram_array_global_[feature_index].SetSumup(
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
smaller_leaf_splits_global_->sum_gradients(),
smaller_leaf_splits_global_->sum_hessians());
// restore from buffer
smaller_leaf_histogram_array_global_[feature_index].FromMemory(
output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
// find best threshold
smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold(
&smaller_leaf_splits_global_->BestSplitPerFeature()[feature_index]);
}
if (larger_is_feature_aggregated_[feature_index]) {
larger_leaf_histogram_array_global_[feature_index].SetSumup(GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians());
// restore from buffer
larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
// find best threshold
larger_leaf_histogram_array_global_[feature_index].FindBestThreshold(&larger_leaf_splits_global_->BestSplitPerFeature()[feature_index]);
}
}
}
void VotingParallelTreeLearner::FindBestSplitsForLeaves() {
int smaller_best_feature = -1, larger_best_feature = -1;
// find local best
SplitInfo smaller_best, larger_best;
std::vector<double> gains;
for (size_t i = 0; i < smaller_leaf_splits_global_->BestSplitPerFeature().size(); ++i) {
gains.push_back(smaller_leaf_splits_global_->BestSplitPerFeature()[i].gain);
}
smaller_best_feature = static_cast<int>(ArrayArgs<double>::ArgMax(gains));
smaller_best = smaller_leaf_splits_global_->BestSplitPerFeature()[smaller_best_feature];
if (larger_leaf_splits_global_->LeafIndex() >= 0) {
gains.clear();
for (size_t i = 0; i < larger_leaf_splits_global_->BestSplitPerFeature().size(); ++i) {
gains.push_back(larger_leaf_splits_global_->BestSplitPerFeature()[i].gain);
}
larger_best_feature = static_cast<int>(ArrayArgs<double>::ArgMax(gains));
larger_best = larger_leaf_splits_global_->BestSplitPerFeature()[larger_best_feature];
}
// sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo));
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// copy back
best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best;
if (larger_best.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) {
best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best;
}
}
void VotingParallelTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
SerialTreeLearner::Split(tree, best_Leaf, left_leaf, right_leaf);
const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
// set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
// init the global sumup info
if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_global_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
larger_leaf_splits_global_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
} else {
smaller_leaf_splits_global_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
larger_leaf_splits_global_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
}
}
} // namespace FTLBoost

Просмотреть файл

@ -259,6 +259,7 @@
<ClCompile Include="..\src\treelearner\feature_parallel_tree_learner.cpp" />
<ClCompile Include="..\src\treelearner\serial_tree_learner.cpp" />
<ClCompile Include="..\src\treelearner\tree_learner.cpp" />
<ClCompile Include="..\src\treelearner\voting_parallel_tree_learner.cpp" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">

Просмотреть файл

@ -242,5 +242,8 @@
<ClCompile Include="..\src\io\dataset_loader.cpp">
<Filter>src\io</Filter>
</ClCompile>
<ClCompile Include="..\src\treelearner\voting_parallel_tree_learner.cpp">
<Filter>src\treelearner</Filter>
</ClCompile>
</ItemGroup>
</Project>