[fix] fix quantized training (fixes #5982) (fixes #5994) (#6092)

* fix leaf splits update after split in quantized training

* fix preparation ordered gradients for quantized training

* remove force_row_wise in distributed test for quantized training

* Update src/treelearner/leaf_splits.hpp

---------

Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
shiyu1994 2023-09-13 01:06:20 +08:00 коммит произвёл GitHub
Родитель cd39520c5e
Коммит a92bf3742b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 142 добавлений и 32 удалений

Просмотреть файл

@ -1278,21 +1278,34 @@ void Dataset::ConstructHistogramsInner(
auto ptr_ordered_grad = gradients; auto ptr_ordered_grad = gradients;
auto ptr_ordered_hess = hessians; auto ptr_ordered_hess = hessians;
if (num_used_dense_group > 0) { if (num_used_dense_group > 0) {
if (USE_INDICES) { if (USE_QUANT_GRAD) {
if (USE_HESSIAN) { int16_t* ordered_gradients_and_hessians = reinterpret_cast<int16_t*>(ordered_gradients);
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024) const int16_t* gradients_and_hessians = reinterpret_cast<const int16_t*>(gradients);
if (USE_INDICES) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]]; ordered_gradients_and_hessians[i] = gradients_and_hessians[data_indices[i]];
ordered_hessians[i] = hessians[data_indices[i]];
} }
ptr_ordered_grad = ordered_gradients; ptr_ordered_grad = reinterpret_cast<const score_t*>(ordered_gradients);
ptr_ordered_hess = ordered_hessians; ptr_ordered_hess = nullptr;
} else { }
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024) } else {
for (data_size_t i = 0; i < num_data; ++i) { if (USE_INDICES) {
ordered_gradients[i] = gradients[data_indices[i]]; if (USE_HESSIAN) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
ordered_hessians[i] = hessians[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
ptr_ordered_hess = ordered_hessians;
} else {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
} }
ptr_ordered_grad = ordered_gradients;
} }
} }
OMP_INIT_EX(); OMP_INIT_EX();

Просмотреть файл

