зеркало из https://github.com/microsoft/LightGBM.git
* fix leaf splits update after split in quantized training * fix preparation ordered gradients for quantized training * remove force_row_wise in distributed test for quantized training * Update src/treelearner/leaf_splits.hpp --------- Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
Родитель
cd39520c5e
Коммит
a92bf3742b
|
@ -1278,21 +1278,34 @@ void Dataset::ConstructHistogramsInner(
|
|||
auto ptr_ordered_grad = gradients;
|
||||
auto ptr_ordered_hess = hessians;
|
||||
if (num_used_dense_group > 0) {
|
||||
if (USE_INDICES) {
|
||||
if (USE_HESSIAN) {
|
||||
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
|
||||
if (USE_QUANT_GRAD) {
|
||||
int16_t* ordered_gradients_and_hessians = reinterpret_cast<int16_t*>(ordered_gradients);
|
||||
const int16_t* gradients_and_hessians = reinterpret_cast<const int16_t*>(gradients);
|
||||
if (USE_INDICES) {
|
||||
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
|
||||
for (data_size_t i = 0; i < num_data; ++i) {
|
||||
ordered_gradients[i] = gradients[data_indices[i]];
|
||||
ordered_hessians[i] = hessians[data_indices[i]];
|
||||
ordered_gradients_and_hessians[i] = gradients_and_hessians[data_indices[i]];
|
||||
}
|
||||
ptr_ordered_grad = ordered_gradients;
|
||||
ptr_ordered_hess = ordered_hessians;
|
||||
} else {
|
||||
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
|
||||
for (data_size_t i = 0; i < num_data; ++i) {
|
||||
ordered_gradients[i] = gradients[data_indices[i]];
|
||||
ptr_ordered_grad = reinterpret_cast<const score_t*>(ordered_gradients);
|
||||
ptr_ordered_hess = nullptr;
|
||||
}
|
||||
} else {
|
||||
if (USE_INDICES) {
|
||||
if (USE_HESSIAN) {
|
||||
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
|
||||
for (data_size_t i = 0; i < num_data; ++i) {
|
||||
ordered_gradients[i] = gradients[data_indices[i]];
|
||||
ordered_hessians[i] = hessians[data_indices[i]];
|
||||
}
|
||||
ptr_ordered_grad = ordered_gradients;
|
||||
ptr_ordered_hess = ordered_hessians;
|
||||
} else {
|
||||
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
|
||||
for (data_size_t i = 0; i < num_data; ++i) {
|
||||
ordered_gradients[i] = gradients[data_indices[i]];
|
||||
}
|
||||
ptr_ordered_grad = ordered_gradients;
|
||||
}
|
||||
ptr_ordered_grad = ordered_gradients;
|
||||
}
|
||||
}
|
||||
OMP_INIT_EX();
|
||||
|
|
|
@ -53,6 +53,25 @@ class LeafSplits {
|
|||
weight_ = weight;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Init split on current leaf on partial data.
|
||||
* \param leaf Index of current leaf
|
||||
* \param data_partition current data partition
|
||||
* \param sum_gradients
|
||||
* \param sum_hessians
|
||||
* \param sum_gradients_and_hessians
|
||||
* \param weight
|
||||
*/
|
||||
void Init(int leaf, const DataPartition* data_partition, double sum_gradients,
|
||||
double sum_hessians, int64_t sum_gradients_and_hessians, double weight) {
|
||||
leaf_index_ = leaf;
|
||||
data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_);
|
||||
sum_gradients_ = sum_gradients;
|
||||
sum_hessians_ = sum_hessians;
|
||||
int_sum_gradients_and_hessians_ = sum_gradients_and_hessians;
|
||||
weight_ = weight;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Init split on current leaf on partial data.
|
||||
* \param leaf Index of current leaf
|
||||
|
|
|
@ -841,32 +841,65 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
|
|||
#endif
|
||||
|
||||
// init the leaves that used on next iteration
|
||||
if (best_split_info.left_count < best_split_info.right_count) {
|
||||
CHECK_GT(best_split_info.left_count, 0);
|
||||
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
|
||||
best_split_info.left_sum_gradient,
|
||||
best_split_info.left_sum_hessian,
|
||||
best_split_info.left_output);
|
||||
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
|
||||
best_split_info.right_sum_gradient,
|
||||
best_split_info.right_sum_hessian,
|
||||
best_split_info.right_output);
|
||||
if (!config_->use_quantized_grad) {
|
||||
if (best_split_info.left_count < best_split_info.right_count) {
|
||||
CHECK_GT(best_split_info.left_count, 0);
|
||||
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
|
||||
best_split_info.left_sum_gradient,
|
||||
best_split_info.left_sum_hessian,
|
||||
best_split_info.left_output);
|
||||
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
|
||||
best_split_info.right_sum_gradient,
|
||||
best_split_info.right_sum_hessian,
|
||||
best_split_info.right_output);
|
||||
} else {
|
||||
CHECK_GT(best_split_info.right_count, 0);
|
||||
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
|
||||
best_split_info.right_sum_gradient,
|
||||
best_split_info.right_sum_hessian,
|
||||
best_split_info.right_output);
|
||||
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
|
||||
best_split_info.left_sum_gradient,
|
||||
best_split_info.left_sum_hessian,
|
||||
best_split_info.left_output);
|
||||
}
|
||||
} else {
|
||||
CHECK_GT(best_split_info.right_count, 0);
|
||||
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
|
||||
best_split_info.right_sum_gradient,
|
||||
best_split_info.right_sum_hessian,
|
||||
best_split_info.right_output);
|
||||
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
|
||||
best_split_info.left_sum_gradient,
|
||||
best_split_info.left_sum_hessian,
|
||||
best_split_info.left_output);
|
||||
if (best_split_info.left_count < best_split_info.right_count) {
|
||||
CHECK_GT(best_split_info.left_count, 0);
|
||||
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
|
||||
best_split_info.left_sum_gradient,
|
||||
best_split_info.left_sum_hessian,
|
||||
best_split_info.left_sum_gradient_and_hessian,
|
||||
best_split_info.left_output);
|
||||
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
|
||||
best_split_info.right_sum_gradient,
|
||||
best_split_info.right_sum_hessian,
|
||||
best_split_info.right_sum_gradient_and_hessian,
|
||||
best_split_info.right_output);
|
||||
} else {
|
||||
CHECK_GT(best_split_info.right_count, 0);
|
||||
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
|
||||
best_split_info.right_sum_gradient,
|
||||
best_split_info.right_sum_hessian,
|
||||
best_split_info.right_sum_gradient_and_hessian,
|
||||
best_split_info.right_output);
|
||||
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
|
||||
best_split_info.left_sum_gradient,
|
||||
best_split_info.left_sum_hessian,
|
||||
best_split_info.left_sum_gradient_and_hessian,
|
||||
best_split_info.left_output);
|
||||
}
|
||||
}
|
||||
if (config_->use_quantized_grad && config_->tree_learner != std::string("data")) {
|
||||
gradient_discretizer_->SetNumBitsInHistogramBin<false>(*left_leaf, *right_leaf,
|
||||
data_partition_->leaf_count(*left_leaf),
|
||||
data_partition_->leaf_count(*right_leaf));
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
CheckSplit(best_split_info, *left_leaf, *right_leaf);
|
||||
#endif
|
||||
|
||||
auto leaves_need_update = constraints_->Update(
|
||||
is_numerical_split, *left_leaf, *right_leaf,
|
||||
best_split_info.monotone_type, best_split_info.right_output,
|
||||
|
@ -1024,4 +1057,48 @@ std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
|
|||
*split = bests[best_idx];
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index) {
|
||||
data_size_t num_data_in_left = 0;
|
||||
data_size_t num_data_in_right = 0;
|
||||
const data_size_t* data_indices_in_left = data_partition_->GetIndexOnLeaf(left_leaf_index, &num_data_in_left);
|
||||
const data_size_t* data_indices_in_right = data_partition_->GetIndexOnLeaf(right_leaf_index, &num_data_in_right);
|
||||
if (config_->use_quantized_grad) {
|
||||
int32_t sum_left_gradient = 0;
|
||||
int32_t sum_left_hessian = 0;
|
||||
int32_t sum_right_gradient = 0;
|
||||
int32_t sum_right_hessian = 0;
|
||||
const int8_t* discretized_grad_and_hess = gradient_discretizer_->discretized_gradients_and_hessians();
|
||||
for (data_size_t i = 0; i < num_data_in_left; ++i) {
|
||||
const data_size_t index = data_indices_in_left[i];
|
||||
sum_left_gradient += discretized_grad_and_hess[2 * index + 1];
|
||||
sum_left_hessian += discretized_grad_and_hess[2 * index];
|
||||
}
|
||||
for (data_size_t i = 0; i < num_data_in_right; ++i) {
|
||||
const data_size_t index = data_indices_in_right[i];
|
||||
sum_right_gradient += discretized_grad_and_hess[2 * index + 1];
|
||||
sum_right_hessian += discretized_grad_and_hess[2 * index];
|
||||
}
|
||||
Log::Warning("============================ start leaf split info ============================");
|
||||
Log::Warning("left_leaf_index = %d, right_leaf_index = %d", left_leaf_index, right_leaf_index);
|
||||
Log::Warning("num_data_in_left = %d, num_data_in_right = %d", num_data_in_left, num_data_in_right);
|
||||
Log::Warning("sum_left_gradient = %d, best_split_info->left_sum_gradient_and_hessian.gradient = %d", sum_left_gradient,
|
||||
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32));
|
||||
Log::Warning("sum_left_hessian = %d, best_split_info->left_sum_gradient_and_hessian.hessian = %d", sum_left_hessian,
|
||||
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
|
||||
Log::Warning("sum_right_gradient = %d, best_split_info->right_sum_gradient_and_hessian.gradient = %d", sum_right_gradient,
|
||||
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
|
||||
Log::Warning("sum_right_hessian = %d, best_split_info->right_sum_gradient_and_hessian.hessian = %d", sum_right_hessian,
|
||||
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
|
||||
CHECK_EQ(num_data_in_left, best_split_info.left_count);
|
||||
CHECK_EQ(num_data_in_right, best_split_info.right_count);
|
||||
CHECK_EQ(sum_left_gradient, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32))
|
||||
CHECK_EQ(sum_left_hessian, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
|
||||
CHECK_EQ(sum_right_gradient, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
|
||||
CHECK_EQ(sum_right_hessian, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
|
||||
Log::Warning("============================ end leaf split info ============================");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -171,7 +171,9 @@ class SerialTreeLearner: public TreeLearner {
|
|||
|
||||
std::set<int> FindAllForceFeatures(Json force_split_leaf_setting);
|
||||
|
||||
#ifdef DEBUG
|
||||
void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index);
|
||||
#endif
|
||||
|
||||
/*!
|
||||
* \brief Get the number of data in a leaf
|
||||
|
|
|
@ -1838,7 +1838,6 @@ def test_distributed_quantized_training(cluster):
|
|||
'num_grad_quant_bins': 30,
|
||||
'quant_train_renew_leaf': True,
|
||||
'verbose': -1,
|
||||
'force_row_wise': True,
|
||||
}
|
||||
|
||||
quant_dask_classifier = lgb.DaskLGBMRegressor(
|
||||
|
|
Загрузка…
Ссылка в новой задаче