From f901f47141700f01294c6766bd93bcc57ddb1960 Mon Sep 17 00:00:00 2001 From: shiyu1994 Date: Sun, 8 Oct 2023 23:25:46 +0800 Subject: [PATCH] [CUDA] CUDA Quantized Training (fixes #5606) (#5933) * add quantized training (first stage) * add histogram construction functions for integer gradients * add stochastic rounding * update docs * fix compilation errors by adding template instantiations * update files for compilation * fix compilation of gpu version * initialize gradient discretizer before share states * add a test case for quantized training * add quantized training for data distributed training * Delete origin.pred * Delete ifelse.pred * Delete LightGBM_model.txt * remove useless changes * fix lint error * remove debug loggings * fix mismatch of vector and allocator types * remove changes in main.cpp * fix bugs with uninitialized gradient discretizer * initialize ordered gradients in gradient discretizer * disable quantized training with gpu and cuda fix msvc compilation errors and warnings * fix bug in data parallel tree learner * make quantized training test deterministic * make quantized training in test case more accurate * refactor test_quantized_training * fix leaf splits initialization with quantized training * check distributed quantized training result * add cuda gradient discretizer * add quantized training for CUDA version in tree learner * remove cuda computability 6.1 and 6.2 * fix parts of gpu quantized training errors and warnings * fix build-python.sh to install locally built version * fix memory access bugs * fix lint errors * mark cuda quantized training on cuda with categorical features as unsupported * rename cuda_utils.h to cuda_utils.hu * enable quantized training with cuda * fix cuda quantized training with sparse row data * allow using global memory buffer in histogram construction with cuda quantized training * recover build-python.sh enlarge allowed package size to 100M --- .ci/check_python_dists.sh | 2 +- include/LightGBM/cuda/cuda_algorithms.hpp | 2 +- include/LightGBM/cuda/cuda_column_data.hpp | 2 +- include/LightGBM/cuda/cuda_metadata.hpp | 2 +- include/LightGBM/cuda/cuda_metric.hpp | 2 +- .../LightGBM/cuda/cuda_objective_function.hpp | 2 +- include/LightGBM/cuda/cuda_row_data.hpp | 2 +- include/LightGBM/cuda/cuda_split_info.hpp | 2 + .../cuda/{cuda_utils.h => cuda_utils.hu} | 18 +- include/LightGBM/sample_strategy.h | 2 +- src/boosting/cuda/cuda_score_updater.hpp | 2 +- src/cuda/cuda_utils.cpp | 2 +- src/io/config.cpp | 4 - src/metric/cuda/cuda_binary_metric.hpp | 2 +- src/metric/cuda/cuda_pointwise_metric.hpp | 2 +- src/metric/cuda/cuda_regression_metric.hpp | 2 +- .../cuda/cuda_best_split_finder.cpp | 19 +- .../cuda/cuda_best_split_finder.cu | 434 +++++++++++ .../cuda/cuda_best_split_finder.hpp | 31 +- src/treelearner/cuda/cuda_data_partition.cpp | 6 + src/treelearner/cuda/cuda_data_partition.cu | 47 ++ src/treelearner/cuda/cuda_data_partition.hpp | 8 + .../cuda/cuda_gradient_discretizer.cu | 171 +++++ .../cuda/cuda_gradient_discretizer.hpp | 118 +++ .../cuda/cuda_histogram_constructor.cpp | 110 ++- .../cuda/cuda_histogram_constructor.cu | 692 +++++++++++++++--- .../cuda/cuda_histogram_constructor.hpp | 61 +- src/treelearner/cuda/cuda_leaf_splits.cpp | 56 +- src/treelearner/cuda/cuda_leaf_splits.cu | 132 +++- src/treelearner/cuda/cuda_leaf_splits.hpp | 33 +- .../cuda/cuda_single_gpu_tree_learner.cpp | 154 +++- .../cuda/cuda_single_gpu_tree_learner.cu | 38 +- .../cuda/cuda_single_gpu_tree_learner.hpp | 11 +- 33 files changed, 1912 insertions(+), 259 deletions(-) rename include/LightGBM/cuda/{cuda_utils.h => cuda_utils.hu} (91%) create mode 100644 src/treelearner/cuda/cuda_gradient_discretizer.cu create mode 100644 src/treelearner/cuda/cuda_gradient_discretizer.hpp diff --git a/.ci/check_python_dists.sh b/.ci/check_python_dists.sh index 217fd3317..1dd19679d 100644 --- a/.ci/check_python_dists.sh +++ b/.ci/check_python_dists.sh @@ -25,7 +25,7 @@ if [ $PY_MINOR_VER -gt 7 ]; then pydistcheck \ --inspect \ --ignore 'compiled-objects-have-debug-symbols,distro-too-large-compressed' \ - --max-allowed-size-uncompressed '70M' \ + --max-allowed-size-uncompressed '100M' \ --max-allowed-files 800 \ ${DIST_DIR}/* || exit -1 elif { test $(uname -m) = "aarch64"; }; then diff --git a/include/LightGBM/cuda/cuda_algorithms.hpp b/include/LightGBM/cuda/cuda_algorithms.hpp index ab3328bb5..f79fc57e4 100644 --- a/include/LightGBM/cuda/cuda_algorithms.hpp +++ b/include/LightGBM/cuda/cuda_algorithms.hpp @@ -13,7 +13,7 @@ #include #include -#include +#include #include #include diff --git a/include/LightGBM/cuda/cuda_column_data.hpp b/include/LightGBM/cuda/cuda_column_data.hpp index 5b2301ac8..314a17885 100644 --- a/include/LightGBM/cuda/cuda_column_data.hpp +++ b/include/LightGBM/cuda/cuda_column_data.hpp @@ -9,7 +9,7 @@ #define LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_ #include -#include +#include #include #include diff --git a/include/LightGBM/cuda/cuda_metadata.hpp b/include/LightGBM/cuda/cuda_metadata.hpp index bc7339a84..5882ce7e0 100644 --- a/include/LightGBM/cuda/cuda_metadata.hpp +++ b/include/LightGBM/cuda/cuda_metadata.hpp @@ -8,7 +8,7 @@ #ifndef LIGHTGBM_CUDA_CUDA_METADATA_HPP_ #define LIGHTGBM_CUDA_CUDA_METADATA_HPP_ -#include +#include #include #include diff --git a/include/LightGBM/cuda/cuda_metric.hpp b/include/LightGBM/cuda/cuda_metric.hpp index 9186ceea1..2540b0c1a 100644 --- a/include/LightGBM/cuda/cuda_metric.hpp +++ b/include/LightGBM/cuda/cuda_metric.hpp @@ -9,7 +9,7 @@ #ifdef USE_CUDA -#include +#include #include namespace LightGBM { diff --git a/include/LightGBM/cuda/cuda_objective_function.hpp b/include/LightGBM/cuda/cuda_objective_function.hpp index fae8aa7ec..465ed3341 100644 --- a/include/LightGBM/cuda/cuda_objective_function.hpp +++ b/include/LightGBM/cuda/cuda_objective_function.hpp @@ -9,7 +9,7 @@ #ifdef USE_CUDA -#include +#include #include #include diff --git a/include/LightGBM/cuda/cuda_row_data.hpp b/include/LightGBM/cuda/cuda_row_data.hpp index 0386db0dc..1d4cb2f73 100644 --- a/include/LightGBM/cuda/cuda_row_data.hpp +++ b/include/LightGBM/cuda/cuda_row_data.hpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include #include diff --git a/include/LightGBM/cuda/cuda_split_info.hpp b/include/LightGBM/cuda/cuda_split_info.hpp index 46b35ca37..f01ce2b02 100644 --- a/include/LightGBM/cuda/cuda_split_info.hpp +++ b/include/LightGBM/cuda/cuda_split_info.hpp @@ -24,12 +24,14 @@ class CUDASplitInfo { double left_sum_gradients; double left_sum_hessians; + int64_t left_sum_of_gradients_hessians; data_size_t left_count; double left_gain; double left_value; double right_sum_gradients; double right_sum_hessians; + int64_t right_sum_of_gradients_hessians; data_size_t right_count; double right_gain; double right_value; diff --git a/include/LightGBM/cuda/cuda_utils.h b/include/LightGBM/cuda/cuda_utils.hu similarity index 91% rename from include/LightGBM/cuda/cuda_utils.h rename to include/LightGBM/cuda/cuda_utils.hu index 953bf9f12..4bd84aeb2 100644 --- a/include/LightGBM/cuda/cuda_utils.h +++ b/include/LightGBM/cuda/cuda_utils.hu @@ -7,15 +7,21 @@ #define LIGHTGBM_CUDA_CUDA_UTILS_H_ #ifdef USE_CUDA + #include #include #include + #include + +#include #include #include namespace LightGBM { +typedef unsigned long long atomic_add_long_t; + #define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { if (code != cudaSuccess) { @@ -125,13 +131,19 @@ class CUDAVector { T* new_data = nullptr; AllocateCUDAMemory(&new_data, size, __FILE__, __LINE__); if (size_ > 0 && data_ != nullptr) { - CopyFromCUDADeviceToCUDADevice(new_data, data_, size, __FILE__, __LINE__); + const size_t size_for_old_content = std::min(size_, size); + CopyFromCUDADeviceToCUDADevice(new_data, data_, size_for_old_content, __FILE__, __LINE__); } DeallocateCUDAMemory(&data_, __FILE__, __LINE__); data_ = new_data; size_ = size; } + void InitFromHostVector(const std::vector& host_vector) { + Resize(host_vector.size()); + CopyFromHostToCUDADevice(data_, host_vector.data(), host_vector.size(), __FILE__, __LINE__); + } + void Clear() { if (size_ > 0 && data_ != nullptr) { DeallocateCUDAMemory(&data_, __FILE__, __LINE__); @@ -171,6 +183,10 @@ class CUDAVector { return data_; } + void SetValue(int value) { + SetCUDAMemory(data_, value, size_, __FILE__, __LINE__); + } + const T* RawDataReadOnly() const { return data_; } diff --git a/include/LightGBM/sample_strategy.h b/include/LightGBM/sample_strategy.h index 51d3cbc16..4ea5cfc5f 100644 --- a/include/LightGBM/sample_strategy.h +++ b/include/LightGBM/sample_strategy.h @@ -6,7 +6,7 @@ #ifndef LIGHTGBM_SAMPLE_STRATEGY_H_ #define LIGHTGBM_SAMPLE_STRATEGY_H_ -#include +#include #include #include #include diff --git a/src/boosting/cuda/cuda_score_updater.hpp b/src/boosting/cuda/cuda_score_updater.hpp index ec728777e..cb79b43b9 100644 --- a/src/boosting/cuda/cuda_score_updater.hpp +++ b/src/boosting/cuda/cuda_score_updater.hpp @@ -8,7 +8,7 @@ #ifdef USE_CUDA -#include +#include #include "../score_updater.hpp" diff --git a/src/cuda/cuda_utils.cpp b/src/cuda/cuda_utils.cpp index a7d0df697..b601f9395 100644 --- a/src/cuda/cuda_utils.cpp +++ b/src/cuda/cuda_utils.cpp @@ -5,7 +5,7 @@ #ifdef USE_CUDA -#include +#include namespace LightGBM { diff --git a/src/io/config.cpp b/src/io/config.cpp index e85780469..e25bb6d4f 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -389,10 +389,6 @@ void Config::CheckParamConflict() { if (deterministic) { Log::Warning("Although \"deterministic\" is set, the results ran by GPU may be non-deterministic."); } - if (use_quantized_grad) { - Log::Warning("Quantized training is not supported by CUDA tree learner. Switch to full precision training."); - use_quantized_grad = false; - } } // linear tree learner must be serial type and run on CPU device if (linear_tree) { diff --git a/src/metric/cuda/cuda_binary_metric.hpp b/src/metric/cuda/cuda_binary_metric.hpp index 86dbdce98..0f61063c6 100644 --- a/src/metric/cuda/cuda_binary_metric.hpp +++ b/src/metric/cuda/cuda_binary_metric.hpp @@ -10,7 +10,7 @@ #ifdef USE_CUDA #include -#include +#include #include diff --git a/src/metric/cuda/cuda_pointwise_metric.hpp b/src/metric/cuda/cuda_pointwise_metric.hpp index dae1c6a7f..fafeafe63 100644 --- a/src/metric/cuda/cuda_pointwise_metric.hpp +++ b/src/metric/cuda/cuda_pointwise_metric.hpp @@ -10,7 +10,7 @@ #ifdef USE_CUDA #include -#include +#include #include diff --git a/src/metric/cuda/cuda_regression_metric.hpp b/src/metric/cuda/cuda_regression_metric.hpp index 4cfd996d8..e69bd2215 100644 --- a/src/metric/cuda/cuda_regression_metric.hpp +++ b/src/metric/cuda/cuda_regression_metric.hpp @@ -10,7 +10,7 @@ #ifdef USE_CUDA #include -#include +#include #include diff --git a/src/treelearner/cuda/cuda_best_split_finder.cpp b/src/treelearner/cuda/cuda_best_split_finder.cpp index 761b62f21..957585428 100644 --- a/src/treelearner/cuda/cuda_best_split_finder.cpp +++ b/src/treelearner/cuda/cuda_best_split_finder.cpp @@ -40,6 +40,9 @@ CUDABestSplitFinder::CUDABestSplitFinder( select_features_by_node_(select_features_by_node), cuda_hist_(cuda_hist) { InitFeatureMetaInfo(train_data); + if (has_categorical_feature_ && config->use_quantized_grad) { + Log::Fatal("Quantized training on GPU with categorical features is not supported yet."); + } cuda_leaf_best_split_info_ = nullptr; cuda_best_split_info_ = nullptr; cuda_best_split_info_buffer_ = nullptr; @@ -326,13 +329,23 @@ void CUDABestSplitFinder::FindBestSplitsForLeaf( const data_size_t num_data_in_smaller_leaf, const data_size_t num_data_in_larger_leaf, const double sum_hessians_in_smaller_leaf, - const double sum_hessians_in_larger_leaf) { + const double sum_hessians_in_larger_leaf, + const score_t* grad_scale, + const score_t* hess_scale, + const uint8_t smaller_num_bits_in_histogram_bins, + const uint8_t larger_num_bits_in_histogram_bins) { const bool is_smaller_leaf_valid = (num_data_in_smaller_leaf > min_data_in_leaf_ && sum_hessians_in_smaller_leaf > min_sum_hessian_in_leaf_); const bool is_larger_leaf_valid = (num_data_in_larger_leaf > min_data_in_leaf_ && sum_hessians_in_larger_leaf > min_sum_hessian_in_leaf_ && larger_leaf_index >= 0); - LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits, - smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid); + if (grad_scale != nullptr && hess_scale != nullptr) { + LaunchFindBestSplitsDiscretizedForLeafKernel(smaller_leaf_splits, larger_leaf_splits, + smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid, + grad_scale, hess_scale, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins); + } else { + LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits, + smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid); + } global_timer.Start("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel"); LaunchSyncBestSplitForLeafKernel(smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid); SynchronizeCUDADevice(__FILE__, __LINE__); diff --git a/src/treelearner/cuda/cuda_best_split_finder.cu b/src/treelearner/cuda/cuda_best_split_finder.cu index 3fee55629..d5c819d39 100644 --- a/src/treelearner/cuda/cuda_best_split_finder.cu +++ b/src/treelearner/cuda/cuda_best_split_finder.cu @@ -320,6 +320,175 @@ __device__ void FindBestSplitsForLeafKernelInner( } } +template +__device__ void FindBestSplitsDiscretizedForLeafKernelInner( + // input feature information + const BIN_HIST_TYPE* feature_hist_ptr, + // input task information + const SplitFindTask* task, + CUDARandom* cuda_random, + // input config parameter values + const double lambda_l1, + const double lambda_l2, + const double path_smooth, + const data_size_t min_data_in_leaf, + const double min_sum_hessian_in_leaf, + const double min_gain_to_split, + // input parent node information + const double parent_gain, + const int64_t sum_gradients_hessians, + const data_size_t num_data, + const double parent_output, + // gradient scale + const double grad_scale, + const double hess_scale, + // output parameters + CUDASplitInfo* cuda_best_split_info) { + const double sum_hessians = static_cast(sum_gradients_hessians & 0x00000000ffffffff) * hess_scale; + const double cnt_factor = num_data / sum_hessians; + const double min_gain_shift = parent_gain + min_gain_to_split; + + cuda_best_split_info->is_valid = false; + + ACC_HIST_TYPE local_grad_hess_hist = 0; + double local_gain = 0.0f; + bool threshold_found = false; + uint32_t threshold_value = 0; + __shared__ int rand_threshold; + if (USE_RAND && threadIdx.x == 0) { + if (task->num_bin - 2 > 0) { + rand_threshold = cuda_random->NextInt(0, task->num_bin - 2); + } + } + __shared__ uint32_t best_thread_index; + __shared__ double shared_double_buffer[32]; + __shared__ bool shared_bool_buffer[32]; + __shared__ uint32_t shared_int_buffer[64]; + const unsigned int threadIdx_x = threadIdx.x; + const bool skip_sum = REVERSE ? + (task->skip_default_bin && (task->num_bin - 1 - threadIdx_x) == static_cast(task->default_bin)) : + (task->skip_default_bin && (threadIdx_x + task->mfb_offset) == static_cast(task->default_bin)); + const uint32_t feature_num_bin_minus_offset = task->num_bin - task->mfb_offset; + if (!REVERSE) { + if (threadIdx_x < feature_num_bin_minus_offset && !skip_sum) { + const unsigned int bin_offset = threadIdx_x; + if (USE_16BIT_BIN_HIST && !USE_16BIT_ACC_HIST) { + const int32_t local_grad_hess_hist_int32 = feature_hist_ptr[bin_offset]; + local_grad_hess_hist = (static_cast(static_cast(local_grad_hess_hist_int32 >> 16)) << 32) | (static_cast(local_grad_hess_hist_int32 & 0x0000ffff)); + } else { + local_grad_hess_hist = feature_hist_ptr[bin_offset]; + } + } + } else { + if (threadIdx_x >= static_cast(task->na_as_missing) && + threadIdx_x < feature_num_bin_minus_offset && !skip_sum) { + const unsigned int read_index = feature_num_bin_minus_offset - 1 - threadIdx_x; + if (USE_16BIT_BIN_HIST && !USE_16BIT_ACC_HIST) { + const int32_t local_grad_hess_hist_int32 = feature_hist_ptr[read_index]; + local_grad_hess_hist = (static_cast(static_cast(local_grad_hess_hist_int32 >> 16)) << 32) | (static_cast(local_grad_hess_hist_int32 & 0x0000ffff)); + } else { + local_grad_hess_hist = feature_hist_ptr[read_index]; + } + } + } + __syncthreads(); + local_gain = kMinScore; + local_grad_hess_hist = ShufflePrefixSum(local_grad_hess_hist, reinterpret_cast(shared_int_buffer)); + double sum_left_gradient = 0.0f; + double sum_left_hessian = 0.0f; + double sum_right_gradient = 0.0f; + double sum_right_hessian = 0.0f; + data_size_t left_count = 0; + data_size_t right_count = 0; + int64_t sum_left_gradient_hessian = 0; + int64_t sum_right_gradient_hessian = 0; + if (REVERSE) { + if (threadIdx_x >= static_cast(task->na_as_missing) && threadIdx_x <= task->num_bin - 2 && !skip_sum) { + sum_right_gradient_hessian = USE_16BIT_ACC_HIST ? + (static_cast(static_cast(local_grad_hess_hist >> 16)) << 32) | static_cast(local_grad_hess_hist & 0x0000ffff) : + local_grad_hess_hist; + sum_right_gradient = static_cast(static_cast((sum_right_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale; + sum_right_hessian = static_cast(static_cast(sum_right_gradient_hessian & 0x00000000ffffffff)) * hess_scale; + right_count = static_cast(__double2int_rn(sum_right_hessian * cnt_factor)); + sum_left_gradient_hessian = sum_gradients_hessians - sum_right_gradient_hessian; + sum_left_gradient = static_cast(static_cast((sum_left_gradient_hessian & 0xffffffff00000000)>> 32)) * grad_scale; + sum_left_hessian = static_cast(static_cast(sum_left_gradient_hessian & 0x00000000ffffffff)) * hess_scale; + left_count = num_data - right_count; + if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf && + sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf && + (!USE_RAND || static_cast(task->num_bin - 2 - threadIdx_x) == rand_threshold)) { + double current_gain = CUDALeafSplits::GetSplitGains( + sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient, + sum_right_hessian + kEpsilon, lambda_l1, + lambda_l2, path_smooth, left_count, right_count, parent_output); + // gain with split is worse than without split + if (current_gain > min_gain_shift) { + local_gain = current_gain - min_gain_shift; + threshold_value = static_cast(task->num_bin - 2 - threadIdx_x); + threshold_found = true; + } + } + } + } else { + if (threadIdx_x <= feature_num_bin_minus_offset - 2 && !skip_sum) { + sum_left_gradient_hessian = USE_16BIT_ACC_HIST ? + (static_cast(static_cast(local_grad_hess_hist >> 16)) << 32) | static_cast(local_grad_hess_hist & 0x0000ffff) : + local_grad_hess_hist; + sum_left_gradient = static_cast(static_cast((sum_left_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale; + sum_left_hessian = static_cast(static_cast(sum_left_gradient_hessian & 0x00000000ffffffff)) * hess_scale; + left_count = static_cast(__double2int_rn(sum_left_hessian * cnt_factor)); + sum_right_gradient_hessian = sum_gradients_hessians - sum_left_gradient_hessian; + sum_right_gradient = static_cast(static_cast((sum_right_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale; + sum_right_hessian = static_cast(static_cast(sum_right_gradient_hessian & 0x00000000ffffffff)) * hess_scale; + right_count = num_data - left_count; + if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf && + sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf && + (!USE_RAND || static_cast(threadIdx_x + task->mfb_offset) == rand_threshold)) { + double current_gain = CUDALeafSplits::GetSplitGains( + sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient, + sum_right_hessian + kEpsilon, lambda_l1, + lambda_l2, path_smooth, left_count, right_count, parent_output); + // gain with split is worse than without split + if (current_gain > min_gain_shift) { + local_gain = current_gain - min_gain_shift; + threshold_value = static_cast(threadIdx_x + task->mfb_offset); + threshold_found = true; + } + } + } + } + __syncthreads(); + const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_double_buffer, shared_bool_buffer, shared_int_buffer); + if (threadIdx_x == 0) { + best_thread_index = result; + } + __syncthreads(); + if (threshold_found && threadIdx_x == best_thread_index) { + cuda_best_split_info->is_valid = true; + cuda_best_split_info->threshold = threshold_value; + cuda_best_split_info->gain = local_gain; + cuda_best_split_info->default_left = task->assume_out_default_left; + const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput(sum_left_gradient, + sum_left_hessian, lambda_l1, lambda_l2, path_smooth, left_count, parent_output); + const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput(sum_right_gradient, + sum_right_hessian, lambda_l1, lambda_l2, path_smooth, right_count, parent_output); + cuda_best_split_info->left_sum_gradients = sum_left_gradient; + cuda_best_split_info->left_sum_hessians = sum_left_hessian; + cuda_best_split_info->left_sum_of_gradients_hessians = sum_left_gradient_hessian; + cuda_best_split_info->left_count = left_count; + cuda_best_split_info->right_sum_gradients = sum_right_gradient; + cuda_best_split_info->right_sum_hessians = sum_right_hessian; + cuda_best_split_info->right_sum_of_gradients_hessians = sum_right_gradient_hessian; + cuda_best_split_info->right_count = right_count; + cuda_best_split_info->left_value = left_output; + cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput(sum_left_gradient, + sum_left_hessian, lambda_l1, lambda_l2, left_output); + cuda_best_split_info->right_value = right_output; + cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput(sum_right_gradient, + sum_right_hessian, lambda_l1, lambda_l2, right_output); + } +} + template __device__ void FindBestSplitsForLeafKernelCategoricalInner( // input feature information @@ -715,6 +884,169 @@ __global__ void FindBestSplitsForLeafKernel( } } + +template +__global__ void FindBestSplitsDiscretizedForLeafKernel( + // input feature information + const int8_t* is_feature_used_bytree, + // input task information + const int num_tasks, + const SplitFindTask* tasks, + CUDARandom* cuda_randoms, + // input leaf information + const CUDALeafSplitsStruct* smaller_leaf_splits, + const CUDALeafSplitsStruct* larger_leaf_splits, + const uint8_t smaller_leaf_num_bits_in_histogram_bin, + const uint8_t larger_leaf_num_bits_in_histogram_bin, + // input config parameter values + const data_size_t min_data_in_leaf, + const double min_sum_hessian_in_leaf, + const double min_gain_to_split, + const double lambda_l1, + const double lambda_l2, + const double path_smooth, + const double cat_smooth, + const double cat_l2, + const int max_cat_threshold, + const int min_data_per_group, + const int max_cat_to_onehot, + // gradient scale + const score_t* grad_scale, + const score_t* hess_scale, + // output + CUDASplitInfo* cuda_best_split_info) { + const unsigned int task_index = blockIdx.x; + const SplitFindTask* task = tasks + task_index; + const int inner_feature_index = task->inner_feature_index; + const double parent_gain = IS_LARGER ? larger_leaf_splits->gain : smaller_leaf_splits->gain; + const int64_t sum_gradients_hessians = IS_LARGER ? larger_leaf_splits->sum_of_gradients_hessians : smaller_leaf_splits->sum_of_gradients_hessians; + const data_size_t num_data = IS_LARGER ? larger_leaf_splits->num_data_in_leaf : smaller_leaf_splits->num_data_in_leaf; + const double parent_output = IS_LARGER ? larger_leaf_splits->leaf_value : smaller_leaf_splits->leaf_value; + const unsigned int output_offset = IS_LARGER ? (task_index + num_tasks) : task_index; + CUDASplitInfo* out = cuda_best_split_info + output_offset; + CUDARandom* cuda_random = USE_RAND ? + (IS_LARGER ? cuda_randoms + task_index * 2 + 1 : cuda_randoms + task_index * 2) : nullptr; + const bool use_16bit_bin = IS_LARGER ? (larger_leaf_num_bits_in_histogram_bin <= 16) : (smaller_leaf_num_bits_in_histogram_bin <= 16); + if (is_feature_used_bytree[inner_feature_index]) { + if (task->is_categorical) { + __threadfence(); // ensure store issued before trap + asm("trap;"); + } else { + if (!task->reverse) { + if (use_16bit_bin) { + const int32_t* hist_ptr = + reinterpret_cast(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset; + FindBestSplitsDiscretizedForLeafKernelInner( + // input feature information + hist_ptr, + // input task information + task, + cuda_random, + // input config parameter values + lambda_l1, + lambda_l2, + path_smooth, + min_data_in_leaf, + min_sum_hessian_in_leaf, + min_gain_to_split, + // input parent node information + parent_gain, + sum_gradients_hessians, + num_data, + parent_output, + // gradient scale + *grad_scale, + *hess_scale, + // output parameters + out); + } else { + const int32_t* hist_ptr = + reinterpret_cast(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset; + FindBestSplitsDiscretizedForLeafKernelInner( + // input feature information + hist_ptr, + // input task information + task, + cuda_random, + // input config parameter values + lambda_l1, + lambda_l2, + path_smooth, + min_data_in_leaf, + min_sum_hessian_in_leaf, + min_gain_to_split, + // input parent node information + parent_gain, + sum_gradients_hessians, + num_data, + parent_output, + // gradient scale + *grad_scale, + *hess_scale, + // output parameters + out); + } + } else { + if (use_16bit_bin) { + const int32_t* hist_ptr = + reinterpret_cast(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset; + FindBestSplitsDiscretizedForLeafKernelInner( + // input feature information + hist_ptr, + // input task information + task, + cuda_random, + // input config parameter values + lambda_l1, + lambda_l2, + path_smooth, + min_data_in_leaf, + min_sum_hessian_in_leaf, + min_gain_to_split, + // input parent node information + parent_gain, + sum_gradients_hessians, + num_data, + parent_output, + // gradient scale + *grad_scale, + *hess_scale, + // output parameters + out); + } else { + const int32_t* hist_ptr = + reinterpret_cast(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset; + FindBestSplitsDiscretizedForLeafKernelInner( + // input feature information + hist_ptr, + // input task information + task, + cuda_random, + // input config parameter values + lambda_l1, + lambda_l2, + path_smooth, + min_data_in_leaf, + min_sum_hessian_in_leaf, + min_gain_to_split, + // input parent node information + parent_gain, + sum_gradients_hessians, + num_data, + parent_output, + // gradient scale + *grad_scale, + *hess_scale, + // output parameters + out); + } + } + } + } else { + out->is_valid = false; + } +} + template __device__ void FindBestSplitsForLeafKernelInner_GlobalMemory( // input feature information @@ -1466,6 +1798,108 @@ void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner2(LaunchFindBest #undef FindBestSplitsForLeafKernel_ARGS #undef GlobalMemory_Buffer_ARGS + +#define LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS \ + const CUDALeafSplitsStruct* smaller_leaf_splits, \ + const CUDALeafSplitsStruct* larger_leaf_splits, \ + const int smaller_leaf_index, \ + const int larger_leaf_index, \ + const bool is_smaller_leaf_valid, \ + const bool is_larger_leaf_valid, \ + const score_t* grad_scale, \ + const score_t* hess_scale, \ + const uint8_t smaller_num_bits_in_histogram_bins, \ + const uint8_t larger_num_bits_in_histogram_bins + +#define LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS \ + smaller_leaf_splits, \ + larger_leaf_splits, \ + smaller_leaf_index, \ + larger_leaf_index, \ + is_smaller_leaf_valid, \ + is_larger_leaf_valid, \ + grad_scale, \ + hess_scale, \ + smaller_num_bits_in_histogram_bins, \ + larger_num_bits_in_histogram_bins + +#define FindBestSplitsDiscretizedForLeafKernel_ARGS \ + cuda_is_feature_used_bytree_, \ + num_tasks_, \ + cuda_split_find_tasks_.RawData(), \ + cuda_randoms_.RawData(), \ + smaller_leaf_splits, \ + larger_leaf_splits, \ + smaller_num_bits_in_histogram_bins, \ + larger_num_bits_in_histogram_bins, \ + min_data_in_leaf_, \ + min_sum_hessian_in_leaf_, \ + min_gain_to_split_, \ + lambda_l1_, \ + lambda_l2_, \ + path_smooth_, \ + cat_smooth_, \ + cat_l2_, \ + max_cat_threshold_, \ + min_data_per_group_, \ + max_cat_to_onehot_, \ + grad_scale, \ + hess_scale, \ + cuda_best_split_info_ + +void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernel(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) { + if (!is_smaller_leaf_valid && !is_larger_leaf_valid) { + return; + } + if (!extra_trees_) { + LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS); + } else { + LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS); + } +} + +template +void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) { + if (lambda_l1_ <= 0.0f) { + LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS); + } else { + LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS); + } +} + +template +void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) { + if (!use_smoothing_) { + LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS); + } else { + LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS); + } +} + +template +void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) { + if (!use_global_memory_) { + if (is_smaller_leaf_valid) { + FindBestSplitsDiscretizedForLeafKernel + <<>> + (FindBestSplitsDiscretizedForLeafKernel_ARGS); + } + SynchronizeCUDADevice(__FILE__, __LINE__); + if (is_larger_leaf_valid) { + FindBestSplitsDiscretizedForLeafKernel + <<>> + (FindBestSplitsDiscretizedForLeafKernel_ARGS); + } + } else { + // TODO(shiyu1994) + } +} + +#undef LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS +#undef LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS +#undef FindBestSplitsDiscretizedForLeafKernel_ARGS + + __device__ void ReduceBestSplit(bool* found, double* gain, uint32_t* shared_read_index, uint32_t num_features_aligned) { const uint32_t threadIdx_x = threadIdx.x; diff --git a/src/treelearner/cuda/cuda_best_split_finder.hpp b/src/treelearner/cuda/cuda_best_split_finder.hpp index 69f8169f8..2d9940312 100644 --- a/src/treelearner/cuda/cuda_best_split_finder.hpp +++ b/src/treelearner/cuda/cuda_best_split_finder.hpp @@ -67,7 +67,11 @@ class CUDABestSplitFinder { const data_size_t num_data_in_smaller_leaf, const data_size_t num_data_in_larger_leaf, const double sum_hessians_in_smaller_leaf, - const double sum_hessians_in_larger_leaf); + const double sum_hessians_in_larger_leaf, + const score_t* grad_scale, + const score_t* hess_scale, + const uint8_t smaller_num_bits_in_histogram_bins, + const uint8_t larger_num_bits_in_histogram_bins); const CUDASplitInfo* FindBestFromAllSplits( const int cur_num_leaves, @@ -114,6 +118,31 @@ class CUDABestSplitFinder { #undef LaunchFindBestSplitsForLeafKernel_PARAMS + #define LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS \ + const CUDALeafSplitsStruct* smaller_leaf_splits, \ + const CUDALeafSplitsStruct* larger_leaf_splits, \ + const int smaller_leaf_index, \ + const int larger_leaf_index, \ + const bool is_smaller_leaf_valid, \ + const bool is_larger_leaf_valid, \ + const score_t* grad_scale, \ + const score_t* hess_scale, \ + const uint8_t smaller_num_bits_in_histogram_bins, \ + const uint8_t larger_num_bits_in_histogram_bins + + void LaunchFindBestSplitsDiscretizedForLeafKernel(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS); + + template + void LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS); + + template + void LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS); + + template + void LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS); + + #undef LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS + void LaunchSyncBestSplitForLeafKernel( const int host_smaller_leaf_index, const int host_larger_leaf_index, diff --git a/src/treelearner/cuda/cuda_data_partition.cpp b/src/treelearner/cuda/cuda_data_partition.cpp index 3ad157ef0..c09021ad3 100644 --- a/src/treelearner/cuda/cuda_data_partition.cpp +++ b/src/treelearner/cuda/cuda_data_partition.cpp @@ -368,6 +368,12 @@ void CUDADataPartition::ResetByLeafPred(const std::vector& leaf_pred, int n cur_num_leaves_ = num_leaves; } +void CUDADataPartition::ReduceLeafGradStat( + const score_t* gradients, const score_t* hessians, + CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const { + LaunchReduceLeafGradStat(gradients, hessians, tree, leaf_grad_stat_buffer, leaf_hess_state_buffer); +} + } // namespace LightGBM #endif // USE_CUDA diff --git a/src/treelearner/cuda/cuda_data_partition.cu b/src/treelearner/cuda/cuda_data_partition.cu index b1d3fa496..3090b7a84 100644 --- a/src/treelearner/cuda/cuda_data_partition.cu +++ b/src/treelearner/cuda/cuda_data_partition.cu @@ -1069,6 +1069,53 @@ void CUDADataPartition::LaunchAddPredictionToScoreKernel(const double* leaf_valu global_timer.Stop("CUDADataPartition::AddPredictionToScoreKernel"); } +__global__ void RenewDiscretizedTreeLeavesKernel( + const score_t* gradients, + const score_t* hessians, + const data_size_t* data_indices, + const data_size_t* leaf_data_start, + const data_size_t* leaf_num_data, + double* leaf_grad_stat_buffer, + double* leaf_hess_stat_buffer, + double* leaf_values) { + __shared__ double shared_mem_buffer[32]; + const int leaf_index = static_cast(blockIdx.x); + const data_size_t* data_indices_in_leaf = data_indices + leaf_data_start[leaf_index]; + const data_size_t num_data_in_leaf = leaf_num_data[leaf_index]; + double sum_gradients = 0.0f; + double sum_hessians = 0.0f; + for (data_size_t inner_data_index = static_cast(threadIdx.x); + inner_data_index < num_data_in_leaf; inner_data_index += static_cast(blockDim.x)) { + const data_size_t data_index = data_indices_in_leaf[inner_data_index]; + const score_t gradient = gradients[data_index]; + const score_t hessian = hessians[data_index]; + sum_gradients += static_cast(gradient); + sum_hessians += static_cast(hessian); + } + sum_gradients = ShuffleReduceSum(sum_gradients, shared_mem_buffer, blockDim.x); + __syncthreads(); + sum_hessians = ShuffleReduceSum(sum_hessians, shared_mem_buffer, blockDim.x); + if (threadIdx.x == 0) { + leaf_grad_stat_buffer[leaf_index] = sum_gradients; + leaf_hess_stat_buffer[leaf_index] = sum_hessians; + } +} + +void CUDADataPartition::LaunchReduceLeafGradStat( + const score_t* gradients, const score_t* hessians, + CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const { + const int num_blocks = tree->num_leaves(); + RenewDiscretizedTreeLeavesKernel<<>>( + gradients, + hessians, + cuda_data_indices_, + cuda_leaf_data_start_, + cuda_leaf_num_data_, + leaf_grad_stat_buffer, + leaf_hess_state_buffer, + tree->cuda_leaf_value_ref()); +} + } // namespace LightGBM #endif // USE_CUDA diff --git a/src/treelearner/cuda/cuda_data_partition.hpp b/src/treelearner/cuda/cuda_data_partition.hpp index 84050565c..f6bbab9b8 100644 --- a/src/treelearner/cuda/cuda_data_partition.hpp +++ b/src/treelearner/cuda/cuda_data_partition.hpp @@ -78,6 +78,10 @@ class CUDADataPartition { void ResetByLeafPred(const std::vector& leaf_pred, int num_leaves); + void ReduceLeafGradStat( + const score_t* gradients, const score_t* hessians, + CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const; + data_size_t root_num_data() const { if (use_bagging_) { return num_used_indices_; @@ -292,6 +296,10 @@ class CUDADataPartition { void LaunchFillDataIndexToLeafIndex(); + void LaunchReduceLeafGradStat( + const score_t* gradients, const score_t* hessians, + CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const; + // Host memory // dataset information diff --git a/src/treelearner/cuda/cuda_gradient_discretizer.cu b/src/treelearner/cuda/cuda_gradient_discretizer.cu new file mode 100644 index 000000000..bcea706b4 --- /dev/null +++ b/src/treelearner/cuda/cuda_gradient_discretizer.cu @@ -0,0 +1,171 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifdef USE_CUDA + +#include + +#include + +#include "cuda_gradient_discretizer.hpp" + +namespace LightGBM { + +__global__ void ReduceMinMaxKernel( + const data_size_t num_data, + const score_t* input_gradients, + const score_t* input_hessians, + score_t* grad_min_block_buffer, + score_t* grad_max_block_buffer, + score_t* hess_min_block_buffer, + score_t* hess_max_block_buffer) { + __shared__ score_t shared_mem_buffer[32]; + const data_size_t index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); + score_t grad_max_val = kMinScore; + score_t grad_min_val = kMaxScore; + score_t hess_max_val = kMinScore; + score_t hess_min_val = kMaxScore; + if (index < num_data) { + grad_max_val = input_gradients[index]; + grad_min_val = input_gradients[index]; + hess_max_val = input_hessians[index]; + hess_min_val = input_hessians[index]; + } + grad_min_val = ShuffleReduceMin(grad_min_val, shared_mem_buffer, blockDim.x); + __syncthreads(); + grad_max_val = ShuffleReduceMax(grad_max_val, shared_mem_buffer, blockDim.x); + __syncthreads(); + hess_min_val = ShuffleReduceMin(hess_min_val, shared_mem_buffer, blockDim.x); + __syncthreads(); + hess_max_val = ShuffleReduceMax(hess_max_val, shared_mem_buffer, blockDim.x); + if (threadIdx.x == 0) { + grad_min_block_buffer[blockIdx.x] = grad_min_val; + grad_max_block_buffer[blockIdx.x] = grad_max_val; + hess_min_block_buffer[blockIdx.x] = hess_min_val; + hess_max_block_buffer[blockIdx.x] = hess_max_val; + } +} + +__global__ void ReduceBlockMinMaxKernel( + const int num_blocks, + const int grad_discretize_bins, + score_t* grad_min_block_buffer, + score_t* grad_max_block_buffer, + score_t* hess_min_block_buffer, + score_t* hess_max_block_buffer) { + __shared__ score_t shared_mem_buffer[32]; + score_t grad_max_val = kMinScore; + score_t grad_min_val = kMaxScore; + score_t hess_max_val = kMinScore; + score_t hess_min_val = kMaxScore; + for (int block_index = static_cast(threadIdx.x); block_index < num_blocks; block_index += static_cast(blockDim.x)) { + grad_min_val = min(grad_min_val, grad_min_block_buffer[block_index]); + grad_max_val = max(grad_max_val, grad_max_block_buffer[block_index]); + hess_min_val = min(hess_min_val, hess_min_block_buffer[block_index]); + hess_max_val = max(hess_max_val, hess_max_block_buffer[block_index]); + } + grad_min_val = ShuffleReduceMin(grad_min_val, shared_mem_buffer, blockDim.x); + __syncthreads(); + grad_max_val = ShuffleReduceMax(grad_max_val, shared_mem_buffer, blockDim.x); + __syncthreads(); + hess_max_val = ShuffleReduceMax(hess_max_val, shared_mem_buffer, blockDim.x); + __syncthreads(); + hess_max_val = ShuffleReduceMax(hess_max_val, shared_mem_buffer, blockDim.x); + if (threadIdx.x == 0) { + const score_t grad_abs_max = max(fabs(grad_min_val), fabs(grad_max_val)); + const score_t hess_abs_max = max(fabs(hess_min_val), fabs(hess_max_val)); + grad_min_block_buffer[0] = 1.0f / (grad_abs_max / (grad_discretize_bins / 2)); + grad_max_block_buffer[0] = (grad_abs_max / (grad_discretize_bins / 2)); + hess_min_block_buffer[0] = 1.0f / (hess_abs_max / (grad_discretize_bins)); + hess_max_block_buffer[0] = (hess_abs_max / (grad_discretize_bins)); + } +} + +template +__global__ void DiscretizeGradientsKernel( + const data_size_t num_data, + const score_t* input_gradients, + const score_t* input_hessians, + const score_t* grad_scale_ptr, + const score_t* hess_scale_ptr, + const int iter, + const int* random_values_use_start, + const score_t* gradient_random_values, + const score_t* hessian_random_values, + const int grad_discretize_bins, + int8_t* output_gradients_and_hessians) { + const int start = random_values_use_start[iter]; + const data_size_t index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); + const score_t grad_scale = *grad_scale_ptr; + const score_t hess_scale = *hess_scale_ptr; + int16_t* output_gradients_and_hessians_ptr = reinterpret_cast(output_gradients_and_hessians); + if (index < num_data) { + if (STOCHASTIC_ROUNDING) { + const data_size_t index_offset = (index + start) % num_data; + const score_t gradient = input_gradients[index]; + const score_t hessian = input_hessians[index]; + const score_t gradient_random_value = gradient_random_values[index_offset]; + const score_t hessian_random_value = hessian_random_values[index_offset]; + output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ? + static_cast(gradient * grad_scale + gradient_random_value) : + static_cast(gradient * grad_scale - gradient_random_value); + output_gradients_and_hessians_ptr[2 * index] = static_cast(hessian * hess_scale + hessian_random_value); + } else { + const score_t gradient = input_gradients[index]; + const score_t hessian = input_hessians[index]; + output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ? + static_cast(gradient * grad_scale + 0.5) : + static_cast(gradient * grad_scale - 0.5); + output_gradients_and_hessians_ptr[2 * index] = static_cast(hessian * hess_scale + 0.5); + } + } +} + +void CUDAGradientDiscretizer::DiscretizeGradients( + const data_size_t num_data, + const score_t* input_gradients, + const score_t* input_hessians) { + ReduceMinMaxKernel<<>>( + num_data, input_gradients, input_hessians, + grad_min_block_buffer_.RawData(), + grad_max_block_buffer_.RawData(), + hess_min_block_buffer_.RawData(), + hess_max_block_buffer_.RawData()); + SynchronizeCUDADevice(__FILE__, __LINE__); + ReduceBlockMinMaxKernel<<<1, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>( + num_reduce_blocks_, + num_grad_quant_bins_, + grad_min_block_buffer_.RawData(), + grad_max_block_buffer_.RawData(), + hess_min_block_buffer_.RawData(), + hess_max_block_buffer_.RawData()); + SynchronizeCUDADevice(__FILE__, __LINE__); + + #define DiscretizeGradientsKernel_ARGS \ + num_data, \ + input_gradients, \ + input_hessians, \ + grad_min_block_buffer_.RawData(), \ + hess_min_block_buffer_.RawData(), \ + iter_, \ + random_values_use_start_.RawData(), \ + gradient_random_values_.RawData(), \ + hessian_random_values_.RawData(), \ + num_grad_quant_bins_, \ + discretized_gradients_and_hessians_.RawData() + + if (stochastic_rounding_) { + DiscretizeGradientsKernel<<>>(DiscretizeGradientsKernel_ARGS); + } else { + DiscretizeGradientsKernel<<>>(DiscretizeGradientsKernel_ARGS); + } + SynchronizeCUDADevice(__FILE__, __LINE__); + ++iter_; +} + +} // namespace LightGBM + +#endif // USE_CUDA diff --git a/src/treelearner/cuda/cuda_gradient_discretizer.hpp b/src/treelearner/cuda/cuda_gradient_discretizer.hpp new file mode 100644 index 000000000..d5c2fb0e0 --- /dev/null +++ b/src/treelearner/cuda/cuda_gradient_discretizer.hpp @@ -0,0 +1,118 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifndef LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_ +#define LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_ + +#ifdef USE_CUDA + +#include +#include +#include +#include + +#include +#include +#include + +#include "cuda_leaf_splits.hpp" +#include "../gradient_discretizer.hpp" + +namespace LightGBM { + +#define CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE (1024) + +class CUDAGradientDiscretizer: public GradientDiscretizer { + public: + CUDAGradientDiscretizer(int num_grad_quant_bins, int num_trees, int random_seed, bool is_constant_hessian, bool stochastic_roudning): + GradientDiscretizer(num_grad_quant_bins, num_trees, random_seed, is_constant_hessian, stochastic_roudning) { + } + + void DiscretizeGradients( + const data_size_t num_data, + const score_t* input_gradients, + const score_t* input_hessians) override; + + const int8_t* discretized_gradients_and_hessians() const override { return discretized_gradients_and_hessians_.RawData(); } + + double grad_scale() const override { + Log::Fatal("grad_scale() of CUDAGradientDiscretizer should not be called."); + return 0.0; + } + + double hess_scale() const override { + Log::Fatal("hess_scale() of CUDAGradientDiscretizer should not be called."); + return 0.0; + } + + const score_t* grad_scale_ptr() const { return grad_max_block_buffer_.RawData(); } + + const score_t* hess_scale_ptr() const { return hess_max_block_buffer_.RawData(); } + + void Init(const data_size_t num_data, const int num_leaves, + const int num_features, const Dataset* train_data) override { + GradientDiscretizer::Init(num_data, num_leaves, num_features, train_data); + discretized_gradients_and_hessians_.Resize(num_data * 2); + num_reduce_blocks_ = (num_data + CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE - 1) / CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE; + grad_min_block_buffer_.Resize(num_reduce_blocks_); + grad_max_block_buffer_.Resize(num_reduce_blocks_); + hess_min_block_buffer_.Resize(num_reduce_blocks_); + hess_max_block_buffer_.Resize(num_reduce_blocks_); + random_values_use_start_.Resize(num_trees_); + gradient_random_values_.Resize(num_data); + hessian_random_values_.Resize(num_data); + + std::vector gradient_random_values(num_data, 0.0f); + std::vector hessian_random_values(num_data, 0.0f); + std::vector random_values_use_start(num_trees_, 0); + + const int num_threads = OMP_NUM_THREADS(); + + std::mt19937 random_values_use_start_eng = std::mt19937(random_seed_); + std::uniform_int_distribution random_values_use_start_dist = std::uniform_int_distribution(0, num_data); + for (int tree_index = 0; tree_index < num_trees_; ++tree_index) { + random_values_use_start[tree_index] = random_values_use_start_dist(random_values_use_start_eng); + } + + int num_blocks = 0; + data_size_t block_size = 0; + Threading::BlockInfo(num_data, 512, &num_blocks, &block_size); + #pragma omp parallel for schedule(static, 1) num_threads(num_threads) + for (int thread_id = 0; thread_id < num_blocks; ++thread_id) { + const data_size_t start = thread_id * block_size; + const data_size_t end = std::min(start + block_size, num_data); + std::mt19937 gradient_random_values_eng(random_seed_ + thread_id); + std::uniform_real_distribution gradient_random_values_dist(0.0f, 1.0f); + std::mt19937 hessian_random_values_eng(random_seed_ + thread_id + num_threads); + std::uniform_real_distribution hessian_random_values_dist(0.0f, 1.0f); + for (data_size_t i = start; i < end; ++i) { + gradient_random_values[i] = gradient_random_values_dist(gradient_random_values_eng); + hessian_random_values[i] = hessian_random_values_dist(hessian_random_values_eng); + } + } + + CopyFromHostToCUDADevice(gradient_random_values_.RawData(), gradient_random_values.data(), gradient_random_values.size(), __FILE__, __LINE__); + CopyFromHostToCUDADevice(hessian_random_values_.RawData(), hessian_random_values.data(), hessian_random_values.size(), __FILE__, __LINE__); + CopyFromHostToCUDADevice(random_values_use_start_.RawData(), random_values_use_start.data(), random_values_use_start.size(), __FILE__, __LINE__); + iter_ = 0; + } + + protected: + mutable CUDAVector discretized_gradients_and_hessians_; + mutable CUDAVector grad_min_block_buffer_; + mutable CUDAVector grad_max_block_buffer_; + mutable CUDAVector hess_min_block_buffer_; + mutable CUDAVector hess_max_block_buffer_; + CUDAVector random_values_use_start_; + CUDAVector gradient_random_values_; + CUDAVector hessian_random_values_; + int num_reduce_blocks_; +}; + +} // namespace LightGBM + +#endif // USE_CUDA +#endif // LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_ diff --git a/src/treelearner/cuda/cuda_histogram_constructor.cpp b/src/treelearner/cuda/cuda_histogram_constructor.cpp index 7e6be1c10..659db2aad 100644 --- a/src/treelearner/cuda/cuda_histogram_constructor.cpp +++ b/src/treelearner/cuda/cuda_histogram_constructor.cpp @@ -20,7 +20,9 @@ CUDAHistogramConstructor::CUDAHistogramConstructor( const int min_data_in_leaf, const double min_sum_hessian_in_leaf, const int gpu_device_id, - const bool gpu_use_dp): + const bool gpu_use_dp, + const bool use_quantized_grad, + const int num_grad_quant_bins): num_data_(train_data->num_data()), num_features_(train_data->num_features()), num_leaves_(num_leaves), @@ -28,24 +30,14 @@ CUDAHistogramConstructor::CUDAHistogramConstructor( min_data_in_leaf_(min_data_in_leaf), min_sum_hessian_in_leaf_(min_sum_hessian_in_leaf), gpu_device_id_(gpu_device_id), - gpu_use_dp_(gpu_use_dp) { + gpu_use_dp_(gpu_use_dp), + use_quantized_grad_(use_quantized_grad), + num_grad_quant_bins_(num_grad_quant_bins) { InitFeatureMetaInfo(train_data, feature_hist_offsets); cuda_row_data_.reset(nullptr); - cuda_feature_num_bins_ = nullptr; - cuda_feature_hist_offsets_ = nullptr; - cuda_feature_most_freq_bins_ = nullptr; - cuda_hist_ = nullptr; - cuda_need_fix_histogram_features_ = nullptr; - cuda_need_fix_histogram_features_num_bin_aligned_ = nullptr; } CUDAHistogramConstructor::~CUDAHistogramConstructor() { - DeallocateCUDAMemory(&cuda_feature_num_bins_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_hist_offsets_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_most_freq_bins_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_hist_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_need_fix_histogram_features_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__); gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__); } @@ -84,54 +76,70 @@ void CUDAHistogramConstructor::InitFeatureMetaInfo(const Dataset* train_data, co void CUDAHistogramConstructor::BeforeTrain(const score_t* gradients, const score_t* hessians) { cuda_gradients_ = gradients; cuda_hessians_ = hessians; - SetCUDAMemory(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); + cuda_hist_.SetValue(0); } void CUDAHistogramConstructor::Init(const Dataset* train_data, TrainingShareStates* share_state) { - AllocateCUDAMemory(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); - SetCUDAMemory(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); + cuda_hist_.Resize(static_cast(num_total_bin_ * 2 * num_leaves_)); + cuda_hist_.SetValue(0); - InitCUDAMemoryFromHostMemory(&cuda_feature_num_bins_, - feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__); - - InitCUDAMemoryFromHostMemory(&cuda_feature_hist_offsets_, - feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__); - - InitCUDAMemoryFromHostMemory(&cuda_feature_most_freq_bins_, - feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__); + cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_); + cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_); + cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_); cuda_row_data_.reset(new CUDARowData(train_data, share_state, gpu_device_id_, gpu_use_dp_)); cuda_row_data_->Init(train_data, share_state); CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_)); - InitCUDAMemoryFromHostMemory(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(), - need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__); + cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_); + cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_); if (cuda_row_data_->NumLargeBinPartition() > 0) { int grid_dim_x = 0, grid_dim_y = 0, block_dim_x = 0, block_dim_y = 0; CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_); - const size_t buffer_size = static_cast(grid_dim_y) * static_cast(num_total_bin_) * 2; - AllocateCUDAMemory(&cuda_hist_buffer_, buffer_size, __FILE__, __LINE__); + const size_t buffer_size = static_cast(grid_dim_y) * static_cast(num_total_bin_); + if (!use_quantized_grad_) { + if (gpu_use_dp_) { + // need to double the size of histogram buffer in global memory when using double precision in histogram construction + cuda_hist_buffer_.Resize(buffer_size * 4); + } else { + cuda_hist_buffer_.Resize(buffer_size * 2); + } + } else { + // use only half the size of histogram buffer in global memory when quantized training since each gradient and hessian takes only 2 bytes + cuda_hist_buffer_.Resize(buffer_size); + } } + hist_buffer_for_num_bit_change_.Resize(num_total_bin_ * 2); } void CUDAHistogramConstructor::ConstructHistogramForLeaf( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const CUDALeafSplitsStruct* cuda_larger_leaf_splits, + const CUDALeafSplitsStruct* /*cuda_larger_leaf_splits*/, const data_size_t num_data_in_smaller_leaf, const data_size_t num_data_in_larger_leaf, const double sum_hessians_in_smaller_leaf, - const double sum_hessians_in_larger_leaf) { + const double sum_hessians_in_larger_leaf, + const uint8_t num_bits_in_histogram_bins) { if ((num_data_in_smaller_leaf <= min_data_in_leaf_ || sum_hessians_in_smaller_leaf <= min_sum_hessian_in_leaf_) && (num_data_in_larger_leaf <= min_data_in_leaf_ || sum_hessians_in_larger_leaf <= min_sum_hessian_in_leaf_)) { return; } - LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); SynchronizeCUDADevice(__FILE__, __LINE__); +} + +void CUDAHistogramConstructor::SubtractHistogramForLeaf( + const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, + const CUDALeafSplitsStruct* cuda_larger_leaf_splits, + const bool use_quantized_grad, + const uint8_t parent_num_bits_in_histogram_bins, + const uint8_t smaller_num_bits_in_histogram_bins, + const uint8_t larger_num_bits_in_histogram_bins) { global_timer.Start("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel"); - LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits); + LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits, use_quantized_grad, + parent_num_bits_in_histogram_bins, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins); global_timer.Stop("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel"); } @@ -152,33 +160,18 @@ void CUDAHistogramConstructor::ResetTrainingData(const Dataset* train_data, Trai num_data_ = train_data->num_data(); num_features_ = train_data->num_features(); InitFeatureMetaInfo(train_data, share_states->feature_hist_offsets()); - if (feature_num_bins_.size() > 0) { - DeallocateCUDAMemory(&cuda_feature_num_bins_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_hist_offsets_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_feature_most_freq_bins_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_need_fix_histogram_features_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_hist_, __FILE__, __LINE__); - } - AllocateCUDAMemory(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); - SetCUDAMemory(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); - - InitCUDAMemoryFromHostMemory(&cuda_feature_num_bins_, - feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__); - - InitCUDAMemoryFromHostMemory(&cuda_feature_hist_offsets_, - feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__); - - InitCUDAMemoryFromHostMemory(&cuda_feature_most_freq_bins_, - feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__); + cuda_hist_.Resize(static_cast(num_total_bin_ * 2 * num_leaves_)); + cuda_hist_.SetValue(0); + cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_); + cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_); + cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_); cuda_row_data_.reset(new CUDARowData(train_data, share_states, gpu_device_id_, gpu_use_dp_)); cuda_row_data_->Init(train_data, share_states); - InitCUDAMemoryFromHostMemory(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__); - InitCUDAMemoryFromHostMemory(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(), - need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__); + cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_); + cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_); } void CUDAHistogramConstructor::ResetConfig(const Config* config) { @@ -186,9 +179,8 @@ void CUDAHistogramConstructor::ResetConfig(const Config* config) { num_leaves_ = config->num_leaves; min_data_in_leaf_ = config->min_data_in_leaf; min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf; - DeallocateCUDAMemory(&cuda_hist_, __FILE__, __LINE__); - AllocateCUDAMemory(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); - SetCUDAMemory(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__); + cuda_hist_.Resize(static_cast(num_total_bin_ * 2 * num_leaves_)); + cuda_hist_.SetValue(0); } } // namespace LightGBM diff --git a/src/treelearner/cuda/cuda_histogram_constructor.cu b/src/treelearner/cuda/cuda_histogram_constructor.cu index c88438330..03d3b8979 100644 --- a/src/treelearner/cuda/cuda_histogram_constructor.cu +++ b/src/treelearner/cuda/cuda_histogram_constructor.cu @@ -125,7 +125,7 @@ __global__ void CUDAConstructHistogramSparseKernel( } } -template +template __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( const CUDALeafSplitsStruct* smaller_leaf_splits, const score_t* cuda_gradients, @@ -135,7 +135,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( const uint32_t* column_hist_offsets_full, const int* feature_partition_column_index_offsets, const data_size_t num_data, - float* global_hist_buffer) { + HIST_TYPE* global_hist_buffer) { const int dim_y = static_cast(gridDim.y * blockDim.y); const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf; const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y; @@ -150,7 +150,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1; const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x; const int num_total_bin = column_hist_offsets_full[gridDim.x]; - float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2; + HIST_TYPE* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2; for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { shared_hist[i] = 0.0f; } @@ -166,14 +166,14 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( data_size_t inner_data_index = static_cast(threadIdx_y); const int column_index = static_cast(threadIdx.x) + partition_column_start; if (threadIdx.x < static_cast(num_columns_in_partition)) { - float* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1); + HIST_TYPE* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1); for (data_size_t i = 0; i < num_iteration_this; ++i) { const data_size_t data_index = data_indices_ref_this_block[inner_data_index]; const score_t grad = cuda_gradients[data_index]; const score_t hess = cuda_hessians[data_index]; const uint32_t bin = static_cast(data_ptr[static_cast(data_index) * num_columns_in_partition + threadIdx.x]); const uint32_t pos = bin << 1; - float* pos_ptr = shared_hist_ptr + pos; + HIST_TYPE* pos_ptr = shared_hist_ptr + pos; atomicAdd_block(pos_ptr, grad); atomicAdd_block(pos_ptr + 1, hess); inner_data_index += blockDim.y; @@ -186,7 +186,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory( } } -template +template __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory( const CUDALeafSplitsStruct* smaller_leaf_splits, const score_t* cuda_gradients, @@ -196,7 +196,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory( const DATA_PTR_TYPE* partition_ptr, const uint32_t* column_hist_offsets_full, const data_size_t num_data, - float* global_hist_buffer) { + HIST_TYPE* global_hist_buffer) { const int dim_y = static_cast(gridDim.y * blockDim.y); const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf; const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y; @@ -209,7 +209,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory( const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1; const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x; const int num_total_bin = column_hist_offsets_full[gridDim.x]; - float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2; + HIST_TYPE* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2; for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { shared_hist[i] = 0.0f; } @@ -233,7 +233,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory( const score_t hess = cuda_hessians[data_index]; const uint32_t bin = static_cast(data_ptr[row_start + threadIdx.x]); const uint32_t pos = bin << 1; - float* pos_ptr = shared_hist + pos; + HIST_TYPE* pos_ptr = shared_hist + pos; atomicAdd_block(pos_ptr, grad); atomicAdd_block(pos_ptr + 1, hess); } @@ -246,13 +246,278 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory( } } +template +__global__ void CUDAConstructDiscretizedHistogramDenseKernel( + const CUDALeafSplitsStruct* smaller_leaf_splits, + const int32_t* cuda_gradients_and_hessians, + const BIN_TYPE* data, + const uint32_t* column_hist_offsets, + const uint32_t* column_hist_offsets_full, + const int* feature_partition_column_index_offsets, + const data_size_t num_data) { + const int dim_y = static_cast(gridDim.y * blockDim.y); + const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf; + const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y; + const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf; + __shared__ int16_t shared_hist[SHARED_HIST_SIZE]; + int32_t* shared_hist_packed = reinterpret_cast(shared_hist); + const unsigned int num_threads_per_block = blockDim.x * blockDim.y; + const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x]; + const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1]; + const BIN_TYPE* data_ptr = data + partition_column_start * num_data; + const int num_columns_in_partition = partition_column_end - partition_column_start; + const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; + const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; + const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start); + const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + shared_hist_packed[i] = 0; + } + __syncthreads(); + const unsigned int threadIdx_y = threadIdx.y; + const unsigned int blockIdx_y = blockIdx.y; + const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread; + const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start; + data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast(blockDim.y))); + const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y; + const data_size_t remainder = block_num_data % blockDim.y; + const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast(threadIdx_y >= remainder); + data_size_t inner_data_index = static_cast(threadIdx_y); + const int column_index = static_cast(threadIdx.x) + partition_column_start; + if (threadIdx.x < static_cast(num_columns_in_partition)) { + int32_t* shared_hist_ptr = shared_hist_packed + (column_hist_offsets[column_index]); + for (data_size_t i = 0; i < num_iteration_this; ++i) { + const data_size_t data_index = data_indices_ref_this_block[inner_data_index]; + const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index]; + const uint32_t bin = static_cast(data_ptr[data_index * num_columns_in_partition + threadIdx.x]); + int32_t* pos_ptr = shared_hist_ptr + bin; + atomicAdd_block(pos_ptr, grad_and_hess); + inner_data_index += blockDim.y; + } + } + __syncthreads(); + if (USE_16BIT_HIST) { + int32_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess); + } + } else { + atomic_add_long_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + const int64_t packed_grad_hess_int64 = (static_cast(static_cast(packed_grad_hess >> 16)) << 32) | (static_cast(packed_grad_hess & 0x0000ffff)); + atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64)); + } + } +} + +template +__global__ void CUDAConstructDiscretizedHistogramSparseKernel( + const CUDALeafSplitsStruct* smaller_leaf_splits, + const int32_t* cuda_gradients_and_hessians, + const BIN_TYPE* data, + const DATA_PTR_TYPE* row_ptr, + const DATA_PTR_TYPE* partition_ptr, + const uint32_t* column_hist_offsets_full, + const data_size_t num_data) { + const int dim_y = static_cast(gridDim.y * blockDim.y); + const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf; + const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y; + const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf; + __shared__ int16_t shared_hist[SHARED_HIST_SIZE]; + int32_t* shared_hist_packed = reinterpret_cast(shared_hist); + const unsigned int num_threads_per_block = blockDim.x * blockDim.y; + const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1); + const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x]; + const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; + const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; + const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start); + const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + shared_hist_packed[i] = 0.0f; + } + __syncthreads(); + const unsigned int threadIdx_y = threadIdx.y; + const unsigned int blockIdx_y = blockIdx.y; + const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread; + const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start; + data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast(blockDim.y))); + const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y; + const data_size_t remainder = block_num_data % blockDim.y; + const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast(threadIdx_y >= remainder); + data_size_t inner_data_index = static_cast(threadIdx_y); + for (data_size_t i = 0; i < num_iteration_this; ++i) { + const data_size_t data_index = data_indices_ref_this_block[inner_data_index]; + const DATA_PTR_TYPE row_start = block_row_ptr[data_index]; + const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1]; + const DATA_PTR_TYPE row_size = row_end - row_start; + if (threadIdx.x < row_size) { + const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index]; + const uint32_t bin = static_cast(data_ptr[row_start + threadIdx.x]); + int32_t* pos_ptr = shared_hist_packed + bin; + atomicAdd_block(pos_ptr, grad_and_hess); + } + inner_data_index += blockDim.y; + } + __syncthreads(); + if (USE_16BIT_HIST) { + int32_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess); + } + } else { + atomic_add_long_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + const int64_t packed_grad_hess_int64 = (static_cast(static_cast(packed_grad_hess >> 16)) << 32) | (static_cast(packed_grad_hess & 0x0000ffff)); + atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64)); + } + } +} + +template +__global__ void CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory( + const CUDALeafSplitsStruct* smaller_leaf_splits, + const int32_t* cuda_gradients_and_hessians, + const BIN_TYPE* data, + const uint32_t* column_hist_offsets, + const uint32_t* column_hist_offsets_full, + const int* feature_partition_column_index_offsets, + const data_size_t num_data, + int32_t* global_hist_buffer) { + const int dim_y = static_cast(gridDim.y * blockDim.y); + const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf; + const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y; + const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf; + const unsigned int num_threads_per_block = blockDim.x * blockDim.y; + const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x]; + const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1]; + const BIN_TYPE* data_ptr = data + partition_column_start * num_data; + const int num_columns_in_partition = partition_column_end - partition_column_start; + const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; + const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; + const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start); + const int num_total_bin = column_hist_offsets_full[gridDim.x]; + int32_t* shared_hist_packed = global_hist_buffer + (blockIdx.y * num_total_bin + partition_column_start); + const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + shared_hist_packed[i] = 0; + } + __syncthreads(); + const unsigned int threadIdx_y = threadIdx.y; + const unsigned int blockIdx_y = blockIdx.y; + const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread; + const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start; + data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast(blockDim.y))); + const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y; + const data_size_t remainder = block_num_data % blockDim.y; + const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast(threadIdx_y >= remainder); + data_size_t inner_data_index = static_cast(threadIdx_y); + const int column_index = static_cast(threadIdx.x) + partition_column_start; + if (threadIdx.x < static_cast(num_columns_in_partition)) { + int32_t* shared_hist_ptr = shared_hist_packed + (column_hist_offsets[column_index]); + for (data_size_t i = 0; i < num_iteration_this; ++i) { + const data_size_t data_index = data_indices_ref_this_block[inner_data_index]; + const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index]; + const uint32_t bin = static_cast(data_ptr[data_index * num_columns_in_partition + threadIdx.x]); + int32_t* pos_ptr = shared_hist_ptr + bin; + atomicAdd_block(pos_ptr, grad_and_hess); + inner_data_index += blockDim.y; + } + } + __syncthreads(); + if (USE_16BIT_HIST) { + int32_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess); + } + } else { + atomic_add_long_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + const int64_t packed_grad_hess_int64 = (static_cast(static_cast(packed_grad_hess >> 16)) << 32) | (static_cast(packed_grad_hess & 0x0000ffff)); + atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64)); + } + } +} + +template +__global__ void CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory( + const CUDALeafSplitsStruct* smaller_leaf_splits, + const int32_t* cuda_gradients_and_hessians, + const BIN_TYPE* data, + const DATA_PTR_TYPE* row_ptr, + const DATA_PTR_TYPE* partition_ptr, + const uint32_t* column_hist_offsets_full, + const data_size_t num_data, + int32_t* global_hist_buffer) { + const int dim_y = static_cast(gridDim.y * blockDim.y); + const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf; + const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y; + const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf; + const int num_total_bin = column_hist_offsets_full[gridDim.x]; + const unsigned int num_threads_per_block = blockDim.x * blockDim.y; + const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1); + const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x]; + const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x]; + const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1]; + const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start); + const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x; + int32_t* shared_hist_packed = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start); + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + shared_hist_packed[i] = 0.0f; + } + __syncthreads(); + const unsigned int threadIdx_y = threadIdx.y; + const unsigned int blockIdx_y = blockIdx.y; + const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread; + const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start; + data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast(blockDim.y))); + const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y; + const data_size_t remainder = block_num_data % blockDim.y; + const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast(threadIdx_y >= remainder); + data_size_t inner_data_index = static_cast(threadIdx_y); + for (data_size_t i = 0; i < num_iteration_this; ++i) { + const data_size_t data_index = data_indices_ref_this_block[inner_data_index]; + const DATA_PTR_TYPE row_start = block_row_ptr[data_index]; + const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1]; + const DATA_PTR_TYPE row_size = row_end - row_start; + if (threadIdx.x < row_size) { + const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index]; + const uint32_t bin = static_cast(data_ptr[row_start + threadIdx.x]); + int32_t* pos_ptr = shared_hist_packed + bin; + atomicAdd_block(pos_ptr, grad_and_hess); + } + inner_data_index += blockDim.y; + } + __syncthreads(); + if (USE_16BIT_HIST) { + int32_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess); + } + } else { + atomic_add_long_t* feature_histogram_ptr = reinterpret_cast(smaller_leaf_splits->hist_in_leaf) + partition_hist_start; + for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) { + const int32_t packed_grad_hess = shared_hist_packed[i]; + const int64_t packed_grad_hess_int64 = (static_cast(static_cast(packed_grad_hess >> 16)) << 32) | (static_cast(packed_grad_hess & 0x0000ffff)); + atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64)); + } + } +} + void CUDAHistogramConstructor::LaunchConstructHistogramKernel( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf) { + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins) { if (cuda_row_data_->shared_hist_size() == DP_SHARED_HIST_SIZE && gpu_use_dp_) { - LaunchConstructHistogramKernelInner(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else if (cuda_row_data_->shared_hist_size() == SP_SHARED_HIST_SIZE && !gpu_use_dp_) { - LaunchConstructHistogramKernelInner(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else { Log::Fatal("Unknown shared histogram size %d", cuda_row_data_->shared_hist_size()); } @@ -261,13 +526,14 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernel( template void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf) { + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins) { if (cuda_row_data_->bit_type() == 8) { - LaunchConstructHistogramKernelInner0(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner0(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else if (cuda_row_data_->bit_type() == 16) { - LaunchConstructHistogramKernelInner0(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner0(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else if (cuda_row_data_->bit_type() == 32) { - LaunchConstructHistogramKernelInner0(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner0(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else { Log::Fatal("Unknown bit_type = %d", cuda_row_data_->bit_type()); } @@ -276,16 +542,17 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner( template void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf) { + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins) { if (cuda_row_data_->row_ptr_bit_type() == 16) { - LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else if (cuda_row_data_->row_ptr_bit_type() == 32) { - LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else if (cuda_row_data_->row_ptr_bit_type() == 64) { - LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else { if (!cuda_row_data_->is_sparse()) { - LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner1(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else { Log::Fatal("Unknown row_ptr_bit_type = %d", cuda_row_data_->row_ptr_bit_type()); } @@ -295,18 +562,20 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0( template void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner1( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf) { + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins) { if (cuda_row_data_->NumLargeBinPartition() == 0) { - LaunchConstructHistogramKernelInner2(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner2(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } else { - LaunchConstructHistogramKernelInner2(cuda_smaller_leaf_splits, num_data_in_smaller_leaf); + LaunchConstructHistogramKernelInner2(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins); } } template void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf) { + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins) { int grid_dim_x = 0; int grid_dim_y = 0; int block_dim_x = 0; @@ -314,47 +583,139 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2( CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_in_smaller_leaf); dim3 grid_dim(grid_dim_x, grid_dim_y); dim3 block_dim(block_dim_x, block_dim_y); - if (!USE_GLOBAL_MEM_BUFFER) { - if (cuda_row_data_->is_sparse()) { - CUDAConstructHistogramSparseKernel<<>>( - cuda_smaller_leaf_splits, - cuda_gradients_, cuda_hessians_, - cuda_row_data_->GetBin(), - cuda_row_data_->GetRowPtr(), - cuda_row_data_->GetPartitionPtr(), - cuda_row_data_->cuda_partition_hist_offsets(), - num_data_); + if (use_quantized_grad_) { + if (USE_GLOBAL_MEM_BUFFER) { + if (cuda_row_data_->is_sparse()) { + if (num_bits_in_histogram_bins <= 16) { + CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->GetRowPtr(), + cuda_row_data_->GetPartitionPtr(), + cuda_row_data_->cuda_partition_hist_offsets(), + num_data_, + reinterpret_cast(cuda_hist_buffer_.RawData())); + } else { + CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->GetRowPtr(), + cuda_row_data_->GetPartitionPtr(), + cuda_row_data_->cuda_partition_hist_offsets(), + num_data_, + reinterpret_cast(cuda_hist_buffer_.RawData())); + } + } else { + if (num_bits_in_histogram_bins <= 16) { + CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->cuda_column_hist_offsets(), + cuda_row_data_->cuda_partition_hist_offsets(), + cuda_row_data_->cuda_feature_partition_column_index_offsets(), + num_data_, + reinterpret_cast(cuda_hist_buffer_.RawData())); + } else { + CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->cuda_column_hist_offsets(), + cuda_row_data_->cuda_partition_hist_offsets(), + cuda_row_data_->cuda_feature_partition_column_index_offsets(), + num_data_, + reinterpret_cast(cuda_hist_buffer_.RawData())); + } + } } else { - CUDAConstructHistogramDenseKernel<<>>( - cuda_smaller_leaf_splits, - cuda_gradients_, cuda_hessians_, - cuda_row_data_->GetBin(), - cuda_row_data_->cuda_column_hist_offsets(), - cuda_row_data_->cuda_partition_hist_offsets(), - cuda_row_data_->cuda_feature_partition_column_index_offsets(), - num_data_); + if (cuda_row_data_->is_sparse()) { + if (num_bits_in_histogram_bins <= 16) { + CUDAConstructDiscretizedHistogramSparseKernel<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->GetRowPtr(), + cuda_row_data_->GetPartitionPtr(), + cuda_row_data_->cuda_partition_hist_offsets(), + num_data_); + } else { + CUDAConstructDiscretizedHistogramSparseKernel<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->GetRowPtr(), + cuda_row_data_->GetPartitionPtr(), + cuda_row_data_->cuda_partition_hist_offsets(), + num_data_); + } + } else { + if (num_bits_in_histogram_bins <= 16) { + CUDAConstructDiscretizedHistogramDenseKernel<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->cuda_column_hist_offsets(), + cuda_row_data_->cuda_partition_hist_offsets(), + cuda_row_data_->cuda_feature_partition_column_index_offsets(), + num_data_); + } else { + CUDAConstructDiscretizedHistogramDenseKernel<<>>( + cuda_smaller_leaf_splits, + reinterpret_cast(cuda_gradients_), + cuda_row_data_->GetBin(), + cuda_row_data_->cuda_column_hist_offsets(), + cuda_row_data_->cuda_partition_hist_offsets(), + cuda_row_data_->cuda_feature_partition_column_index_offsets(), + num_data_); + } + } } } else { - if (cuda_row_data_->is_sparse()) { - CUDAConstructHistogramSparseKernel_GlobalMemory<<>>( - cuda_smaller_leaf_splits, - cuda_gradients_, cuda_hessians_, - cuda_row_data_->GetBin(), - cuda_row_data_->GetRowPtr(), - cuda_row_data_->GetPartitionPtr(), - cuda_row_data_->cuda_partition_hist_offsets(), - num_data_, - cuda_hist_buffer_); + if (!USE_GLOBAL_MEM_BUFFER) { + if (cuda_row_data_->is_sparse()) { + CUDAConstructHistogramSparseKernel<<>>( + cuda_smaller_leaf_splits, + cuda_gradients_, cuda_hessians_, + cuda_row_data_->GetBin(), + cuda_row_data_->GetRowPtr(), + cuda_row_data_->GetPartitionPtr(), + cuda_row_data_->cuda_partition_hist_offsets(), + num_data_); + } else { + CUDAConstructHistogramDenseKernel<<>>( + cuda_smaller_leaf_splits, + cuda_gradients_, cuda_hessians_, + cuda_row_data_->GetBin(), + cuda_row_data_->cuda_column_hist_offsets(), + cuda_row_data_->cuda_partition_hist_offsets(), + cuda_row_data_->cuda_feature_partition_column_index_offsets(), + num_data_); + } } else { - CUDAConstructHistogramDenseKernel_GlobalMemory<<>>( - cuda_smaller_leaf_splits, - cuda_gradients_, cuda_hessians_, - cuda_row_data_->GetBin(), - cuda_row_data_->cuda_column_hist_offsets(), - cuda_row_data_->cuda_partition_hist_offsets(), - cuda_row_data_->cuda_feature_partition_column_index_offsets(), - num_data_, - cuda_hist_buffer_); + if (cuda_row_data_->is_sparse()) { + CUDAConstructHistogramSparseKernel_GlobalMemory<<>>( + cuda_smaller_leaf_splits, + cuda_gradients_, cuda_hessians_, + cuda_row_data_->GetBin(), + cuda_row_data_->GetRowPtr(), + cuda_row_data_->GetPartitionPtr(), + cuda_row_data_->cuda_partition_hist_offsets(), + num_data_, + reinterpret_cast(cuda_hist_buffer_.RawData())); + } else { + CUDAConstructHistogramDenseKernel_GlobalMemory<<>>( + cuda_smaller_leaf_splits, + cuda_gradients_, cuda_hessians_, + cuda_row_data_->GetBin(), + cuda_row_data_->cuda_column_hist_offsets(), + cuda_row_data_->cuda_partition_hist_offsets(), + cuda_row_data_->cuda_feature_partition_column_index_offsets(), + num_data_, + reinterpret_cast(cuda_hist_buffer_.RawData())); + } } } } @@ -403,28 +764,195 @@ __global__ void FixHistogramKernel( } } +template +__global__ void SubtractHistogramDiscretizedKernel( + const int num_total_bin, + const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, + const CUDALeafSplitsStruct* cuda_larger_leaf_splits, + hist_t* num_bit_change_buffer) { + const unsigned int global_thread_index = threadIdx.x + blockIdx.x * blockDim.x; + const int cuda_larger_leaf_index_ref = cuda_larger_leaf_splits->leaf_index; + if (cuda_larger_leaf_index_ref >= 0) { + if (PARENT_USE_16BIT_HIST) { + const int32_t* smaller_leaf_hist = reinterpret_cast(cuda_smaller_leaf_splits->hist_in_leaf); + int32_t* larger_leaf_hist = reinterpret_cast(cuda_larger_leaf_splits->hist_in_leaf); + if (global_thread_index < num_total_bin) { + larger_leaf_hist[global_thread_index] -= smaller_leaf_hist[global_thread_index]; + } + } else if (LARGER_USE_16BIT_HIST) { + int32_t* buffer = reinterpret_cast(num_bit_change_buffer); + const int32_t* smaller_leaf_hist = reinterpret_cast(cuda_smaller_leaf_splits->hist_in_leaf); + int64_t* larger_leaf_hist = reinterpret_cast(cuda_larger_leaf_splits->hist_in_leaf); + if (global_thread_index < num_total_bin) { + const int64_t parent_hist_item = larger_leaf_hist[global_thread_index]; + const int32_t smaller_hist_item = smaller_leaf_hist[global_thread_index]; + const int64_t smaller_hist_item_int64 = (static_cast(static_cast(smaller_hist_item >> 16)) << 32) | + static_cast(smaller_hist_item & 0x0000ffff); + const int64_t larger_hist_item = parent_hist_item - smaller_hist_item_int64; + buffer[global_thread_index] = static_cast(static_cast(larger_hist_item >> 32) << 16) | + static_cast(larger_hist_item & 0x000000000000ffff); + } + } else if (SMALLER_USE_16BIT_HIST) { + const int32_t* smaller_leaf_hist = reinterpret_cast(cuda_smaller_leaf_splits->hist_in_leaf); + int64_t* larger_leaf_hist = reinterpret_cast(cuda_larger_leaf_splits->hist_in_leaf); + if (global_thread_index < num_total_bin) { + const int64_t parent_hist_item = larger_leaf_hist[global_thread_index]; + const int32_t smaller_hist_item = smaller_leaf_hist[global_thread_index]; + const int64_t smaller_hist_item_int64 = (static_cast(static_cast(smaller_hist_item >> 16)) << 32) | + static_cast(smaller_hist_item & 0x0000ffff); + const int64_t larger_hist_item = parent_hist_item - smaller_hist_item_int64; + larger_leaf_hist[global_thread_index] = larger_hist_item; + } + } else { + const int64_t* smaller_leaf_hist = reinterpret_cast(cuda_smaller_leaf_splits->hist_in_leaf); + int64_t* larger_leaf_hist = reinterpret_cast(cuda_larger_leaf_splits->hist_in_leaf); + if (global_thread_index < num_total_bin) { + larger_leaf_hist[global_thread_index] -= smaller_leaf_hist[global_thread_index]; + } + } + } +} + +__global__ void CopyChangedNumBitHistogram( + const int num_total_bin, + const CUDALeafSplitsStruct* cuda_larger_leaf_splits, + hist_t* num_bit_change_buffer) { + int32_t* hist_dst = reinterpret_cast(cuda_larger_leaf_splits->hist_in_leaf); + const int32_t* hist_src = reinterpret_cast(num_bit_change_buffer); + const unsigned int global_thread_index = threadIdx.x + blockIdx.x * blockDim.x; + if (global_thread_index < static_cast(num_total_bin)) { + hist_dst[global_thread_index] = hist_src[global_thread_index]; + } +} + +template +__global__ void FixHistogramDiscretizedKernel( + const uint32_t* cuda_feature_num_bins, + const uint32_t* cuda_feature_hist_offsets, + const uint32_t* cuda_feature_most_freq_bins, + const int* cuda_need_fix_histogram_features, + const uint32_t* cuda_need_fix_histogram_features_num_bin_aligned, + const CUDALeafSplitsStruct* cuda_smaller_leaf_splits) { + __shared__ int64_t shared_mem_buffer[32]; + const unsigned int blockIdx_x = blockIdx.x; + const int feature_index = cuda_need_fix_histogram_features[blockIdx_x]; + const uint32_t num_bin_aligned = cuda_need_fix_histogram_features_num_bin_aligned[blockIdx_x]; + const uint32_t feature_hist_offset = cuda_feature_hist_offsets[feature_index]; + const uint32_t most_freq_bin = cuda_feature_most_freq_bins[feature_index]; + if (USE_16BIT_HIST) { + const int64_t leaf_sum_gradients_hessians_int64 = cuda_smaller_leaf_splits->sum_of_gradients_hessians; + const int32_t leaf_sum_gradients_hessians = + (static_cast(leaf_sum_gradients_hessians_int64 >> 32) << 16) | static_cast(leaf_sum_gradients_hessians_int64 & 0x000000000000ffff); + int32_t* feature_hist = reinterpret_cast(cuda_smaller_leaf_splits->hist_in_leaf) + feature_hist_offset; + const unsigned int threadIdx_x = threadIdx.x; + const uint32_t num_bin = cuda_feature_num_bins[feature_index]; + const int32_t bin_gradient_hessian = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[threadIdx_x] : 0; + const int32_t sum_gradient_hessian = ShuffleReduceSum( + bin_gradient_hessian, + reinterpret_cast(shared_mem_buffer), + num_bin_aligned); + if (threadIdx_x == 0) { + feature_hist[most_freq_bin] = leaf_sum_gradients_hessians - sum_gradient_hessian; + } + } else { + const int64_t leaf_sum_gradients_hessians = cuda_smaller_leaf_splits->sum_of_gradients_hessians; + int64_t* feature_hist = reinterpret_cast(cuda_smaller_leaf_splits->hist_in_leaf) + feature_hist_offset; + const unsigned int threadIdx_x = threadIdx.x; + const uint32_t num_bin = cuda_feature_num_bins[feature_index]; + const int64_t bin_gradient_hessian = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[threadIdx_x] : 0; + const int64_t sum_gradient_hessian = ShuffleReduceSum(bin_gradient_hessian, shared_mem_buffer, num_bin_aligned); + if (threadIdx_x == 0) { + feature_hist[most_freq_bin] = leaf_sum_gradients_hessians - sum_gradient_hessian; + } + } +} + void CUDAHistogramConstructor::LaunchSubtractHistogramKernel( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const CUDALeafSplitsStruct* cuda_larger_leaf_splits) { - const int num_subtract_threads = 2 * num_total_bin_; - const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE; - global_timer.Start("CUDAHistogramConstructor::FixHistogramKernel"); - if (need_fix_histogram_features_.size() > 0) { - FixHistogramKernel<<>>( - cuda_feature_num_bins_, - cuda_feature_hist_offsets_, - cuda_feature_most_freq_bins_, - cuda_need_fix_histogram_features_, - cuda_need_fix_histogram_features_num_bin_aligned_, - cuda_smaller_leaf_splits); - } - global_timer.Stop("CUDAHistogramConstructor::FixHistogramKernel"); - global_timer.Start("CUDAHistogramConstructor::SubtractHistogramKernel"); - SubtractHistogramKernel<<>>( - num_total_bin_, - cuda_smaller_leaf_splits, - cuda_larger_leaf_splits); - global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramKernel"); + const CUDALeafSplitsStruct* cuda_larger_leaf_splits, + const bool use_discretized_grad, + const uint8_t parent_num_bits_in_histogram_bins, + const uint8_t smaller_num_bits_in_histogram_bins, + const uint8_t larger_num_bits_in_histogram_bins) { + if (!use_discretized_grad) { + const int num_subtract_threads = 2 * num_total_bin_; + const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE; + global_timer.Start("CUDAHistogramConstructor::FixHistogramKernel"); + if (need_fix_histogram_features_.size() > 0) { + FixHistogramKernel<<>>( + cuda_feature_num_bins_.RawData(), + cuda_feature_hist_offsets_.RawData(), + cuda_feature_most_freq_bins_.RawData(), + cuda_need_fix_histogram_features_.RawData(), + cuda_need_fix_histogram_features_num_bin_aligned_.RawData(), + cuda_smaller_leaf_splits); + } + global_timer.Stop("CUDAHistogramConstructor::FixHistogramKernel"); + global_timer.Start("CUDAHistogramConstructor::SubtractHistogramKernel"); + SubtractHistogramKernel<<>>( + num_total_bin_, + cuda_smaller_leaf_splits, + cuda_larger_leaf_splits); + global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramKernel"); + } else { + const int num_subtract_threads = num_total_bin_; + const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE; + global_timer.Start("CUDAHistogramConstructor::FixHistogramDiscretizedKernel"); + if (need_fix_histogram_features_.size() > 0) { + if (smaller_num_bits_in_histogram_bins <= 16) { + FixHistogramDiscretizedKernel<<>>( + cuda_feature_num_bins_.RawData(), + cuda_feature_hist_offsets_.RawData(), + cuda_feature_most_freq_bins_.RawData(), + cuda_need_fix_histogram_features_.RawData(), + cuda_need_fix_histogram_features_num_bin_aligned_.RawData(), + cuda_smaller_leaf_splits); + } else { + FixHistogramDiscretizedKernel<<>>( + cuda_feature_num_bins_.RawData(), + cuda_feature_hist_offsets_.RawData(), + cuda_feature_most_freq_bins_.RawData(), + cuda_need_fix_histogram_features_.RawData(), + cuda_need_fix_histogram_features_num_bin_aligned_.RawData(), + cuda_smaller_leaf_splits); + } + } + global_timer.Stop("CUDAHistogramConstructor::FixHistogramDiscretizedKernel"); + global_timer.Start("CUDAHistogramConstructor::SubtractHistogramDiscretizedKernel"); + if (parent_num_bits_in_histogram_bins <= 16) { + CHECK_LE(smaller_num_bits_in_histogram_bins, 16); + CHECK_LE(larger_num_bits_in_histogram_bins, 16); + SubtractHistogramDiscretizedKernel<<>>( + num_total_bin_, + cuda_smaller_leaf_splits, + cuda_larger_leaf_splits, + hist_buffer_for_num_bit_change_.RawData()); + } else if (larger_num_bits_in_histogram_bins <= 16) { + CHECK_LE(smaller_num_bits_in_histogram_bins, 16); + SubtractHistogramDiscretizedKernel<<>>( + num_total_bin_, + cuda_smaller_leaf_splits, + cuda_larger_leaf_splits, + hist_buffer_for_num_bit_change_.RawData()); + CopyChangedNumBitHistogram<<>>( + num_total_bin_, + cuda_larger_leaf_splits, + hist_buffer_for_num_bit_change_.RawData()); + } else if (smaller_num_bits_in_histogram_bins <= 16) { + SubtractHistogramDiscretizedKernel<<>>( + num_total_bin_, + cuda_smaller_leaf_splits, + cuda_larger_leaf_splits, + hist_buffer_for_num_bit_change_.RawData()); + } else { + SubtractHistogramDiscretizedKernel<<>>( + num_total_bin_, + cuda_smaller_leaf_splits, + cuda_larger_leaf_splits, + hist_buffer_for_num_bit_change_.RawData()); + } + global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramDiscretizedKernel"); + } } } // namespace LightGBM diff --git a/src/treelearner/cuda/cuda_histogram_constructor.hpp b/src/treelearner/cuda/cuda_histogram_constructor.hpp index 7e600e7c0..ddc78cb17 100644 --- a/src/treelearner/cuda/cuda_histogram_constructor.hpp +++ b/src/treelearner/cuda/cuda_histogram_constructor.hpp @@ -9,6 +9,7 @@ #ifdef USE_CUDA #include +#include #include #include @@ -37,7 +38,9 @@ class CUDAHistogramConstructor { const int min_data_in_leaf, const double min_sum_hessian_in_leaf, const int gpu_device_id, - const bool gpu_use_dp); + const bool gpu_use_dp, + const bool use_discretized_grad, + const int grad_discretized_bins); ~CUDAHistogramConstructor(); @@ -49,7 +52,16 @@ class CUDAHistogramConstructor { const data_size_t num_data_in_smaller_leaf, const data_size_t num_data_in_larger_leaf, const double sum_hessians_in_smaller_leaf, - const double sum_hessians_in_larger_leaf); + const double sum_hessians_in_larger_leaf, + const uint8_t num_bits_in_histogram_bins); + + void SubtractHistogramForLeaf( + const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, + const CUDALeafSplitsStruct* cuda_larger_leaf_splits, + const bool use_discretized_grad, + const uint8_t parent_num_bits_in_histogram_bins, + const uint8_t smaller_num_bits_in_histogram_bins, + const uint8_t larger_num_bits_in_histogram_bins); void ResetTrainingData(const Dataset* train_data, TrainingShareStates* share_states); @@ -57,9 +69,9 @@ class CUDAHistogramConstructor { void BeforeTrain(const score_t* gradients, const score_t* hessians); - const hist_t* cuda_hist() const { return cuda_hist_; } + const hist_t* cuda_hist() const { return cuda_hist_.RawData(); } - hist_t* cuda_hist_pointer() { return cuda_hist_; } + hist_t* cuda_hist_pointer() { return cuda_hist_.RawData(); } private: void InitFeatureMetaInfo(const Dataset* train_data, const std::vector& feature_hist_offsets); @@ -74,30 +86,39 @@ class CUDAHistogramConstructor { template void LaunchConstructHistogramKernelInner( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf); + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins); template void LaunchConstructHistogramKernelInner0( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf); + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins); template void LaunchConstructHistogramKernelInner1( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf); + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins); template void LaunchConstructHistogramKernelInner2( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf); + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins); void LaunchConstructHistogramKernel( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const data_size_t num_data_in_smaller_leaf); + const data_size_t num_data_in_smaller_leaf, + const uint8_t num_bits_in_histogram_bins); void LaunchSubtractHistogramKernel( const CUDALeafSplitsStruct* cuda_smaller_leaf_splits, - const CUDALeafSplitsStruct* cuda_larger_leaf_splits); + const CUDALeafSplitsStruct* cuda_larger_leaf_splits, + const bool use_discretized_grad, + const uint8_t parent_num_bits_in_histogram_bins, + const uint8_t smaller_num_bits_in_histogram_bins, + const uint8_t larger_num_bits_in_histogram_bins); // Host memory @@ -136,19 +157,21 @@ class CUDAHistogramConstructor { /*! \brief CUDA row wise data */ std::unique_ptr cuda_row_data_; /*! \brief number of bins per feature */ - uint32_t* cuda_feature_num_bins_; + CUDAVector cuda_feature_num_bins_; /*! \brief offsets in histogram of all features */ - uint32_t* cuda_feature_hist_offsets_; + CUDAVector cuda_feature_hist_offsets_; /*! \brief most frequent bins in each feature */ - uint32_t* cuda_feature_most_freq_bins_; + CUDAVector cuda_feature_most_freq_bins_; /*! \brief CUDA histograms */ - hist_t* cuda_hist_; + CUDAVector cuda_hist_; /*! \brief CUDA histograms buffer for each block */ - float* cuda_hist_buffer_; + CUDAVector cuda_hist_buffer_; /*! \brief indices of feature whose histograms need to be fixed */ - int* cuda_need_fix_histogram_features_; + CUDAVector cuda_need_fix_histogram_features_; /*! \brief aligned number of bins of the features whose histograms need to be fixed */ - uint32_t* cuda_need_fix_histogram_features_num_bin_aligned_; + CUDAVector cuda_need_fix_histogram_features_num_bin_aligned_; + /*! \brief histogram buffer used in histogram subtraction with different number of bits for histogram bins */ + CUDAVector hist_buffer_for_num_bit_change_; // CUDA memory, held by other object @@ -161,6 +184,10 @@ class CUDAHistogramConstructor { const int gpu_device_id_; /*! \brief use double precision histogram per block */ const bool gpu_use_dp_; + /*! \brief whether to use quantized gradients */ + const bool use_quantized_grad_; + /*! \brief the number of bins to quantized gradients */ + const int num_grad_quant_bins_; }; } // namespace LightGBM diff --git a/src/treelearner/cuda/cuda_leaf_splits.cpp b/src/treelearner/cuda/cuda_leaf_splits.cpp index 6aa020d9e..57b5b777c 100644 --- a/src/treelearner/cuda/cuda_leaf_splits.cpp +++ b/src/treelearner/cuda/cuda_leaf_splits.cpp @@ -11,27 +11,22 @@ namespace LightGBM { CUDALeafSplits::CUDALeafSplits(const data_size_t num_data): -num_data_(num_data) { - cuda_struct_ = nullptr; - cuda_sum_of_gradients_buffer_ = nullptr; - cuda_sum_of_hessians_buffer_ = nullptr; -} +num_data_(num_data) {} -CUDALeafSplits::~CUDALeafSplits() { - DeallocateCUDAMemory(&cuda_struct_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__); -} +CUDALeafSplits::~CUDALeafSplits() {} -void CUDALeafSplits::Init() { +void CUDALeafSplits::Init(const bool use_quantized_grad) { num_blocks_init_from_gradients_ = (num_data_ + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS; // allocate more memory for sum reduction in CUDA // only the first element records the final sum - AllocateCUDAMemory(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__); - AllocateCUDAMemory(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__); + cuda_sum_of_gradients_buffer_.Resize(static_cast(num_blocks_init_from_gradients_)); + cuda_sum_of_hessians_buffer_.Resize(static_cast(num_blocks_init_from_gradients_)); + if (use_quantized_grad) { + cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast(num_blocks_init_from_gradients_)); + } - AllocateCUDAMemory(&cuda_struct_, 1, __FILE__, __LINE__); + cuda_struct_.Resize(1); } void CUDALeafSplits::InitValues() { @@ -46,24 +41,33 @@ void CUDALeafSplits::InitValues( const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians) { cuda_gradients_ = cuda_gradients; cuda_hessians_ = cuda_hessians; - SetCUDAMemory(cuda_sum_of_gradients_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__); - SetCUDAMemory(cuda_sum_of_hessians_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__); + cuda_sum_of_gradients_buffer_.SetValue(0); + cuda_sum_of_hessians_buffer_.SetValue(0); LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf); - CopyFromCUDADeviceToHost(root_sum_hessians, cuda_sum_of_hessians_buffer_, 1, __FILE__, __LINE__); + CopyFromCUDADeviceToHost(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__); + SynchronizeCUDADevice(__FILE__, __LINE__); +} + +void CUDALeafSplits::InitValues( + const double lambda_l1, const double lambda_l2, + const int16_t* cuda_gradients_and_hessians, + const data_size_t* cuda_bagging_data_indices, + const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices, + hist_t* cuda_hist_in_leaf, double* root_sum_hessians, + const score_t* grad_scale, const score_t* hess_scale) { + cuda_gradients_ = reinterpret_cast(cuda_gradients_and_hessians); + cuda_hessians_ = nullptr; + LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf, grad_scale, hess_scale); + CopyFromCUDADeviceToHost(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__); SynchronizeCUDADevice(__FILE__, __LINE__); } void CUDALeafSplits::Resize(const data_size_t num_data) { - if (num_data > num_data_) { - DeallocateCUDAMemory(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__); - num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS; - AllocateCUDAMemory(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__); - AllocateCUDAMemory(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__); - } else { - num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS; - } num_data_ = num_data; + num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS; + cuda_sum_of_gradients_buffer_.Resize(static_cast(num_blocks_init_from_gradients_)); + cuda_sum_of_hessians_buffer_.Resize(static_cast(num_blocks_init_from_gradients_)); + cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast(num_blocks_init_from_gradients_)); } } // namespace LightGBM diff --git a/src/treelearner/cuda/cuda_leaf_splits.cu b/src/treelearner/cuda/cuda_leaf_splits.cu index 29e42f67e..ae505ecd5 100644 --- a/src/treelearner/cuda/cuda_leaf_splits.cu +++ b/src/treelearner/cuda/cuda_leaf_splits.cu @@ -81,6 +81,90 @@ __global__ void CUDAInitValuesKernel2( } } +template +__global__ void CUDAInitValuesKernel3(const int16_t* cuda_gradients_and_hessians, + const data_size_t num_data, const data_size_t* cuda_bagging_data_indices, + double* cuda_sum_of_gradients, double* cuda_sum_of_hessians, int64_t* cuda_sum_of_hessians_hessians, + const score_t* grad_scale_pointer, const score_t* hess_scale_pointer) { + const score_t grad_scale = *grad_scale_pointer; + const score_t hess_scale = *hess_scale_pointer; + __shared__ int64_t shared_mem_buffer[32]; + const data_size_t data_index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); + int64_t int_gradient = 0; + int64_t int_hessian = 0; + if (data_index < num_data) { + int_gradient = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index] + 1] : + cuda_gradients_and_hessians[2 * data_index + 1]; + int_hessian = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index]] : + cuda_gradients_and_hessians[2 * data_index]; + } + const int64_t block_sum_gradient = ShuffleReduceSum(int_gradient, shared_mem_buffer, blockDim.x); + __syncthreads(); + const int64_t block_sum_hessian = ShuffleReduceSum(int_hessian, shared_mem_buffer, blockDim.x); + if (threadIdx.x == 0) { + cuda_sum_of_gradients[blockIdx.x] = block_sum_gradient * grad_scale; + cuda_sum_of_hessians[blockIdx.x] = block_sum_hessian * hess_scale; + cuda_sum_of_hessians_hessians[blockIdx.x] = ((block_sum_gradient << 32) | block_sum_hessian); + } +} + +__global__ void CUDAInitValuesKernel4( + const double lambda_l1, + const double lambda_l2, + const int num_blocks_to_reduce, + double* cuda_sum_of_gradients, + double* cuda_sum_of_hessians, + int64_t* cuda_sum_of_gradients_hessians, + const data_size_t num_data, + const data_size_t* cuda_data_indices_in_leaf, + hist_t* cuda_hist_in_leaf, + CUDALeafSplitsStruct* cuda_struct) { + __shared__ double shared_mem_buffer[32]; + double thread_sum_of_gradients = 0.0f; + double thread_sum_of_hessians = 0.0f; + int64_t thread_sum_of_gradients_hessians = 0; + for (int block_index = static_cast(threadIdx.x); block_index < num_blocks_to_reduce; block_index += static_cast(blockDim.x)) { + thread_sum_of_gradients += cuda_sum_of_gradients[block_index]; + thread_sum_of_hessians += cuda_sum_of_hessians[block_index]; + thread_sum_of_gradients_hessians += cuda_sum_of_gradients_hessians[block_index]; + } + const double sum_of_gradients = ShuffleReduceSum(thread_sum_of_gradients, shared_mem_buffer, blockDim.x); + __syncthreads(); + const double sum_of_hessians = ShuffleReduceSum(thread_sum_of_hessians, shared_mem_buffer, blockDim.x); + __syncthreads(); + const double sum_of_gradients_hessians = ShuffleReduceSum( + thread_sum_of_gradients_hessians, + reinterpret_cast(shared_mem_buffer), + blockDim.x); + if (threadIdx.x == 0) { + cuda_sum_of_hessians[0] = sum_of_hessians; + cuda_struct->leaf_index = 0; + cuda_struct->sum_of_gradients = sum_of_gradients; + cuda_struct->sum_of_hessians = sum_of_hessians; + cuda_struct->sum_of_gradients_hessians = sum_of_gradients_hessians; + cuda_struct->num_data_in_leaf = num_data; + const bool use_l1 = lambda_l1 > 0.0f; + if (!use_l1) { + // no smoothing on root node + cuda_struct->gain = CUDALeafSplits::GetLeafGain(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f); + } else { + // no smoothing on root node + cuda_struct->gain = CUDALeafSplits::GetLeafGain(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f); + } + if (!use_l1) { + // no smoothing on root node + cuda_struct->leaf_value = + CUDALeafSplits::CalculateSplittedLeafOutput(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f); + } else { + // no smoothing on root node + cuda_struct->leaf_value = + CUDALeafSplits::CalculateSplittedLeafOutput(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f); + } + cuda_struct->data_indices_in_leaf = cuda_data_indices_in_leaf; + cuda_struct->hist_in_leaf = cuda_hist_in_leaf; + } +} + __global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) { cuda_struct->leaf_index = -1; cuda_struct->sum_of_gradients = 0.0f; @@ -93,7 +177,7 @@ __global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) { } void CUDALeafSplits::LaunchInitValuesEmptyKernel() { - InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_); + InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_.RawData()); } void CUDALeafSplits::LaunchInitValuesKernal( @@ -104,23 +188,55 @@ void CUDALeafSplits::LaunchInitValuesKernal( hist_t* cuda_hist_in_leaf) { if (cuda_bagging_data_indices == nullptr) { CUDAInitValuesKernel1<<>>( - cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_, - cuda_sum_of_hessians_buffer_); + cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(), + cuda_sum_of_hessians_buffer_.RawData()); } else { CUDAInitValuesKernel1<<>>( - cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_, - cuda_sum_of_hessians_buffer_); + cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(), + cuda_sum_of_hessians_buffer_.RawData()); } SynchronizeCUDADevice(__FILE__, __LINE__); CUDAInitValuesKernel2<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>( lambda_l1, lambda_l2, num_blocks_init_from_gradients_, - cuda_sum_of_gradients_buffer_, - cuda_sum_of_hessians_buffer_, + cuda_sum_of_gradients_buffer_.RawData(), + cuda_sum_of_hessians_buffer_.RawData(), num_used_indices, cuda_data_indices_in_leaf, cuda_hist_in_leaf, - cuda_struct_); + cuda_struct_.RawData()); + SynchronizeCUDADevice(__FILE__, __LINE__); +} + +void CUDALeafSplits::LaunchInitValuesKernal( + const double lambda_l1, const double lambda_l2, + const data_size_t* cuda_bagging_data_indices, + const data_size_t* cuda_data_indices_in_leaf, + const data_size_t num_used_indices, + hist_t* cuda_hist_in_leaf, + const score_t* grad_scale, + const score_t* hess_scale) { + if (cuda_bagging_data_indices == nullptr) { + CUDAInitValuesKernel3<<>>( + reinterpret_cast(cuda_gradients_), num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(), + cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale); + } else { + CUDAInitValuesKernel3<<>>( + reinterpret_cast(cuda_gradients_), num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(), + cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale); + } + + SynchronizeCUDADevice(__FILE__, __LINE__); + CUDAInitValuesKernel4<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>( + lambda_l1, lambda_l2, + num_blocks_init_from_gradients_, + cuda_sum_of_gradients_buffer_.RawData(), + cuda_sum_of_hessians_buffer_.RawData(), + cuda_sum_of_gradients_hessians_buffer_.RawData(), + num_used_indices, + cuda_data_indices_in_leaf, + cuda_hist_in_leaf, + cuda_struct_.RawData()); SynchronizeCUDADevice(__FILE__, __LINE__); } diff --git a/src/treelearner/cuda/cuda_leaf_splits.hpp b/src/treelearner/cuda/cuda_leaf_splits.hpp index 769f956b9..33a9ea578 100644 --- a/src/treelearner/cuda/cuda_leaf_splits.hpp +++ b/src/treelearner/cuda/cuda_leaf_splits.hpp @@ -8,7 +8,7 @@ #ifdef USE_CUDA -#include +#include #include #include #include @@ -23,6 +23,7 @@ struct CUDALeafSplitsStruct { int leaf_index; double sum_of_gradients; double sum_of_hessians; + int64_t sum_of_gradients_hessians; data_size_t num_data_in_leaf; double gain; double leaf_value; @@ -36,7 +37,7 @@ class CUDALeafSplits { ~CUDALeafSplits(); - void Init(); + void Init(const bool use_quantized_grad); void InitValues( const double lambda_l1, const double lambda_l2, @@ -45,11 +46,19 @@ class CUDALeafSplits { const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians); + void InitValues( + const double lambda_l1, const double lambda_l2, + const int16_t* cuda_gradients_and_hessians, + const data_size_t* cuda_bagging_data_indices, + const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices, + hist_t* cuda_hist_in_leaf, double* root_sum_hessians, + const score_t* grad_scale, const score_t* hess_scale); + void InitValues(); - const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_; } + const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_.RawDataReadOnly(); } - CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_; } + CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_.RawData(); } void Resize(const data_size_t num_data); @@ -140,14 +149,24 @@ class CUDALeafSplits { const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf); + void LaunchInitValuesKernal( + const double lambda_l1, const double lambda_l2, + const data_size_t* cuda_bagging_data_indices, + const data_size_t* cuda_data_indices_in_leaf, + const data_size_t num_used_indices, + hist_t* cuda_hist_in_leaf, + const score_t* grad_scale, + const score_t* hess_scale); + // Host memory data_size_t num_data_; int num_blocks_init_from_gradients_; // CUDA memory, held by this object - CUDALeafSplitsStruct* cuda_struct_; - double* cuda_sum_of_gradients_buffer_; - double* cuda_sum_of_hessians_buffer_; + CUDAVector cuda_struct_; + CUDAVector cuda_sum_of_gradients_buffer_; + CUDAVector cuda_sum_of_hessians_buffer_; + CUDAVector cuda_sum_of_gradients_hessians_buffer_; // CUDA memory, held by other object const score_t* cuda_gradients_; diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp index 1600f3767..59d3d3ca1 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp @@ -9,7 +9,7 @@ #include "cuda_single_gpu_tree_learner.hpp" #include -#include +#include #include #include #include @@ -39,13 +39,14 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_ SetCUDADevice(gpu_device_id_, __FILE__, __LINE__); cuda_smaller_leaf_splits_.reset(new CUDALeafSplits(num_data_)); - cuda_smaller_leaf_splits_->Init(); + cuda_smaller_leaf_splits_->Init(config_->use_quantized_grad); cuda_larger_leaf_splits_.reset(new CUDALeafSplits(num_data_)); - cuda_larger_leaf_splits_->Init(); + cuda_larger_leaf_splits_->Init(config_->use_quantized_grad); cuda_histogram_constructor_.reset(new CUDAHistogramConstructor(train_data_, config_->num_leaves, num_threads_, share_state_->feature_hist_offsets(), - config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp)); + config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp, + config_->use_quantized_grad, config_->num_grad_quant_bins)); cuda_histogram_constructor_->Init(train_data_, share_state_.get()); const auto& feature_hist_offsets = share_state_->feature_hist_offsets(); @@ -73,11 +74,19 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_ } AllocateBitset(); - cuda_leaf_gradient_stat_buffer_ = nullptr; - cuda_leaf_hessian_stat_buffer_ = nullptr; leaf_stat_buffer_size_ = 0; num_cat_threshold_ = 0; + if (config_->use_quantized_grad) { + cuda_leaf_gradient_stat_buffer_.Resize(config_->num_leaves); + cuda_leaf_hessian_stat_buffer_.Resize(config_->num_leaves); + cuda_gradient_discretizer_.reset(new CUDAGradientDiscretizer( + config_->num_grad_quant_bins, config_->num_iterations, config_->seed, is_constant_hessian, config_->stochastic_rounding)); + cuda_gradient_discretizer_->Init(num_data_, config_->num_leaves, train_data_->num_features(), train_data_); + } else { + cuda_gradient_discretizer_.reset(nullptr); + } + #ifdef DEBUG host_gradients_.resize(num_data_, 0.0f); host_hessians_.resize(num_data_, 0.0f); @@ -101,19 +110,37 @@ void CUDASingleGPUTreeLearner::BeforeTrain() { const data_size_t* leaf_splits_init_indices = cuda_data_partition_->use_bagging() ? cuda_data_partition_->cuda_data_indices() : nullptr; cuda_data_partition_->BeforeTrain(); - cuda_smaller_leaf_splits_->InitValues( - config_->lambda_l1, - config_->lambda_l2, - gradients_, - hessians_, - leaf_splits_init_indices, - cuda_data_partition_->cuda_data_indices(), - root_num_data, - cuda_histogram_constructor_->cuda_hist_pointer(), - &leaf_sum_hessians_[0]); + if (config_->use_quantized_grad) { + cuda_gradient_discretizer_->DiscretizeGradients(num_data_, gradients_, hessians_); + cuda_histogram_constructor_->BeforeTrain( + reinterpret_cast(cuda_gradient_discretizer_->discretized_gradients_and_hessians()), nullptr); + cuda_smaller_leaf_splits_->InitValues( + config_->lambda_l1, + config_->lambda_l2, + reinterpret_cast(cuda_gradient_discretizer_->discretized_gradients_and_hessians()), + leaf_splits_init_indices, + cuda_data_partition_->cuda_data_indices(), + root_num_data, + cuda_histogram_constructor_->cuda_hist_pointer(), + &leaf_sum_hessians_[0], + cuda_gradient_discretizer_->grad_scale_ptr(), + cuda_gradient_discretizer_->hess_scale_ptr()); + cuda_gradient_discretizer_->SetNumBitsInHistogramBin(0, -1, root_num_data, 0); + } else { + cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_); + cuda_smaller_leaf_splits_->InitValues( + config_->lambda_l1, + config_->lambda_l2, + gradients_, + hessians_, + leaf_splits_init_indices, + cuda_data_partition_->cuda_data_indices(), + root_num_data, + cuda_histogram_constructor_->cuda_hist_pointer(), + &leaf_sum_hessians_[0]); + } leaf_num_data_[0] = root_num_data; cuda_larger_leaf_splits_->InitValues(); - cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_); col_sampler_.ResetByTree(); cuda_best_split_finder_->BeforeTrain(col_sampler_.is_feature_used_bytree()); leaf_data_start_[0] = 0; @@ -141,24 +168,70 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, const data_size_t num_data_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_num_data_[larger_leaf_index_]; const double sum_hessians_in_smaller_leaf = leaf_sum_hessians_[smaller_leaf_index_]; const double sum_hessians_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_sum_hessians_[larger_leaf_index_]; + const uint8_t num_bits_in_histogram_bins = config_->use_quantized_grad ? cuda_gradient_discretizer_->GetHistBitsInLeaf(smaller_leaf_index_) : 0; cuda_histogram_constructor_->ConstructHistogramForLeaf( cuda_smaller_leaf_splits_->GetCUDAStruct(), cuda_larger_leaf_splits_->GetCUDAStruct(), num_data_in_smaller_leaf, num_data_in_larger_leaf, sum_hessians_in_smaller_leaf, - sum_hessians_in_larger_leaf); + sum_hessians_in_larger_leaf, + num_bits_in_histogram_bins); global_timer.Stop("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf"); global_timer.Start("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf"); - SelectFeatureByNode(tree.get()); - - cuda_best_split_finder_->FindBestSplitsForLeaf( + uint8_t parent_num_bits_bin = 0; + uint8_t smaller_num_bits_bin = 0; + uint8_t larger_num_bits_bin = 0; + if (config_->use_quantized_grad) { + if (larger_leaf_index_ != -1) { + const int parent_leaf_index = std::min(smaller_leaf_index_, larger_leaf_index_); + parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInNode(parent_leaf_index); + smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf(smaller_leaf_index_); + larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf(larger_leaf_index_); + } else { + parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf(0); + smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf(0); + larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf(0); + } + } else { + parent_num_bits_bin = 0; + smaller_num_bits_bin = 0; + larger_num_bits_bin = 0; + } + cuda_histogram_constructor_->SubtractHistogramForLeaf( cuda_smaller_leaf_splits_->GetCUDAStruct(), cuda_larger_leaf_splits_->GetCUDAStruct(), - smaller_leaf_index_, larger_leaf_index_, - num_data_in_smaller_leaf, num_data_in_larger_leaf, - sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf); + config_->use_quantized_grad, + parent_num_bits_bin, + smaller_num_bits_bin, + larger_num_bits_bin); + + SelectFeatureByNode(tree.get()); + + if (config_->use_quantized_grad) { + const uint8_t smaller_leaf_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf(smaller_leaf_index_); + const uint8_t larger_leaf_num_bits_bin = larger_leaf_index_ < 0 ? 32 : cuda_gradient_discretizer_->GetHistBitsInLeaf(larger_leaf_index_); + cuda_best_split_finder_->FindBestSplitsForLeaf( + cuda_smaller_leaf_splits_->GetCUDAStruct(), + cuda_larger_leaf_splits_->GetCUDAStruct(), + smaller_leaf_index_, larger_leaf_index_, + num_data_in_smaller_leaf, num_data_in_larger_leaf, + sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf, + cuda_gradient_discretizer_->grad_scale_ptr(), + cuda_gradient_discretizer_->hess_scale_ptr(), + smaller_leaf_num_bits_bin, + larger_leaf_num_bits_bin); + } else { + cuda_best_split_finder_->FindBestSplitsForLeaf( + cuda_smaller_leaf_splits_->GetCUDAStruct(), + cuda_larger_leaf_splits_->GetCUDAStruct(), + smaller_leaf_index_, larger_leaf_index_, + num_data_in_smaller_leaf, num_data_in_larger_leaf, + sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf, + nullptr, nullptr, 0, 0); + } + global_timer.Stop("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf"); global_timer.Start("CUDASingleGPUTreeLearner::FindBestFromAllSplits"); const CUDASplitInfo* best_split_info = nullptr; @@ -247,9 +320,19 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients, #endif // DEBUG smaller_leaf_index_ = (leaf_num_data_[best_leaf_index_] < leaf_num_data_[right_leaf_index] ? best_leaf_index_ : right_leaf_index); larger_leaf_index_ = (smaller_leaf_index_ == best_leaf_index_ ? right_leaf_index : best_leaf_index_); + + if (config_->use_quantized_grad) { + cuda_gradient_discretizer_->SetNumBitsInHistogramBin( + best_leaf_index_, right_leaf_index, leaf_num_data_[best_leaf_index_], leaf_num_data_[right_leaf_index]); + } global_timer.Stop("CUDASingleGPUTreeLearner::Split"); } SynchronizeCUDADevice(__FILE__, __LINE__); + if (config_->use_quantized_grad && config_->quant_train_renew_leaf) { + global_timer.Start("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves"); + RenewDiscretizedTreeLeaves(tree.get()); + global_timer.Stop("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves"); + } tree->ToHost(); return tree.release(); } @@ -357,8 +440,8 @@ void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFuncti Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const { std::unique_ptr cuda_tree(new CUDATree(old_tree)); - SetCUDAMemory(cuda_leaf_gradient_stat_buffer_, 0, static_cast(old_tree->num_leaves()), __FILE__, __LINE__); - SetCUDAMemory(cuda_leaf_hessian_stat_buffer_, 0, static_cast(old_tree->num_leaves()), __FILE__, __LINE__); + cuda_leaf_gradient_stat_buffer_.SetValue(0); + cuda_leaf_hessian_stat_buffer_.SetValue(0); ReduceLeafStat(cuda_tree.get(), gradients, hessians, cuda_data_partition_->cuda_data_indices()); cuda_tree->SyncLeafOutputFromCUDAToHost(); return cuda_tree.release(); @@ -373,13 +456,9 @@ Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const st const int num_block = (refit_num_data_ + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE; buffer_size *= static_cast(num_block + 1); } - if (buffer_size != leaf_stat_buffer_size_) { - if (leaf_stat_buffer_size_ != 0) { - DeallocateCUDAMemory(&cuda_leaf_gradient_stat_buffer_, __FILE__, __LINE__); - DeallocateCUDAMemory(&cuda_leaf_hessian_stat_buffer_, __FILE__, __LINE__); - } - AllocateCUDAMemory(&cuda_leaf_gradient_stat_buffer_, static_cast(buffer_size), __FILE__, __LINE__); - AllocateCUDAMemory(&cuda_leaf_hessian_stat_buffer_, static_cast(buffer_size), __FILE__, __LINE__); + if (static_cast(buffer_size) > cuda_leaf_gradient_stat_buffer_.Size()) { + cuda_leaf_gradient_stat_buffer_.Resize(buffer_size); + cuda_leaf_hessian_stat_buffer_.Resize(buffer_size); } return FitByExistingTree(old_tree, gradients, hessians); } @@ -513,6 +592,15 @@ void CUDASingleGPUTreeLearner::CheckSplitValid( } #endif // DEBUG +void CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves(CUDATree* cuda_tree) { + cuda_data_partition_->ReduceLeafGradStat( + gradients_, hessians_, cuda_tree, + cuda_leaf_gradient_stat_buffer_.RawData(), + cuda_leaf_hessian_stat_buffer_.RawData()); + LaunchCalcLeafValuesGivenGradStat(cuda_tree, cuda_data_partition_->cuda_data_indices()); + SynchronizeCUDADevice(__FILE__, __LINE__); +} + } // namespace LightGBM #endif // USE_CUDA diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu index 8a558ddc4..670f1f36d 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu @@ -129,18 +129,18 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel( if (num_leaves <= 2048) { ReduceLeafStatKernel_SharedMemory<<>>( gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(), - cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_); + cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData()); } else { ReduceLeafStatKernel_GlobalMemory<<>>( gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(), - cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_); + cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData()); } const bool use_l1 = config_->lambda_l1 > 0.0f; const bool use_smoothing = config_->path_smooth > 0.0f; num_block = (num_leaves + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE; #define CalcRefitLeafOutputKernel_ARGS \ - num_leaves, cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_, num_data_in_leaf, \ + num_leaves, cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \ leaf_parent, left_child, right_child, \ config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \ shrinkage_rate, config_->refit_decay_rate, cuda_leaf_value @@ -162,6 +162,7 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel( <<>>(CalcRefitLeafOutputKernel_ARGS); } } + #undef CalcRefitLeafOutputKernel_ARGS } template @@ -256,6 +257,37 @@ void CUDASingleGPUTreeLearner::LaunchConstructBitsetForCategoricalSplitKernel( CUDAConstructBitset(best_split_info, num_cat_threshold_, cuda_bitset_, cuda_bitset_len_); } +void CUDASingleGPUTreeLearner::LaunchCalcLeafValuesGivenGradStat( + CUDATree* cuda_tree, const data_size_t* num_data_in_leaf) { + #define CalcRefitLeafOutputKernel_ARGS \ + cuda_tree->num_leaves(), cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \ + cuda_tree->cuda_leaf_parent(), cuda_tree->cuda_left_child(), cuda_tree->cuda_right_child(), \ + config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \ + 1.0f, config_->refit_decay_rate, cuda_tree->cuda_leaf_value_ref() + const bool use_l1 = config_->lambda_l1 > 0.0f; + const bool use_smoothing = config_->path_smooth > 0.0f; + const int num_block = (cuda_tree->num_leaves() + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE; + if (!use_l1) { + if (!use_smoothing) { + CalcRefitLeafOutputKernel + <<>>(CalcRefitLeafOutputKernel_ARGS); + } else { + CalcRefitLeafOutputKernel + <<>>(CalcRefitLeafOutputKernel_ARGS); + } + } else { + if (!use_smoothing) { + CalcRefitLeafOutputKernel + <<>>(CalcRefitLeafOutputKernel_ARGS); + } else { + CalcRefitLeafOutputKernel + <<>>(CalcRefitLeafOutputKernel_ARGS); + } + } + + #undef CalcRefitLeafOutputKernel_ARGS +} + } // namespace LightGBM #endif // USE_CUDA diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp index 576d01ffe..fa782ebaa 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp @@ -16,6 +16,7 @@ #include "cuda_data_partition.hpp" #include "cuda_best_split_finder.hpp" +#include "cuda_gradient_discretizer.hpp" #include "../serial_tree_learner.h" namespace LightGBM { @@ -74,6 +75,10 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { const double sum_left_gradients, const double sum_right_gradients); #endif // DEBUG + void RenewDiscretizedTreeLeaves(CUDATree* cuda_tree); + + void LaunchCalcLeafValuesGivenGradStat(CUDATree* cuda_tree, const data_size_t* num_data_in_leaf); + // GPU device ID int gpu_device_id_; // number of threads on CPU @@ -90,6 +95,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { std::unique_ptr cuda_histogram_constructor_; // for best split information finding, given the histograms std::unique_ptr cuda_best_split_finder_; + // gradient discretizer for quantized training + std::unique_ptr cuda_gradient_discretizer_; std::vector leaf_best_split_feature_; std::vector leaf_best_split_threshold_; @@ -108,8 +115,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner { std::vector categorical_bin_to_value_; std::vector categorical_bin_offsets_; - mutable double* cuda_leaf_gradient_stat_buffer_; - mutable double* cuda_leaf_hessian_stat_buffer_; + mutable CUDAVector cuda_leaf_gradient_stat_buffer_; + mutable CUDAVector cuda_leaf_hessian_stat_buffer_; mutable data_size_t leaf_stat_buffer_size_; mutable data_size_t refit_num_data_; uint32_t* cuda_bitset_;