@ -53,6 +53,25 @@ class LeafSplits {
weight_ = weight; weight_ = weight;
} }
/*!
* \brief Init split on current leaf on partial data.
* \param leaf Index of current leaf
* \param data_partition current data partition
* \param sum_gradients
* \param sum_hessians
* \param sum_gradients_and_hessians
* \param weight
*/
void Init(int leaf, const DataPartition* data_partition, double sum_gradients,
double sum_hessians, int64_t sum_gradients_and_hessians, double weight) {
leaf_index_ = leaf;
data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_);
sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians;
int_sum_gradients_and_hessians_ = sum_gradients_and_hessians;
weight_ = weight;
}
/*! /*!
* \brief Init split on current leaf on partial data. * \brief Init split on current leaf on partial data.
* \param leaf Index of current leaf * \param leaf Index of current leaf

Просмотреть файл

@ -841,32 +841,65 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
#endif #endif
// init the leaves that used on next iteration // init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) { if (!config_->use_quantized_grad) {
CHECK_GT(best_split_info.left_count, 0); if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), CHECK_GT(best_split_info.left_count, 0);
best_split_info.left_sum_gradient, smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_hessian, best_split_info.left_sum_gradient,
best_split_info.left_output); best_split_info.left_sum_hessian,
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.left_output);
best_split_info.right_sum_gradient, larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_hessian, best_split_info.right_sum_gradient,
best_split_info.right_output); best_split_info.right_sum_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
}
} else { } else {
CHECK_GT(best_split_info.right_count, 0); if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), CHECK_GT(best_split_info.left_count, 0);
best_split_info.right_sum_gradient, smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.right_sum_hessian, best_split_info.left_sum_gradient,
best_split_info.right_output); best_split_info.left_sum_hessian,
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient_and_hessian,
best_split_info.left_sum_gradient, best_split_info.left_output);
best_split_info.left_sum_hessian, larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.left_output); best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_sum_gradient_and_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_sum_gradient_and_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_sum_gradient_and_hessian,
best_split_info.left_output);
}
} }
if (config_->use_quantized_grad && config_->tree_learner != std::string("data")) { if (config_->use_quantized_grad && config_->tree_learner != std::string("data")) {
gradient_discretizer_->SetNumBitsInHistogramBin<false>(*left_leaf, *right_leaf, gradient_discretizer_->SetNumBitsInHistogramBin<false>(*left_leaf, *right_leaf,
data_partition_->leaf_count(*left_leaf), data_partition_->leaf_count(*left_leaf),
data_partition_->leaf_count(*right_leaf)); data_partition_->leaf_count(*right_leaf));
} }
#ifdef DEBUG
CheckSplit(best_split_info, *left_leaf, *right_leaf);
#endif
auto leaves_need_update = constraints_->Update( auto leaves_need_update = constraints_->Update(
is_numerical_split, *left_leaf, *right_leaf, is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output, best_split_info.monotone_type, best_split_info.right_output,
@ -1024,4 +1057,48 @@ std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
*split = bests[best_idx]; *split = bests[best_idx];
} }
#ifdef DEBUG
void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index) {
data_size_t num_data_in_left = 0;
data_size_t num_data_in_right = 0;
const data_size_t* data_indices_in_left = data_partition_->GetIndexOnLeaf(left_leaf_index, &num_data_in_left);
const data_size_t* data_indices_in_right = data_partition_->GetIndexOnLeaf(right_leaf_index, &num_data_in_right);
if (config_->use_quantized_grad) {
int32_t sum_left_gradient = 0;
int32_t sum_left_hessian = 0;
int32_t sum_right_gradient = 0;
int32_t sum_right_hessian = 0;
const int8_t* discretized_grad_and_hess = gradient_discretizer_->discretized_gradients_and_hessians();
for (data_size_t i = 0; i < num_data_in_left; ++i) {
const data_size_t index = data_indices_in_left[i];
sum_left_gradient += discretized_grad_and_hess[2 * index + 1];
sum_left_hessian += discretized_grad_and_hess[2 * index];
}
for (data_size_t i = 0; i < num_data_in_right; ++i) {
const data_size_t index = data_indices_in_right[i];
sum_right_gradient += discretized_grad_and_hess[2 * index + 1];
sum_right_hessian += discretized_grad_and_hess[2 * index];
}
Log::Warning("============================ start leaf split info ============================");
Log::Warning("left_leaf_index = %d, right_leaf_index = %d", left_leaf_index, right_leaf_index);
Log::Warning("num_data_in_left = %d, num_data_in_right = %d", num_data_in_left, num_data_in_right);
Log::Warning("sum_left_gradient = %d, best_split_info->left_sum_gradient_and_hessian.gradient = %d", sum_left_gradient,
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32));
Log::Warning("sum_left_hessian = %d, best_split_info->left_sum_gradient_and_hessian.hessian = %d", sum_left_hessian,
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
Log::Warning("sum_right_gradient = %d, best_split_info->right_sum_gradient_and_hessian.gradient = %d", sum_right_gradient,
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
Log::Warning("sum_right_hessian = %d, best_split_info->right_sum_gradient_and_hessian.hessian = %d", sum_right_hessian,
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
CHECK_EQ(num_data_in_left, best_split_info.left_count);
CHECK_EQ(num_data_in_right, best_split_info.right_count);
CHECK_EQ(sum_left_gradient, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32))
CHECK_EQ(sum_left_hessian, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
CHECK_EQ(sum_right_gradient, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
CHECK_EQ(sum_right_hessian, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
Log::Warning("============================ end leaf split info ============================");
}
}
#endif
} // namespace LightGBM } // namespace LightGBM

Просмотреть файл

@ -171,7 +171,9 @@ class SerialTreeLearner: public TreeLearner {
std::set<int> FindAllForceFeatures(Json force_split_leaf_setting); std::set<int> FindAllForceFeatures(Json force_split_leaf_setting);
#ifdef DEBUG
void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index); void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index);
#endif
/*! /*!
* \brief Get the number of data in a leaf * \brief Get the number of data in a leaf

Просмотреть файл

@ -1838,7 +1838,6 @@ def test_distributed_quantized_training(cluster):
'num_grad_quant_bins': 30, 'num_grad_quant_bins': 30,
'quant_train_renew_leaf': True, 'quant_train_renew_leaf': True,
'verbose': -1, 'verbose': -1,
'force_row_wise': True,
} }
quant_dask_classifier = lgb.DaskLGBMRegressor( quant_dask_classifier = lgb.DaskLGBMRegressor(