зеркало из https://github.com/microsoft/LightGBM.git
* add quantized training (first stage) * add histogram construction functions for integer gradients * add stochastic rounding * update docs * fix compilation errors by adding template instantiations * update files for compilation * fix compilation of gpu version * initialize gradient discretizer before share states * add a test case for quantized training * add quantized training for data distributed training * Delete origin.pred * Delete ifelse.pred * Delete LightGBM_model.txt * remove useless changes * fix lint error * remove debug loggings * fix mismatch of vector and allocator types * remove changes in main.cpp * fix bugs with uninitialized gradient discretizer * initialize ordered gradients in gradient discretizer * disable quantized training with gpu and cuda fix msvc compilation errors and warnings * fix bug in data parallel tree learner * make quantized training test deterministic * make quantized training in test case more accurate * refactor test_quantized_training * fix leaf splits initialization with quantized training * check distributed quantized training result * add cuda gradient discretizer * add quantized training for CUDA version in tree learner * remove cuda computability 6.1 and 6.2 * fix parts of gpu quantized training errors and warnings * fix build-python.sh to install locally built version * fix memory access bugs * fix lint errors * mark cuda quantized training on cuda with categorical features as unsupported * rename cuda_utils.h to cuda_utils.hu * enable quantized training with cuda * fix cuda quantized training with sparse row data * allow using global memory buffer in histogram construction with cuda quantized training * recover build-python.sh enlarge allowed package size to 100M
This commit is contained in:
Родитель
3d9ada7657
Коммит
f901f47141
|
@ -25,7 +25,7 @@ if [ $PY_MINOR_VER -gt 7 ]; then
|
|||
pydistcheck \
|
||||
--inspect \
|
||||
--ignore 'compiled-objects-have-debug-symbols,distro-too-large-compressed' \
|
||||
--max-allowed-size-uncompressed '70M' \
|
||||
--max-allowed-size-uncompressed '100M' \
|
||||
--max-allowed-files 800 \
|
||||
${DIST_DIR}/* || exit -1
|
||||
elif { test $(uname -m) = "aarch64"; }; then
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <LightGBM/bin.h>
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/utils/log.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#define LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_
|
||||
|
||||
#include <LightGBM/config.h>
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/bin.h>
|
||||
#include <LightGBM/utils/openmp_wrapper.h>
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#ifndef LIGHTGBM_CUDA_CUDA_METADATA_HPP_
|
||||
#define LIGHTGBM_CUDA_CUDA_METADATA_HPP_
|
||||
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/meta.h>
|
||||
|
||||
#include <vector>
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/metric.h>
|
||||
|
||||
namespace LightGBM {
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/objective_function.h>
|
||||
#include <LightGBM/meta.h>
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
#include <LightGBM/bin.h>
|
||||
#include <LightGBM/config.h>
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/dataset.h>
|
||||
#include <LightGBM/train_share_states.h>
|
||||
#include <LightGBM/utils/openmp_wrapper.h>
|
||||
|
|
|
@ -24,12 +24,14 @@ class CUDASplitInfo {
|
|||
|
||||
double left_sum_gradients;
|
||||
double left_sum_hessians;
|
||||
int64_t left_sum_of_gradients_hessians;
|
||||
data_size_t left_count;
|
||||
double left_gain;
|
||||
double left_value;
|
||||
|
||||
double right_sum_gradients;
|
||||
double right_sum_hessians;
|
||||
int64_t right_sum_of_gradients_hessians;
|
||||
data_size_t right_count;
|
||||
double right_gain;
|
||||
double right_value;
|
||||
|
|
|
@ -7,15 +7,21 @@
|
|||
#define LIGHTGBM_CUDA_CUDA_UTILS_H_
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <LightGBM/utils/log.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
typedef unsigned long long atomic_add_long_t;
|
||||
|
||||
#define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); }
|
||||
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) {
|
||||
if (code != cudaSuccess) {
|
||||
|
@ -125,13 +131,19 @@ class CUDAVector {
|
|||
T* new_data = nullptr;
|
||||
AllocateCUDAMemory<T>(&new_data, size, __FILE__, __LINE__);
|
||||
if (size_ > 0 && data_ != nullptr) {
|
||||
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size, __FILE__, __LINE__);
|
||||
const size_t size_for_old_content = std::min<size_t>(size_, size);
|
||||
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size_for_old_content, __FILE__, __LINE__);
|
||||
}
|
||||
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
|
||||
data_ = new_data;
|
||||
size_ = size;
|
||||
}
|
||||
|
||||
void InitFromHostVector(const std::vector<T>& host_vector) {
|
||||
Resize(host_vector.size());
|
||||
CopyFromHostToCUDADevice(data_, host_vector.data(), host_vector.size(), __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
if (size_ > 0 && data_ != nullptr) {
|
||||
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
|
||||
|
@ -171,6 +183,10 @@ class CUDAVector {
|
|||
return data_;
|
||||
}
|
||||
|
||||
void SetValue(int value) {
|
||||
SetCUDAMemory<T>(data_, value, size_, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
const T* RawDataReadOnly() const {
|
||||
return data_;
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
#ifndef LIGHTGBM_SAMPLE_STRATEGY_H_
|
||||
#define LIGHTGBM_SAMPLE_STRATEGY_H_
|
||||
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/utils/random.h>
|
||||
#include <LightGBM/utils/common.h>
|
||||
#include <LightGBM/utils/threading.h>
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
|
||||
#include "../score_updater.hpp"
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
|
|
|
@ -389,10 +389,6 @@ void Config::CheckParamConflict() {
|
|||
if (deterministic) {
|
||||
Log::Warning("Although \"deterministic\" is set, the results ran by GPU may be non-deterministic.");
|
||||
}
|
||||
if (use_quantized_grad) {
|
||||
Log::Warning("Quantized training is not supported by CUDA tree learner. Switch to full precision training.");
|
||||
use_quantized_grad = false;
|
||||
}
|
||||
}
|
||||
// linear tree learner must be serial type and run on CPU device
|
||||
if (linear_tree) {
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_metric.hpp>
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_metric.hpp>
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_metric.hpp>
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
|
|
@ -40,6 +40,9 @@ CUDABestSplitFinder::CUDABestSplitFinder(
|
|||
select_features_by_node_(select_features_by_node),
|
||||
cuda_hist_(cuda_hist) {
|
||||
InitFeatureMetaInfo(train_data);
|
||||
if (has_categorical_feature_ && config->use_quantized_grad) {
|
||||
Log::Fatal("Quantized training on GPU with categorical features is not supported yet.");
|
||||
}
|
||||
cuda_leaf_best_split_info_ = nullptr;
|
||||
cuda_best_split_info_ = nullptr;
|
||||
cuda_best_split_info_buffer_ = nullptr;
|
||||
|
@ -326,13 +329,23 @@ void CUDABestSplitFinder::FindBestSplitsForLeaf(
|
|||
const data_size_t num_data_in_smaller_leaf,
|
||||
const data_size_t num_data_in_larger_leaf,
|
||||
const double sum_hessians_in_smaller_leaf,
|
||||
const double sum_hessians_in_larger_leaf) {
|
||||
const double sum_hessians_in_larger_leaf,
|
||||
const score_t* grad_scale,
|
||||
const score_t* hess_scale,
|
||||
const uint8_t smaller_num_bits_in_histogram_bins,
|
||||
const uint8_t larger_num_bits_in_histogram_bins) {
|
||||
const bool is_smaller_leaf_valid = (num_data_in_smaller_leaf > min_data_in_leaf_ &&
|
||||
sum_hessians_in_smaller_leaf > min_sum_hessian_in_leaf_);
|
||||
const bool is_larger_leaf_valid = (num_data_in_larger_leaf > min_data_in_leaf_ &&
|
||||
sum_hessians_in_larger_leaf > min_sum_hessian_in_leaf_ && larger_leaf_index >= 0);
|
||||
if (grad_scale != nullptr && hess_scale != nullptr) {
|
||||
LaunchFindBestSplitsDiscretizedForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
|
||||
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid,
|
||||
grad_scale, hess_scale, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins);
|
||||
} else {
|
||||
LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
|
||||
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
|
||||
}
|
||||
global_timer.Start("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel");
|
||||
LaunchSyncBestSplitForLeafKernel(smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
|
|
|
@ -320,6 +320,175 @@ __device__ void FindBestSplitsForLeafKernelInner(
|
|||
}
|
||||
}
|
||||
|
||||
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool REVERSE, typename BIN_HIST_TYPE, typename ACC_HIST_TYPE, bool USE_16BIT_BIN_HIST, bool USE_16BIT_ACC_HIST>
|
||||
__device__ void FindBestSplitsDiscretizedForLeafKernelInner(
|
||||
// input feature information
|
||||
const BIN_HIST_TYPE* feature_hist_ptr,
|
||||
// input task information
|
||||
const SplitFindTask* task,
|
||||
CUDARandom* cuda_random,
|
||||
// input config parameter values
|
||||
const double lambda_l1,
|
||||
const double lambda_l2,
|
||||
const double path_smooth,
|
||||
const data_size_t min_data_in_leaf,
|
||||
const double min_sum_hessian_in_leaf,
|
||||
const double min_gain_to_split,
|
||||
// input parent node information
|
||||
const double parent_gain,
|
||||
const int64_t sum_gradients_hessians,
|
||||
const data_size_t num_data,
|
||||
const double parent_output,
|
||||
// gradient scale
|
||||
const double grad_scale,
|
||||
const double hess_scale,
|
||||
// output parameters
|
||||
CUDASplitInfo* cuda_best_split_info) {
|
||||
const double sum_hessians = static_cast<double>(sum_gradients_hessians & 0x00000000ffffffff) * hess_scale;
|
||||
const double cnt_factor = num_data / sum_hessians;
|
||||
const double min_gain_shift = parent_gain + min_gain_to_split;
|
||||
|
||||
cuda_best_split_info->is_valid = false;
|
||||
|
||||
ACC_HIST_TYPE local_grad_hess_hist = 0;
|
||||
double local_gain = 0.0f;
|
||||
bool threshold_found = false;
|
||||
uint32_t threshold_value = 0;
|
||||
__shared__ int rand_threshold;
|
||||
if (USE_RAND && threadIdx.x == 0) {
|
||||
if (task->num_bin - 2 > 0) {
|
||||
rand_threshold = cuda_random->NextInt(0, task->num_bin - 2);
|
||||
}
|
||||
}
|
||||
__shared__ uint32_t best_thread_index;
|
||||
__shared__ double shared_double_buffer[32];
|
||||
__shared__ bool shared_bool_buffer[32];
|
||||
__shared__ uint32_t shared_int_buffer[64];
|
||||
const unsigned int threadIdx_x = threadIdx.x;
|
||||
const bool skip_sum = REVERSE ?
|
||||
(task->skip_default_bin && (task->num_bin - 1 - threadIdx_x) == static_cast<int>(task->default_bin)) :
|
||||
(task->skip_default_bin && (threadIdx_x + task->mfb_offset) == static_cast<int>(task->default_bin));
|
||||
const uint32_t feature_num_bin_minus_offset = task->num_bin - task->mfb_offset;
|
||||
if (!REVERSE) {
|
||||
if (threadIdx_x < feature_num_bin_minus_offset && !skip_sum) {
|
||||
const unsigned int bin_offset = threadIdx_x;
|
||||
if (USE_16BIT_BIN_HIST && !USE_16BIT_ACC_HIST) {
|
||||
const int32_t local_grad_hess_hist_int32 = feature_hist_ptr[bin_offset];
|
||||
local_grad_hess_hist = (static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist_int32 >> 16)) << 32) | (static_cast<int64_t>(local_grad_hess_hist_int32 & 0x0000ffff));
|
||||
} else {
|
||||
local_grad_hess_hist = feature_hist_ptr[bin_offset];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (threadIdx_x >= static_cast<unsigned int>(task->na_as_missing) &&
|
||||
threadIdx_x < feature_num_bin_minus_offset && !skip_sum) {
|
||||
const unsigned int read_index = feature_num_bin_minus_offset - 1 - threadIdx_x;
|
||||
if (USE_16BIT_BIN_HIST && !USE_16BIT_ACC_HIST) {
|
||||
const int32_t local_grad_hess_hist_int32 = feature_hist_ptr[read_index];
|
||||
local_grad_hess_hist = (static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist_int32 >> 16)) << 32) | (static_cast<int64_t>(local_grad_hess_hist_int32 & 0x0000ffff));
|
||||
} else {
|
||||
local_grad_hess_hist = feature_hist_ptr[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
local_gain = kMinScore;
|
||||
local_grad_hess_hist = ShufflePrefixSum<ACC_HIST_TYPE>(local_grad_hess_hist, reinterpret_cast<ACC_HIST_TYPE*>(shared_int_buffer));
|
||||
double sum_left_gradient = 0.0f;
|
||||
double sum_left_hessian = 0.0f;
|
||||
double sum_right_gradient = 0.0f;
|
||||
double sum_right_hessian = 0.0f;
|
||||
data_size_t left_count = 0;
|
||||
data_size_t right_count = 0;
|
||||
int64_t sum_left_gradient_hessian = 0;
|
||||
int64_t sum_right_gradient_hessian = 0;
|
||||
if (REVERSE) {
|
||||
if (threadIdx_x >= static_cast<unsigned int>(task->na_as_missing) && threadIdx_x <= task->num_bin - 2 && !skip_sum) {
|
||||
sum_right_gradient_hessian = USE_16BIT_ACC_HIST ?
|
||||
(static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist >> 16)) << 32) | static_cast<int64_t>(local_grad_hess_hist & 0x0000ffff) :
|
||||
local_grad_hess_hist;
|
||||
sum_right_gradient = static_cast<double>(static_cast<int32_t>((sum_right_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale;
|
||||
sum_right_hessian = static_cast<double>(static_cast<int32_t>(sum_right_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
|
||||
right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
|
||||
sum_left_gradient_hessian = sum_gradients_hessians - sum_right_gradient_hessian;
|
||||
sum_left_gradient = static_cast<double>(static_cast<int32_t>((sum_left_gradient_hessian & 0xffffffff00000000)>> 32)) * grad_scale;
|
||||
sum_left_hessian = static_cast<double>(static_cast<int32_t>(sum_left_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
|
||||
left_count = num_data - right_count;
|
||||
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
|
||||
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
|
||||
(!USE_RAND || static_cast<int>(task->num_bin - 2 - threadIdx_x) == rand_threshold)) {
|
||||
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
|
||||
sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient,
|
||||
sum_right_hessian + kEpsilon, lambda_l1,
|
||||
lambda_l2, path_smooth, left_count, right_count, parent_output);
|
||||
// gain with split is worse than without split
|
||||
if (current_gain > min_gain_shift) {
|
||||
local_gain = current_gain - min_gain_shift;
|
||||
threshold_value = static_cast<uint32_t>(task->num_bin - 2 - threadIdx_x);
|
||||
threshold_found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (threadIdx_x <= feature_num_bin_minus_offset - 2 && !skip_sum) {
|
||||
sum_left_gradient_hessian = USE_16BIT_ACC_HIST ?
|
||||
(static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist >> 16)) << 32) | static_cast<int64_t>(local_grad_hess_hist & 0x0000ffff) :
|
||||
local_grad_hess_hist;
|
||||
sum_left_gradient = static_cast<double>(static_cast<int32_t>((sum_left_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale;
|
||||
sum_left_hessian = static_cast<double>(static_cast<int32_t>(sum_left_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
|
||||
left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
|
||||
sum_right_gradient_hessian = sum_gradients_hessians - sum_left_gradient_hessian;
|
||||
sum_right_gradient = static_cast<double>(static_cast<int32_t>((sum_right_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale;
|
||||
sum_right_hessian = static_cast<double>(static_cast<int32_t>(sum_right_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
|
||||
right_count = num_data - left_count;
|
||||
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
|
||||
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
|
||||
(!USE_RAND || static_cast<int>(threadIdx_x + task->mfb_offset) == rand_threshold)) {
|
||||
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
|
||||
sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient,
|
||||
sum_right_hessian + kEpsilon, lambda_l1,
|
||||
lambda_l2, path_smooth, left_count, right_count, parent_output);
|
||||
// gain with split is worse than without split
|
||||
if (current_gain > min_gain_shift) {
|
||||
local_gain = current_gain - min_gain_shift;
|
||||
threshold_value = static_cast<uint32_t>(threadIdx_x + task->mfb_offset);
|
||||
threshold_found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_double_buffer, shared_bool_buffer, shared_int_buffer);
|
||||
if (threadIdx_x == 0) {
|
||||
best_thread_index = result;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threshold_found && threadIdx_x == best_thread_index) {
|
||||
cuda_best_split_info->is_valid = true;
|
||||
cuda_best_split_info->threshold = threshold_value;
|
||||
cuda_best_split_info->gain = local_gain;
|
||||
cuda_best_split_info->default_left = task->assume_out_default_left;
|
||||
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
|
||||
sum_left_hessian, lambda_l1, lambda_l2, path_smooth, left_count, parent_output);
|
||||
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
|
||||
sum_right_hessian, lambda_l1, lambda_l2, path_smooth, right_count, parent_output);
|
||||
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
|
||||
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
|
||||
cuda_best_split_info->left_sum_of_gradients_hessians = sum_left_gradient_hessian;
|
||||
cuda_best_split_info->left_count = left_count;
|
||||
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
|
||||
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
|
||||
cuda_best_split_info->right_sum_of_gradients_hessians = sum_right_gradient_hessian;
|
||||
cuda_best_split_info->right_count = right_count;
|
||||
cuda_best_split_info->left_value = left_output;
|
||||
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
|
||||
sum_left_hessian, lambda_l1, lambda_l2, left_output);
|
||||
cuda_best_split_info->right_value = right_output;
|
||||
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
|
||||
sum_right_hessian, lambda_l1, lambda_l2, right_output);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
|
||||
__device__ void FindBestSplitsForLeafKernelCategoricalInner(
|
||||
// input feature information
|
||||
|
@ -715,6 +884,169 @@ __global__ void FindBestSplitsForLeafKernel(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool IS_LARGER>
|
||||
__global__ void FindBestSplitsDiscretizedForLeafKernel(
|
||||
// input feature information
|
||||
const int8_t* is_feature_used_bytree,
|
||||
// input task information
|
||||
const int num_tasks,
|
||||
const SplitFindTask* tasks,
|
||||
CUDARandom* cuda_randoms,
|
||||
// input leaf information
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits,
|
||||
const CUDALeafSplitsStruct* larger_leaf_splits,
|
||||
const uint8_t smaller_leaf_num_bits_in_histogram_bin,
|
||||
const uint8_t larger_leaf_num_bits_in_histogram_bin,
|
||||
// input config parameter values
|
||||
const data_size_t min_data_in_leaf,
|
||||
const double min_sum_hessian_in_leaf,
|
||||
const double min_gain_to_split,
|
||||
const double lambda_l1,
|
||||
const double lambda_l2,
|
||||
const double path_smooth,
|
||||
const double cat_smooth,
|
||||
const double cat_l2,
|
||||
const int max_cat_threshold,
|
||||
const int min_data_per_group,
|
||||
const int max_cat_to_onehot,
|
||||
// gradient scale
|
||||
const score_t* grad_scale,
|
||||
const score_t* hess_scale,
|
||||
// output
|
||||
CUDASplitInfo* cuda_best_split_info) {
|
||||
const unsigned int task_index = blockIdx.x;
|
||||
const SplitFindTask* task = tasks + task_index;
|
||||
const int inner_feature_index = task->inner_feature_index;
|
||||
const double parent_gain = IS_LARGER ? larger_leaf_splits->gain : smaller_leaf_splits->gain;
|
||||
const int64_t sum_gradients_hessians = IS_LARGER ? larger_leaf_splits->sum_of_gradients_hessians : smaller_leaf_splits->sum_of_gradients_hessians;
|
||||
const data_size_t num_data = IS_LARGER ? larger_leaf_splits->num_data_in_leaf : smaller_leaf_splits->num_data_in_leaf;
|
||||
const double parent_output = IS_LARGER ? larger_leaf_splits->leaf_value : smaller_leaf_splits->leaf_value;
|
||||
const unsigned int output_offset = IS_LARGER ? (task_index + num_tasks) : task_index;
|
||||
CUDASplitInfo* out = cuda_best_split_info + output_offset;
|
||||
CUDARandom* cuda_random = USE_RAND ?
|
||||
(IS_LARGER ? cuda_randoms + task_index * 2 + 1 : cuda_randoms + task_index * 2) : nullptr;
|
||||
const bool use_16bit_bin = IS_LARGER ? (larger_leaf_num_bits_in_histogram_bin <= 16) : (smaller_leaf_num_bits_in_histogram_bin <= 16);
|
||||
if (is_feature_used_bytree[inner_feature_index]) {
|
||||
if (task->is_categorical) {
|
||||
__threadfence(); // ensure store issued before trap
|
||||
asm("trap;");
|
||||
} else {
|
||||
if (!task->reverse) {
|
||||
if (use_16bit_bin) {
|
||||
const int32_t* hist_ptr =
|
||||
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
|
||||
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, false, int32_t, int32_t, true, true>(
|
||||
// input feature information
|
||||
hist_ptr,
|
||||
// input task information
|
||||
task,
|
||||
cuda_random,
|
||||
// input config parameter values
|
||||
lambda_l1,
|
||||
lambda_l2,
|
||||
path_smooth,
|
||||
min_data_in_leaf,
|
||||
min_sum_hessian_in_leaf,
|
||||
min_gain_to_split,
|
||||
// input parent node information
|
||||
parent_gain,
|
||||
sum_gradients_hessians,
|
||||
num_data,
|
||||
parent_output,
|
||||
// gradient scale
|
||||
*grad_scale,
|
||||
*hess_scale,
|
||||
// output parameters
|
||||
out);
|
||||
} else {
|
||||
const int32_t* hist_ptr =
|
||||
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
|
||||
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, false, int32_t, int64_t, false, false>(
|
||||
// input feature information
|
||||
hist_ptr,
|
||||
// input task information
|
||||
task,
|
||||
cuda_random,
|
||||
// input config parameter values
|
||||
lambda_l1,
|
||||
lambda_l2,
|
||||
path_smooth,
|
||||
min_data_in_leaf,
|
||||
min_sum_hessian_in_leaf,
|
||||
min_gain_to_split,
|
||||
// input parent node information
|
||||
parent_gain,
|
||||
sum_gradients_hessians,
|
||||
num_data,
|
||||
parent_output,
|
||||
// gradient scale
|
||||
*grad_scale,
|
||||
*hess_scale,
|
||||
// output parameters
|
||||
out);
|
||||
}
|
||||
} else {
|
||||
if (use_16bit_bin) {
|
||||
const int32_t* hist_ptr =
|
||||
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
|
||||
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, true, int32_t, int32_t, true, true>(
|
||||
// input feature information
|
||||
hist_ptr,
|
||||
// input task information
|
||||
task,
|
||||
cuda_random,
|
||||
// input config parameter values
|
||||
lambda_l1,
|
||||
lambda_l2,
|
||||
path_smooth,
|
||||
min_data_in_leaf,
|
||||
min_sum_hessian_in_leaf,
|
||||
min_gain_to_split,
|
||||
// input parent node information
|
||||
parent_gain,
|
||||
sum_gradients_hessians,
|
||||
num_data,
|
||||
parent_output,
|
||||
// gradient scale
|
||||
*grad_scale,
|
||||
*hess_scale,
|
||||
// output parameters
|
||||
out);
|
||||
} else {
|
||||
const int32_t* hist_ptr =
|
||||
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
|
||||
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, true, int32_t, int64_t, false, false>(
|
||||
// input feature information
|
||||
hist_ptr,
|
||||
// input task information
|
||||
task,
|
||||
cuda_random,
|
||||
// input config parameter values
|
||||
lambda_l1,
|
||||
lambda_l2,
|
||||
path_smooth,
|
||||
min_data_in_leaf,
|
||||
min_sum_hessian_in_leaf,
|
||||
min_gain_to_split,
|
||||
// input parent node information
|
||||
parent_gain,
|
||||
sum_gradients_hessians,
|
||||
num_data,
|
||||
parent_output,
|
||||
// gradient scale
|
||||
*grad_scale,
|
||||
*hess_scale,
|
||||
// output parameters
|
||||
out);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
out->is_valid = false;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool REVERSE>
|
||||
__device__ void FindBestSplitsForLeafKernelInner_GlobalMemory(
|
||||
// input feature information
|
||||
|
@ -1466,6 +1798,108 @@ void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner2(LaunchFindBest
|
|||
#undef FindBestSplitsForLeafKernel_ARGS
|
||||
#undef GlobalMemory_Buffer_ARGS
|
||||
|
||||
|
||||
#define LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS \
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits, \
|
||||
const CUDALeafSplitsStruct* larger_leaf_splits, \
|
||||
const int smaller_leaf_index, \
|
||||
const int larger_leaf_index, \
|
||||
const bool is_smaller_leaf_valid, \
|
||||
const bool is_larger_leaf_valid, \
|
||||
const score_t* grad_scale, \
|
||||
const score_t* hess_scale, \
|
||||
const uint8_t smaller_num_bits_in_histogram_bins, \
|
||||
const uint8_t larger_num_bits_in_histogram_bins
|
||||
|
||||
#define LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS \
|
||||
smaller_leaf_splits, \
|
||||
larger_leaf_splits, \
|
||||
smaller_leaf_index, \
|
||||
larger_leaf_index, \
|
||||
is_smaller_leaf_valid, \
|
||||
is_larger_leaf_valid, \
|
||||
grad_scale, \
|
||||
hess_scale, \
|
||||
smaller_num_bits_in_histogram_bins, \
|
||||
larger_num_bits_in_histogram_bins
|
||||
|
||||
#define FindBestSplitsDiscretizedForLeafKernel_ARGS \
|
||||
cuda_is_feature_used_bytree_, \
|
||||
num_tasks_, \
|
||||
cuda_split_find_tasks_.RawData(), \
|
||||
cuda_randoms_.RawData(), \
|
||||
smaller_leaf_splits, \
|
||||
larger_leaf_splits, \
|
||||
smaller_num_bits_in_histogram_bins, \
|
||||
larger_num_bits_in_histogram_bins, \
|
||||
min_data_in_leaf_, \
|
||||
min_sum_hessian_in_leaf_, \
|
||||
min_gain_to_split_, \
|
||||
lambda_l1_, \
|
||||
lambda_l2_, \
|
||||
path_smooth_, \
|
||||
cat_smooth_, \
|
||||
cat_l2_, \
|
||||
max_cat_threshold_, \
|
||||
min_data_per_group_, \
|
||||
max_cat_to_onehot_, \
|
||||
grad_scale, \
|
||||
hess_scale, \
|
||||
cuda_best_split_info_
|
||||
|
||||
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernel(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
|
||||
if (!is_smaller_leaf_valid && !is_larger_leaf_valid) {
|
||||
return;
|
||||
}
|
||||
if (!extra_trees_) {
|
||||
LaunchFindBestSplitsDiscretizedForLeafKernelInner0<false>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
} else {
|
||||
LaunchFindBestSplitsDiscretizedForLeafKernelInner0<true>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_RAND>
|
||||
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
|
||||
if (lambda_l1_ <= 0.0f) {
|
||||
LaunchFindBestSplitsDiscretizedForLeafKernelInner1<USE_RAND, false>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
} else {
|
||||
LaunchFindBestSplitsDiscretizedForLeafKernelInner1<USE_RAND, true>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_RAND, bool USE_L1>
|
||||
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
|
||||
if (!use_smoothing_) {
|
||||
LaunchFindBestSplitsDiscretizedForLeafKernelInner2<USE_RAND, USE_L1, false>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
} else {
|
||||
LaunchFindBestSplitsDiscretizedForLeafKernelInner2<USE_RAND, USE_L1, true>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
|
||||
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
|
||||
if (!use_global_memory_) {
|
||||
if (is_smaller_leaf_valid) {
|
||||
FindBestSplitsDiscretizedForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, false>
|
||||
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[0]>>>
|
||||
(FindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
}
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
if (is_larger_leaf_valid) {
|
||||
FindBestSplitsDiscretizedForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, true>
|
||||
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[1]>>>
|
||||
(FindBestSplitsDiscretizedForLeafKernel_ARGS);
|
||||
}
|
||||
} else {
|
||||
// TODO(shiyu1994)
|
||||
}
|
||||
}
|
||||
|
||||
#undef LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS
|
||||
#undef LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS
|
||||
#undef FindBestSplitsDiscretizedForLeafKernel_ARGS
|
||||
|
||||
|
||||
__device__ void ReduceBestSplit(bool* found, double* gain, uint32_t* shared_read_index,
|
||||
uint32_t num_features_aligned) {
|
||||
const uint32_t threadIdx_x = threadIdx.x;
|
||||
|
|
|
@ -67,7 +67,11 @@ class CUDABestSplitFinder {
|
|||
const data_size_t num_data_in_smaller_leaf,
|
||||
const data_size_t num_data_in_larger_leaf,
|
||||
const double sum_hessians_in_smaller_leaf,
|
||||
const double sum_hessians_in_larger_leaf);
|
||||
const double sum_hessians_in_larger_leaf,
|
||||
const score_t* grad_scale,
|
||||
const score_t* hess_scale,
|
||||
const uint8_t smaller_num_bits_in_histogram_bins,
|
||||
const uint8_t larger_num_bits_in_histogram_bins);
|
||||
|
||||
const CUDASplitInfo* FindBestFromAllSplits(
|
||||
const int cur_num_leaves,
|
||||
|
@ -114,6 +118,31 @@ class CUDABestSplitFinder {
|
|||
|
||||
#undef LaunchFindBestSplitsForLeafKernel_PARAMS
|
||||
|
||||
#define LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS \
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits, \
|
||||
const CUDALeafSplitsStruct* larger_leaf_splits, \
|
||||
const int smaller_leaf_index, \
|
||||
const int larger_leaf_index, \
|
||||
const bool is_smaller_leaf_valid, \
|
||||
const bool is_larger_leaf_valid, \
|
||||
const score_t* grad_scale, \
|
||||
const score_t* hess_scale, \
|
||||
const uint8_t smaller_num_bits_in_histogram_bins, \
|
||||
const uint8_t larger_num_bits_in_histogram_bins
|
||||
|
||||
void LaunchFindBestSplitsDiscretizedForLeafKernel(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
|
||||
|
||||
template <bool USE_RAND>
|
||||
void LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
|
||||
|
||||
template <bool USE_RAND, bool USE_L1>
|
||||
void LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
|
||||
|
||||
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
|
||||
void LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
|
||||
|
||||
#undef LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS
|
||||
|
||||
void LaunchSyncBestSplitForLeafKernel(
|
||||
const int host_smaller_leaf_index,
|
||||
const int host_larger_leaf_index,
|
||||
|
|
|
@ -368,6 +368,12 @@ void CUDADataPartition::ResetByLeafPred(const std::vector<int>& leaf_pred, int n
|
|||
cur_num_leaves_ = num_leaves;
|
||||
}
|
||||
|
||||
void CUDADataPartition::ReduceLeafGradStat(
|
||||
const score_t* gradients, const score_t* hessians,
|
||||
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const {
|
||||
LaunchReduceLeafGradStat(gradients, hessians, tree, leaf_grad_stat_buffer, leaf_hess_state_buffer);
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // USE_CUDA
|
||||
|
|
|
@ -1069,6 +1069,53 @@ void CUDADataPartition::LaunchAddPredictionToScoreKernel(const double* leaf_valu
|
|||
global_timer.Stop("CUDADataPartition::AddPredictionToScoreKernel");
|
||||
}
|
||||
|
||||
__global__ void RenewDiscretizedTreeLeavesKernel(
|
||||
const score_t* gradients,
|
||||
const score_t* hessians,
|
||||
const data_size_t* data_indices,
|
||||
const data_size_t* leaf_data_start,
|
||||
const data_size_t* leaf_num_data,
|
||||
double* leaf_grad_stat_buffer,
|
||||
double* leaf_hess_stat_buffer,
|
||||
double* leaf_values) {
|
||||
__shared__ double shared_mem_buffer[32];
|
||||
const int leaf_index = static_cast<int>(blockIdx.x);
|
||||
const data_size_t* data_indices_in_leaf = data_indices + leaf_data_start[leaf_index];
|
||||
const data_size_t num_data_in_leaf = leaf_num_data[leaf_index];
|
||||
double sum_gradients = 0.0f;
|
||||
double sum_hessians = 0.0f;
|
||||
for (data_size_t inner_data_index = static_cast<int>(threadIdx.x);
|
||||
inner_data_index < num_data_in_leaf; inner_data_index += static_cast<int>(blockDim.x)) {
|
||||
const data_size_t data_index = data_indices_in_leaf[inner_data_index];
|
||||
const score_t gradient = gradients[data_index];
|
||||
const score_t hessian = hessians[data_index];
|
||||
sum_gradients += static_cast<double>(gradient);
|
||||
sum_hessians += static_cast<double>(hessian);
|
||||
}
|
||||
sum_gradients = ShuffleReduceSum<double>(sum_gradients, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
sum_hessians = ShuffleReduceSum<double>(sum_hessians, shared_mem_buffer, blockDim.x);
|
||||
if (threadIdx.x == 0) {
|
||||
leaf_grad_stat_buffer[leaf_index] = sum_gradients;
|
||||
leaf_hess_stat_buffer[leaf_index] = sum_hessians;
|
||||
}
|
||||
}
|
||||
|
||||
void CUDADataPartition::LaunchReduceLeafGradStat(
|
||||
const score_t* gradients, const score_t* hessians,
|
||||
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const {
|
||||
const int num_blocks = tree->num_leaves();
|
||||
RenewDiscretizedTreeLeavesKernel<<<num_blocks, FILL_INDICES_BLOCK_SIZE_DATA_PARTITION>>>(
|
||||
gradients,
|
||||
hessians,
|
||||
cuda_data_indices_,
|
||||
cuda_leaf_data_start_,
|
||||
cuda_leaf_num_data_,
|
||||
leaf_grad_stat_buffer,
|
||||
leaf_hess_state_buffer,
|
||||
tree->cuda_leaf_value_ref());
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // USE_CUDA
|
||||
|
|
|
@ -78,6 +78,10 @@ class CUDADataPartition {
|
|||
|
||||
void ResetByLeafPred(const std::vector<int>& leaf_pred, int num_leaves);
|
||||
|
||||
void ReduceLeafGradStat(
|
||||
const score_t* gradients, const score_t* hessians,
|
||||
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const;
|
||||
|
||||
data_size_t root_num_data() const {
|
||||
if (use_bagging_) {
|
||||
return num_used_indices_;
|
||||
|
@ -292,6 +296,10 @@ class CUDADataPartition {
|
|||
|
||||
void LaunchFillDataIndexToLeafIndex();
|
||||
|
||||
void LaunchReduceLeafGradStat(
|
||||
const score_t* gradients, const score_t* hessians,
|
||||
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const;
|
||||
|
||||
// Host memory
|
||||
|
||||
// dataset information
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
/*!
|
||||
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See LICENSE file in the project root for
|
||||
* license information.
|
||||
*/
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <LightGBM/cuda/cuda_algorithms.hpp>
|
||||
|
||||
#include "cuda_gradient_discretizer.hpp"
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
__global__ void ReduceMinMaxKernel(
|
||||
const data_size_t num_data,
|
||||
const score_t* input_gradients,
|
||||
const score_t* input_hessians,
|
||||
score_t* grad_min_block_buffer,
|
||||
score_t* grad_max_block_buffer,
|
||||
score_t* hess_min_block_buffer,
|
||||
score_t* hess_max_block_buffer) {
|
||||
__shared__ score_t shared_mem_buffer[32];
|
||||
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
|
||||
score_t grad_max_val = kMinScore;
|
||||
score_t grad_min_val = kMaxScore;
|
||||
score_t hess_max_val = kMinScore;
|
||||
score_t hess_min_val = kMaxScore;
|
||||
if (index < num_data) {
|
||||
grad_max_val = input_gradients[index];
|
||||
grad_min_val = input_gradients[index];
|
||||
hess_max_val = input_hessians[index];
|
||||
hess_min_val = input_hessians[index];
|
||||
}
|
||||
grad_min_val = ShuffleReduceMin<score_t>(grad_min_val, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
grad_max_val = ShuffleReduceMax<score_t>(grad_max_val, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
hess_min_val = ShuffleReduceMin<score_t>(hess_min_val, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
|
||||
if (threadIdx.x == 0) {
|
||||
grad_min_block_buffer[blockIdx.x] = grad_min_val;
|
||||
grad_max_block_buffer[blockIdx.x] = grad_max_val;
|
||||
hess_min_block_buffer[blockIdx.x] = hess_min_val;
|
||||
hess_max_block_buffer[blockIdx.x] = hess_max_val;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ReduceBlockMinMaxKernel(
|
||||
const int num_blocks,
|
||||
const int grad_discretize_bins,
|
||||
score_t* grad_min_block_buffer,
|
||||
score_t* grad_max_block_buffer,
|
||||
score_t* hess_min_block_buffer,
|
||||
score_t* hess_max_block_buffer) {
|
||||
__shared__ score_t shared_mem_buffer[32];
|
||||
score_t grad_max_val = kMinScore;
|
||||
score_t grad_min_val = kMaxScore;
|
||||
score_t hess_max_val = kMinScore;
|
||||
score_t hess_min_val = kMaxScore;
|
||||
for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks; block_index += static_cast<int>(blockDim.x)) {
|
||||
grad_min_val = min(grad_min_val, grad_min_block_buffer[block_index]);
|
||||
grad_max_val = max(grad_max_val, grad_max_block_buffer[block_index]);
|
||||
hess_min_val = min(hess_min_val, hess_min_block_buffer[block_index]);
|
||||
hess_max_val = max(hess_max_val, hess_max_block_buffer[block_index]);
|
||||
}
|
||||
grad_min_val = ShuffleReduceMin<score_t>(grad_min_val, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
grad_max_val = ShuffleReduceMax<score_t>(grad_max_val, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
|
||||
if (threadIdx.x == 0) {
|
||||
const score_t grad_abs_max = max(fabs(grad_min_val), fabs(grad_max_val));
|
||||
const score_t hess_abs_max = max(fabs(hess_min_val), fabs(hess_max_val));
|
||||
grad_min_block_buffer[0] = 1.0f / (grad_abs_max / (grad_discretize_bins / 2));
|
||||
grad_max_block_buffer[0] = (grad_abs_max / (grad_discretize_bins / 2));
|
||||
hess_min_block_buffer[0] = 1.0f / (hess_abs_max / (grad_discretize_bins));
|
||||
hess_max_block_buffer[0] = (hess_abs_max / (grad_discretize_bins));
|
||||
}
|
||||
}
|
||||
|
||||
template <bool STOCHASTIC_ROUNDING>
|
||||
__global__ void DiscretizeGradientsKernel(
|
||||
const data_size_t num_data,
|
||||
const score_t* input_gradients,
|
||||
const score_t* input_hessians,
|
||||
const score_t* grad_scale_ptr,
|
||||
const score_t* hess_scale_ptr,
|
||||
const int iter,
|
||||
const int* random_values_use_start,
|
||||
const score_t* gradient_random_values,
|
||||
const score_t* hessian_random_values,
|
||||
const int grad_discretize_bins,
|
||||
int8_t* output_gradients_and_hessians) {
|
||||
const int start = random_values_use_start[iter];
|
||||
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
|
||||
const score_t grad_scale = *grad_scale_ptr;
|
||||
const score_t hess_scale = *hess_scale_ptr;
|
||||
int16_t* output_gradients_and_hessians_ptr = reinterpret_cast<int16_t*>(output_gradients_and_hessians);
|
||||
if (index < num_data) {
|
||||
if (STOCHASTIC_ROUNDING) {
|
||||
const data_size_t index_offset = (index + start) % num_data;
|
||||
const score_t gradient = input_gradients[index];
|
||||
const score_t hessian = input_hessians[index];
|
||||
const score_t gradient_random_value = gradient_random_values[index_offset];
|
||||
const score_t hessian_random_value = hessian_random_values[index_offset];
|
||||
output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ?
|
||||
static_cast<int16_t>(gradient * grad_scale + gradient_random_value) :
|
||||
static_cast<int16_t>(gradient * grad_scale - gradient_random_value);
|
||||
output_gradients_and_hessians_ptr[2 * index] = static_cast<int16_t>(hessian * hess_scale + hessian_random_value);
|
||||
} else {
|
||||
const score_t gradient = input_gradients[index];
|
||||
const score_t hessian = input_hessians[index];
|
||||
output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ?
|
||||
static_cast<int16_t>(gradient * grad_scale + 0.5) :
|
||||
static_cast<int16_t>(gradient * grad_scale - 0.5);
|
||||
output_gradients_and_hessians_ptr[2 * index] = static_cast<int16_t>(hessian * hess_scale + 0.5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAGradientDiscretizer::DiscretizeGradients(
|
||||
const data_size_t num_data,
|
||||
const score_t* input_gradients,
|
||||
const score_t* input_hessians) {
|
||||
ReduceMinMaxKernel<<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(
|
||||
num_data, input_gradients, input_hessians,
|
||||
grad_min_block_buffer_.RawData(),
|
||||
grad_max_block_buffer_.RawData(),
|
||||
hess_min_block_buffer_.RawData(),
|
||||
hess_max_block_buffer_.RawData());
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
ReduceBlockMinMaxKernel<<<1, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(
|
||||
num_reduce_blocks_,
|
||||
num_grad_quant_bins_,
|
||||
grad_min_block_buffer_.RawData(),
|
||||
grad_max_block_buffer_.RawData(),
|
||||
hess_min_block_buffer_.RawData(),
|
||||
hess_max_block_buffer_.RawData());
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
|
||||
#define DiscretizeGradientsKernel_ARGS \
|
||||
num_data, \
|
||||
input_gradients, \
|
||||
input_hessians, \
|
||||
grad_min_block_buffer_.RawData(), \
|
||||
hess_min_block_buffer_.RawData(), \
|
||||
iter_, \
|
||||
random_values_use_start_.RawData(), \
|
||||
gradient_random_values_.RawData(), \
|
||||
hessian_random_values_.RawData(), \
|
||||
num_grad_quant_bins_, \
|
||||
discretized_gradients_and_hessians_.RawData()
|
||||
|
||||
if (stochastic_rounding_) {
|
||||
DiscretizeGradientsKernel<true><<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(DiscretizeGradientsKernel_ARGS);
|
||||
} else {
|
||||
DiscretizeGradientsKernel<false><<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(DiscretizeGradientsKernel_ARGS);
|
||||
}
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
++iter_;
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // USE_CUDA
|
|
@ -0,0 +1,118 @@
|
|||
/*!
|
||||
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See LICENSE file in the project root for
|
||||
* license information.
|
||||
*/
|
||||
|
||||
#ifndef LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
|
||||
#define LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/bin.h>
|
||||
#include <LightGBM/meta.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/utils/threading.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "cuda_leaf_splits.hpp"
|
||||
#include "../gradient_discretizer.hpp"
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
#define CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE (1024)
|
||||
|
||||
class CUDAGradientDiscretizer: public GradientDiscretizer {
|
||||
public:
|
||||
CUDAGradientDiscretizer(int num_grad_quant_bins, int num_trees, int random_seed, bool is_constant_hessian, bool stochastic_roudning):
|
||||
GradientDiscretizer(num_grad_quant_bins, num_trees, random_seed, is_constant_hessian, stochastic_roudning) {
|
||||
}
|
||||
|
||||
void DiscretizeGradients(
|
||||
const data_size_t num_data,
|
||||
const score_t* input_gradients,
|
||||
const score_t* input_hessians) override;
|
||||
|
||||
const int8_t* discretized_gradients_and_hessians() const override { return discretized_gradients_and_hessians_.RawData(); }
|
||||
|
||||
double grad_scale() const override {
|
||||
Log::Fatal("grad_scale() of CUDAGradientDiscretizer should not be called.");
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
double hess_scale() const override {
|
||||
Log::Fatal("hess_scale() of CUDAGradientDiscretizer should not be called.");
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
const score_t* grad_scale_ptr() const { return grad_max_block_buffer_.RawData(); }
|
||||
|
||||
const score_t* hess_scale_ptr() const { return hess_max_block_buffer_.RawData(); }
|
||||
|
||||
void Init(const data_size_t num_data, const int num_leaves,
|
||||
const int num_features, const Dataset* train_data) override {
|
||||
GradientDiscretizer::Init(num_data, num_leaves, num_features, train_data);
|
||||
discretized_gradients_and_hessians_.Resize(num_data * 2);
|
||||
num_reduce_blocks_ = (num_data + CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE - 1) / CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE;
|
||||
grad_min_block_buffer_.Resize(num_reduce_blocks_);
|
||||
grad_max_block_buffer_.Resize(num_reduce_blocks_);
|
||||
hess_min_block_buffer_.Resize(num_reduce_blocks_);
|
||||
hess_max_block_buffer_.Resize(num_reduce_blocks_);
|
||||
random_values_use_start_.Resize(num_trees_);
|
||||
gradient_random_values_.Resize(num_data);
|
||||
hessian_random_values_.Resize(num_data);
|
||||
|
||||
std::vector<score_t> gradient_random_values(num_data, 0.0f);
|
||||
std::vector<score_t> hessian_random_values(num_data, 0.0f);
|
||||
std::vector<int> random_values_use_start(num_trees_, 0);
|
||||
|
||||
const int num_threads = OMP_NUM_THREADS();
|
||||
|
||||
std::mt19937 random_values_use_start_eng = std::mt19937(random_seed_);
|
||||
std::uniform_int_distribution<data_size_t> random_values_use_start_dist = std::uniform_int_distribution<data_size_t>(0, num_data);
|
||||
for (int tree_index = 0; tree_index < num_trees_; ++tree_index) {
|
||||
random_values_use_start[tree_index] = random_values_use_start_dist(random_values_use_start_eng);
|
||||
}
|
||||
|
||||
int num_blocks = 0;
|
||||
data_size_t block_size = 0;
|
||||
Threading::BlockInfo<data_size_t>(num_data, 512, &num_blocks, &block_size);
|
||||
#pragma omp parallel for schedule(static, 1) num_threads(num_threads)
|
||||
for (int thread_id = 0; thread_id < num_blocks; ++thread_id) {
|
||||
const data_size_t start = thread_id * block_size;
|
||||
const data_size_t end = std::min(start + block_size, num_data);
|
||||
std::mt19937 gradient_random_values_eng(random_seed_ + thread_id);
|
||||
std::uniform_real_distribution<double> gradient_random_values_dist(0.0f, 1.0f);
|
||||
std::mt19937 hessian_random_values_eng(random_seed_ + thread_id + num_threads);
|
||||
std::uniform_real_distribution<double> hessian_random_values_dist(0.0f, 1.0f);
|
||||
for (data_size_t i = start; i < end; ++i) {
|
||||
gradient_random_values[i] = gradient_random_values_dist(gradient_random_values_eng);
|
||||
hessian_random_values[i] = hessian_random_values_dist(hessian_random_values_eng);
|
||||
}
|
||||
}
|
||||
|
||||
CopyFromHostToCUDADevice<score_t>(gradient_random_values_.RawData(), gradient_random_values.data(), gradient_random_values.size(), __FILE__, __LINE__);
|
||||
CopyFromHostToCUDADevice<score_t>(hessian_random_values_.RawData(), hessian_random_values.data(), hessian_random_values.size(), __FILE__, __LINE__);
|
||||
CopyFromHostToCUDADevice<int>(random_values_use_start_.RawData(), random_values_use_start.data(), random_values_use_start.size(), __FILE__, __LINE__);
|
||||
iter_ = 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
mutable CUDAVector<int8_t> discretized_gradients_and_hessians_;
|
||||
mutable CUDAVector<score_t> grad_min_block_buffer_;
|
||||
mutable CUDAVector<score_t> grad_max_block_buffer_;
|
||||
mutable CUDAVector<score_t> hess_min_block_buffer_;
|
||||
mutable CUDAVector<score_t> hess_max_block_buffer_;
|
||||
CUDAVector<int> random_values_use_start_;
|
||||
CUDAVector<score_t> gradient_random_values_;
|
||||
CUDAVector<score_t> hessian_random_values_;
|
||||
int num_reduce_blocks_;
|
||||
};
|
||||
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // USE_CUDA
|
||||
#endif // LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
|
|
@ -20,7 +20,9 @@ CUDAHistogramConstructor::CUDAHistogramConstructor(
|
|||
const int min_data_in_leaf,
|
||||
const double min_sum_hessian_in_leaf,
|
||||
const int gpu_device_id,
|
||||
const bool gpu_use_dp):
|
||||
const bool gpu_use_dp,
|
||||
const bool use_quantized_grad,
|
||||
const int num_grad_quant_bins):
|
||||
num_data_(train_data->num_data()),
|
||||
num_features_(train_data->num_features()),
|
||||
num_leaves_(num_leaves),
|
||||
|
@ -28,24 +30,14 @@ CUDAHistogramConstructor::CUDAHistogramConstructor(
|
|||
min_data_in_leaf_(min_data_in_leaf),
|
||||
min_sum_hessian_in_leaf_(min_sum_hessian_in_leaf),
|
||||
gpu_device_id_(gpu_device_id),
|
||||
gpu_use_dp_(gpu_use_dp) {
|
||||
gpu_use_dp_(gpu_use_dp),
|
||||
use_quantized_grad_(use_quantized_grad),
|
||||
num_grad_quant_bins_(num_grad_quant_bins) {
|
||||
InitFeatureMetaInfo(train_data, feature_hist_offsets);
|
||||
cuda_row_data_.reset(nullptr);
|
||||
cuda_feature_num_bins_ = nullptr;
|
||||
cuda_feature_hist_offsets_ = nullptr;
|
||||
cuda_feature_most_freq_bins_ = nullptr;
|
||||
cuda_hist_ = nullptr;
|
||||
cuda_need_fix_histogram_features_ = nullptr;
|
||||
cuda_need_fix_histogram_features_num_bin_aligned_ = nullptr;
|
||||
}
|
||||
|
||||
CUDAHistogramConstructor::~CUDAHistogramConstructor() {
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
|
||||
gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
|
@ -84,54 +76,70 @@ void CUDAHistogramConstructor::InitFeatureMetaInfo(const Dataset* train_data, co
|
|||
void CUDAHistogramConstructor::BeforeTrain(const score_t* gradients, const score_t* hessians) {
|
||||
cuda_gradients_ = gradients;
|
||||
cuda_hessians_ = hessians;
|
||||
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
|
||||
cuda_hist_.SetValue(0);
|
||||
}
|
||||
|
||||
void CUDAHistogramConstructor::Init(const Dataset* train_data, TrainingShareStates* share_state) {
|
||||
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
|
||||
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
|
||||
cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
|
||||
cuda_hist_.SetValue(0);
|
||||
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_,
|
||||
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__);
|
||||
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
|
||||
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
|
||||
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_,
|
||||
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__);
|
||||
cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_);
|
||||
cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_);
|
||||
cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_);
|
||||
|
||||
cuda_row_data_.reset(new CUDARowData(train_data, share_state, gpu_device_id_, gpu_use_dp_));
|
||||
cuda_row_data_->Init(train_data, share_state);
|
||||
|
||||
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_));
|
||||
|
||||
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__);
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(),
|
||||
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
|
||||
cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_);
|
||||
cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_);
|
||||
|
||||
if (cuda_row_data_->NumLargeBinPartition() > 0) {
|
||||
int grid_dim_x = 0, grid_dim_y = 0, block_dim_x = 0, block_dim_y = 0;
|
||||
CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_);
|
||||
const size_t buffer_size = static_cast<size_t>(grid_dim_y) * static_cast<size_t>(num_total_bin_) * 2;
|
||||
AllocateCUDAMemory<float>(&cuda_hist_buffer_, buffer_size, __FILE__, __LINE__);
|
||||
const size_t buffer_size = static_cast<size_t>(grid_dim_y) * static_cast<size_t>(num_total_bin_);
|
||||
if (!use_quantized_grad_) {
|
||||
if (gpu_use_dp_) {
|
||||
// need to double the size of histogram buffer in global memory when using double precision in histogram construction
|
||||
cuda_hist_buffer_.Resize(buffer_size * 4);
|
||||
} else {
|
||||
cuda_hist_buffer_.Resize(buffer_size * 2);
|
||||
}
|
||||
} else {
|
||||
// use only half the size of histogram buffer in global memory when quantized training since each gradient and hessian takes only 2 bytes
|
||||
cuda_hist_buffer_.Resize(buffer_size);
|
||||
}
|
||||
}
|
||||
hist_buffer_for_num_bit_change_.Resize(num_total_bin_ * 2);
|
||||
}
|
||||
|
||||
void CUDAHistogramConstructor::ConstructHistogramForLeaf(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
|
||||
const CUDALeafSplitsStruct* /*cuda_larger_leaf_splits*/,
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const data_size_t num_data_in_larger_leaf,
|
||||
const double sum_hessians_in_smaller_leaf,
|
||||
const double sum_hessians_in_larger_leaf) {
|
||||
const double sum_hessians_in_larger_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins) {
|
||||
if ((num_data_in_smaller_leaf <= min_data_in_leaf_ || sum_hessians_in_smaller_leaf <= min_sum_hessian_in_leaf_) &&
|
||||
(num_data_in_larger_leaf <= min_data_in_leaf_ || sum_hessians_in_larger_leaf <= min_sum_hessian_in_leaf_)) {
|
||||
return;
|
||||
}
|
||||
LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
void CUDAHistogramConstructor::SubtractHistogramForLeaf(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
|
||||
const bool use_quantized_grad,
|
||||
const uint8_t parent_num_bits_in_histogram_bins,
|
||||
const uint8_t smaller_num_bits_in_histogram_bins,
|
||||
const uint8_t larger_num_bits_in_histogram_bins) {
|
||||
global_timer.Start("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
|
||||
LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits);
|
||||
LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits, use_quantized_grad,
|
||||
parent_num_bits_in_histogram_bins, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins);
|
||||
global_timer.Stop("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
|
||||
}
|
||||
|
||||
|
@ -152,33 +160,18 @@ void CUDAHistogramConstructor::ResetTrainingData(const Dataset* train_data, Trai
|
|||
num_data_ = train_data->num_data();
|
||||
num_features_ = train_data->num_features();
|
||||
InitFeatureMetaInfo(train_data, share_states->feature_hist_offsets());
|
||||
if (feature_num_bins_.size() > 0) {
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
|
||||
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
|
||||
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_,
|
||||
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__);
|
||||
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
|
||||
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
|
||||
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_,
|
||||
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__);
|
||||
cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
|
||||
cuda_hist_.SetValue(0);
|
||||
cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_);
|
||||
cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_);
|
||||
cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_);
|
||||
|
||||
cuda_row_data_.reset(new CUDARowData(train_data, share_states, gpu_device_id_, gpu_use_dp_));
|
||||
cuda_row_data_->Init(train_data, share_states);
|
||||
|
||||
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__);
|
||||
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(),
|
||||
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
|
||||
cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_);
|
||||
cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_);
|
||||
}
|
||||
|
||||
void CUDAHistogramConstructor::ResetConfig(const Config* config) {
|
||||
|
@ -186,9 +179,8 @@ void CUDAHistogramConstructor::ResetConfig(const Config* config) {
|
|||
num_leaves_ = config->num_leaves;
|
||||
min_data_in_leaf_ = config->min_data_in_leaf;
|
||||
min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf;
|
||||
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
|
||||
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
|
||||
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
|
||||
cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
|
||||
cuda_hist_.SetValue(0);
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -125,7 +125,7 @@ __global__ void CUDAConstructHistogramSparseKernel(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename BIN_TYPE>
|
||||
template <typename BIN_TYPE, typename HIST_TYPE>
|
||||
__global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits,
|
||||
const score_t* cuda_gradients,
|
||||
|
@ -135,7 +135,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
|
|||
const uint32_t* column_hist_offsets_full,
|
||||
const int* feature_partition_column_index_offsets,
|
||||
const data_size_t num_data,
|
||||
float* global_hist_buffer) {
|
||||
HIST_TYPE* global_hist_buffer) {
|
||||
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
|
||||
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
|
||||
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
|
||||
|
@ -150,7 +150,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
|
|||
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
|
||||
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const int num_total_bin = column_hist_offsets_full[gridDim.x];
|
||||
float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
|
||||
HIST_TYPE* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
shared_hist[i] = 0.0f;
|
||||
}
|
||||
|
@ -166,14 +166,14 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
|
|||
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
|
||||
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
|
||||
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
|
||||
float* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1);
|
||||
HIST_TYPE* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1);
|
||||
for (data_size_t i = 0; i < num_iteration_this; ++i) {
|
||||
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
|
||||
const score_t grad = cuda_gradients[data_index];
|
||||
const score_t hess = cuda_hessians[data_index];
|
||||
const uint32_t bin = static_cast<uint32_t>(data_ptr[static_cast<size_t>(data_index) * num_columns_in_partition + threadIdx.x]);
|
||||
const uint32_t pos = bin << 1;
|
||||
float* pos_ptr = shared_hist_ptr + pos;
|
||||
HIST_TYPE* pos_ptr = shared_hist_ptr + pos;
|
||||
atomicAdd_block(pos_ptr, grad);
|
||||
atomicAdd_block(pos_ptr + 1, hess);
|
||||
inner_data_index += blockDim.y;
|
||||
|
@ -186,7 +186,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename BIN_TYPE, typename DATA_PTR_TYPE>
|
||||
template <typename BIN_TYPE, typename HIST_TYPE, typename DATA_PTR_TYPE>
|
||||
__global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits,
|
||||
const score_t* cuda_gradients,
|
||||
|
@ -196,7 +196,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
|
|||
const DATA_PTR_TYPE* partition_ptr,
|
||||
const uint32_t* column_hist_offsets_full,
|
||||
const data_size_t num_data,
|
||||
float* global_hist_buffer) {
|
||||
HIST_TYPE* global_hist_buffer) {
|
||||
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
|
||||
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
|
||||
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
|
||||
|
@ -209,7 +209,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
|
|||
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
|
||||
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const int num_total_bin = column_hist_offsets_full[gridDim.x];
|
||||
float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
|
||||
HIST_TYPE* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
shared_hist[i] = 0.0f;
|
||||
}
|
||||
|
@ -233,7 +233,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
|
|||
const score_t hess = cuda_hessians[data_index];
|
||||
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
|
||||
const uint32_t pos = bin << 1;
|
||||
float* pos_ptr = shared_hist + pos;
|
||||
HIST_TYPE* pos_ptr = shared_hist + pos;
|
||||
atomicAdd_block(pos_ptr, grad);
|
||||
atomicAdd_block(pos_ptr + 1, hess);
|
||||
}
|
||||
|
@ -246,13 +246,278 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename BIN_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
|
||||
__global__ void CUDAConstructDiscretizedHistogramDenseKernel(
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits,
|
||||
const int32_t* cuda_gradients_and_hessians,
|
||||
const BIN_TYPE* data,
|
||||
const uint32_t* column_hist_offsets,
|
||||
const uint32_t* column_hist_offsets_full,
|
||||
const int* feature_partition_column_index_offsets,
|
||||
const data_size_t num_data) {
|
||||
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
|
||||
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
|
||||
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
|
||||
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
|
||||
__shared__ int16_t shared_hist[SHARED_HIST_SIZE];
|
||||
int32_t* shared_hist_packed = reinterpret_cast<int32_t*>(shared_hist);
|
||||
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
|
||||
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
|
||||
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
|
||||
const BIN_TYPE* data_ptr = data + partition_column_start * num_data;
|
||||
const int num_columns_in_partition = partition_column_end - partition_column_start;
|
||||
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
|
||||
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
|
||||
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
|
||||
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
shared_hist_packed[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
const unsigned int threadIdx_y = threadIdx.y;
|
||||
const unsigned int blockIdx_y = blockIdx.y;
|
||||
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
|
||||
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
|
||||
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
|
||||
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
|
||||
const data_size_t remainder = block_num_data % blockDim.y;
|
||||
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
|
||||
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
|
||||
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
|
||||
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
|
||||
int32_t* shared_hist_ptr = shared_hist_packed + (column_hist_offsets[column_index]);
|
||||
for (data_size_t i = 0; i < num_iteration_this; ++i) {
|
||||
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
|
||||
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
|
||||
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]);
|
||||
int32_t* pos_ptr = shared_hist_ptr + bin;
|
||||
atomicAdd_block(pos_ptr, grad_and_hess);
|
||||
inner_data_index += blockDim.y;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (USE_16BIT_HIST) {
|
||||
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
|
||||
}
|
||||
} else {
|
||||
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
|
||||
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BIN_TYPE, typename DATA_PTR_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
|
||||
__global__ void CUDAConstructDiscretizedHistogramSparseKernel(
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits,
|
||||
const int32_t* cuda_gradients_and_hessians,
|
||||
const BIN_TYPE* data,
|
||||
const DATA_PTR_TYPE* row_ptr,
|
||||
const DATA_PTR_TYPE* partition_ptr,
|
||||
const uint32_t* column_hist_offsets_full,
|
||||
const data_size_t num_data) {
|
||||
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
|
||||
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
|
||||
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
|
||||
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
|
||||
__shared__ int16_t shared_hist[SHARED_HIST_SIZE];
|
||||
int32_t* shared_hist_packed = reinterpret_cast<int32_t*>(shared_hist);
|
||||
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
|
||||
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1);
|
||||
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
|
||||
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
|
||||
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
|
||||
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
|
||||
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
shared_hist_packed[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
const unsigned int threadIdx_y = threadIdx.y;
|
||||
const unsigned int blockIdx_y = blockIdx.y;
|
||||
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
|
||||
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
|
||||
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
|
||||
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
|
||||
const data_size_t remainder = block_num_data % blockDim.y;
|
||||
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
|
||||
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
|
||||
for (data_size_t i = 0; i < num_iteration_this; ++i) {
|
||||
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
|
||||
const DATA_PTR_TYPE row_start = block_row_ptr[data_index];
|
||||
const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1];
|
||||
const DATA_PTR_TYPE row_size = row_end - row_start;
|
||||
if (threadIdx.x < row_size) {
|
||||
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
|
||||
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
|
||||
int32_t* pos_ptr = shared_hist_packed + bin;
|
||||
atomicAdd_block(pos_ptr, grad_and_hess);
|
||||
}
|
||||
inner_data_index += blockDim.y;
|
||||
}
|
||||
__syncthreads();
|
||||
if (USE_16BIT_HIST) {
|
||||
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
|
||||
}
|
||||
} else {
|
||||
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
|
||||
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BIN_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
|
||||
__global__ void CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory(
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits,
|
||||
const int32_t* cuda_gradients_and_hessians,
|
||||
const BIN_TYPE* data,
|
||||
const uint32_t* column_hist_offsets,
|
||||
const uint32_t* column_hist_offsets_full,
|
||||
const int* feature_partition_column_index_offsets,
|
||||
const data_size_t num_data,
|
||||
int32_t* global_hist_buffer) {
|
||||
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
|
||||
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
|
||||
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
|
||||
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
|
||||
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
|
||||
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
|
||||
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
|
||||
const BIN_TYPE* data_ptr = data + partition_column_start * num_data;
|
||||
const int num_columns_in_partition = partition_column_end - partition_column_start;
|
||||
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
|
||||
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
|
||||
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
|
||||
const int num_total_bin = column_hist_offsets_full[gridDim.x];
|
||||
int32_t* shared_hist_packed = global_hist_buffer + (blockIdx.y * num_total_bin + partition_column_start);
|
||||
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
shared_hist_packed[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
const unsigned int threadIdx_y = threadIdx.y;
|
||||
const unsigned int blockIdx_y = blockIdx.y;
|
||||
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
|
||||
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
|
||||
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
|
||||
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
|
||||
const data_size_t remainder = block_num_data % blockDim.y;
|
||||
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
|
||||
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
|
||||
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
|
||||
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
|
||||
int32_t* shared_hist_ptr = shared_hist_packed + (column_hist_offsets[column_index]);
|
||||
for (data_size_t i = 0; i < num_iteration_this; ++i) {
|
||||
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
|
||||
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
|
||||
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]);
|
||||
int32_t* pos_ptr = shared_hist_ptr + bin;
|
||||
atomicAdd_block(pos_ptr, grad_and_hess);
|
||||
inner_data_index += blockDim.y;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (USE_16BIT_HIST) {
|
||||
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
|
||||
}
|
||||
} else {
|
||||
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
|
||||
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BIN_TYPE, typename DATA_PTR_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
|
||||
__global__ void CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory(
|
||||
const CUDALeafSplitsStruct* smaller_leaf_splits,
|
||||
const int32_t* cuda_gradients_and_hessians,
|
||||
const BIN_TYPE* data,
|
||||
const DATA_PTR_TYPE* row_ptr,
|
||||
const DATA_PTR_TYPE* partition_ptr,
|
||||
const uint32_t* column_hist_offsets_full,
|
||||
const data_size_t num_data,
|
||||
int32_t* global_hist_buffer) {
|
||||
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
|
||||
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
|
||||
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
|
||||
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
|
||||
const int num_total_bin = column_hist_offsets_full[gridDim.x];
|
||||
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
|
||||
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1);
|
||||
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
|
||||
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
|
||||
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
|
||||
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
|
||||
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int32_t* shared_hist_packed = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start);
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
shared_hist_packed[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
const unsigned int threadIdx_y = threadIdx.y;
|
||||
const unsigned int blockIdx_y = blockIdx.y;
|
||||
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
|
||||
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
|
||||
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
|
||||
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
|
||||
const data_size_t remainder = block_num_data % blockDim.y;
|
||||
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
|
||||
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
|
||||
for (data_size_t i = 0; i < num_iteration_this; ++i) {
|
||||
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
|
||||
const DATA_PTR_TYPE row_start = block_row_ptr[data_index];
|
||||
const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1];
|
||||
const DATA_PTR_TYPE row_size = row_end - row_start;
|
||||
if (threadIdx.x < row_size) {
|
||||
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
|
||||
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
|
||||
int32_t* pos_ptr = shared_hist_packed + bin;
|
||||
atomicAdd_block(pos_ptr, grad_and_hess);
|
||||
}
|
||||
inner_data_index += blockDim.y;
|
||||
}
|
||||
__syncthreads();
|
||||
if (USE_16BIT_HIST) {
|
||||
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
|
||||
}
|
||||
} else {
|
||||
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
|
||||
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
|
||||
const int32_t packed_grad_hess = shared_hist_packed[i];
|
||||
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
|
||||
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAHistogramConstructor::LaunchConstructHistogramKernel(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf) {
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins) {
|
||||
if (cuda_row_data_->shared_hist_size() == DP_SHARED_HIST_SIZE && gpu_use_dp_) {
|
||||
LaunchConstructHistogramKernelInner<double, DP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner<double, DP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else if (cuda_row_data_->shared_hist_size() == SP_SHARED_HIST_SIZE && !gpu_use_dp_) {
|
||||
LaunchConstructHistogramKernelInner<float, SP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner<float, SP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else {
|
||||
Log::Fatal("Unknown shared histogram size %d", cuda_row_data_->shared_hist_size());
|
||||
}
|
||||
|
@ -261,13 +526,14 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernel(
|
|||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE>
|
||||
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf) {
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins) {
|
||||
if (cuda_row_data_->bit_type() == 8) {
|
||||
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint8_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint8_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else if (cuda_row_data_->bit_type() == 16) {
|
||||
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else if (cuda_row_data_->bit_type() == 32) {
|
||||
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else {
|
||||
Log::Fatal("Unknown bit_type = %d", cuda_row_data_->bit_type());
|
||||
}
|
||||
|
@ -276,16 +542,17 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner(
|
|||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE>
|
||||
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf) {
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins) {
|
||||
if (cuda_row_data_->row_ptr_bit_type() == 16) {
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else if (cuda_row_data_->row_ptr_bit_type() == 32) {
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else if (cuda_row_data_->row_ptr_bit_type() == 64) {
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint64_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint64_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else {
|
||||
if (!cuda_row_data_->is_sparse()) {
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else {
|
||||
Log::Fatal("Unknown row_ptr_bit_type = %d", cuda_row_data_->row_ptr_bit_type());
|
||||
}
|
||||
|
@ -295,18 +562,20 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0(
|
|||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE>
|
||||
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner1(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf) {
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins) {
|
||||
if (cuda_row_data_->NumLargeBinPartition() == 0) {
|
||||
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, false>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, false>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
} else {
|
||||
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, true>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
|
||||
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, true>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE, bool USE_GLOBAL_MEM_BUFFER>
|
||||
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf) {
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins) {
|
||||
int grid_dim_x = 0;
|
||||
int grid_dim_y = 0;
|
||||
int block_dim_x = 0;
|
||||
|
@ -314,6 +583,97 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
|
|||
CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_in_smaller_leaf);
|
||||
dim3 grid_dim(grid_dim_x, grid_dim_y);
|
||||
dim3 block_dim(block_dim_x, block_dim_y);
|
||||
if (use_quantized_grad_) {
|
||||
if (USE_GLOBAL_MEM_BUFFER) {
|
||||
if (cuda_row_data_->is_sparse()) {
|
||||
if (num_bits_in_histogram_bins <= 16) {
|
||||
CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
num_data_,
|
||||
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
|
||||
} else {
|
||||
CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
num_data_,
|
||||
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
|
||||
}
|
||||
} else {
|
||||
if (num_bits_in_histogram_bins <= 16) {
|
||||
CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory<BIN_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->cuda_column_hist_offsets(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
|
||||
num_data_,
|
||||
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
|
||||
} else {
|
||||
CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory<BIN_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->cuda_column_hist_offsets(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
|
||||
num_data_,
|
||||
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (cuda_row_data_->is_sparse()) {
|
||||
if (num_bits_in_histogram_bins <= 16) {
|
||||
CUDAConstructDiscretizedHistogramSparseKernel<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
num_data_);
|
||||
} else {
|
||||
CUDAConstructDiscretizedHistogramSparseKernel<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
num_data_);
|
||||
}
|
||||
} else {
|
||||
if (num_bits_in_histogram_bins <= 16) {
|
||||
CUDAConstructDiscretizedHistogramDenseKernel<BIN_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->cuda_column_hist_offsets(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
|
||||
num_data_);
|
||||
} else {
|
||||
CUDAConstructDiscretizedHistogramDenseKernel<BIN_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
reinterpret_cast<const int32_t*>(cuda_gradients_),
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
cuda_row_data_->cuda_column_hist_offsets(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
|
||||
num_data_);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!USE_GLOBAL_MEM_BUFFER) {
|
||||
if (cuda_row_data_->is_sparse()) {
|
||||
CUDAConstructHistogramSparseKernel<BIN_TYPE, PTR_TYPE, HIST_TYPE, SHARED_HIST_SIZE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
|
@ -336,7 +696,7 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
|
|||
}
|
||||
} else {
|
||||
if (cuda_row_data_->is_sparse()) {
|
||||
CUDAConstructHistogramSparseKernel_GlobalMemory<BIN_TYPE, PTR_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
CUDAConstructHistogramSparseKernel_GlobalMemory<BIN_TYPE, HIST_TYPE, PTR_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
cuda_gradients_, cuda_hessians_,
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
|
@ -344,9 +704,9 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
|
|||
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
|
||||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
num_data_,
|
||||
cuda_hist_buffer_);
|
||||
reinterpret_cast<HIST_TYPE*>(cuda_hist_buffer_.RawData()));
|
||||
} else {
|
||||
CUDAConstructHistogramDenseKernel_GlobalMemory<BIN_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
CUDAConstructHistogramDenseKernel_GlobalMemory<BIN_TYPE, HIST_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
|
||||
cuda_smaller_leaf_splits,
|
||||
cuda_gradients_, cuda_hessians_,
|
||||
cuda_row_data_->GetBin<BIN_TYPE>(),
|
||||
|
@ -354,7 +714,8 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
|
|||
cuda_row_data_->cuda_partition_hist_offsets(),
|
||||
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
|
||||
num_data_,
|
||||
cuda_hist_buffer_);
|
||||
reinterpret_cast<HIST_TYPE*>(cuda_hist_buffer_.RawData()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -403,19 +764,127 @@ __global__ void FixHistogramKernel(
|
|||
}
|
||||
}
|
||||
|
||||
template <bool SMALLER_USE_16BIT_HIST, bool LARGER_USE_16BIT_HIST, bool PARENT_USE_16BIT_HIST>
|
||||
__global__ void SubtractHistogramDiscretizedKernel(
|
||||
const int num_total_bin,
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
|
||||
hist_t* num_bit_change_buffer) {
|
||||
const unsigned int global_thread_index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int cuda_larger_leaf_index_ref = cuda_larger_leaf_splits->leaf_index;
|
||||
if (cuda_larger_leaf_index_ref >= 0) {
|
||||
if (PARENT_USE_16BIT_HIST) {
|
||||
const int32_t* smaller_leaf_hist = reinterpret_cast<const int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
|
||||
int32_t* larger_leaf_hist = reinterpret_cast<int32_t*>(cuda_larger_leaf_splits->hist_in_leaf);
|
||||
if (global_thread_index < num_total_bin) {
|
||||
larger_leaf_hist[global_thread_index] -= smaller_leaf_hist[global_thread_index];
|
||||
}
|
||||
} else if (LARGER_USE_16BIT_HIST) {
|
||||
int32_t* buffer = reinterpret_cast<int32_t*>(num_bit_change_buffer);
|
||||
const int32_t* smaller_leaf_hist = reinterpret_cast<const int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
|
||||
int64_t* larger_leaf_hist = reinterpret_cast<int64_t*>(cuda_larger_leaf_splits->hist_in_leaf);
|
||||
if (global_thread_index < num_total_bin) {
|
||||
const int64_t parent_hist_item = larger_leaf_hist[global_thread_index];
|
||||
const int32_t smaller_hist_item = smaller_leaf_hist[global_thread_index];
|
||||
const int64_t smaller_hist_item_int64 = (static_cast<int64_t>(static_cast<int16_t>(smaller_hist_item >> 16)) << 32) |
|
||||
static_cast<int64_t>(smaller_hist_item & 0x0000ffff);
|
||||
const int64_t larger_hist_item = parent_hist_item - smaller_hist_item_int64;
|
||||
buffer[global_thread_index] = static_cast<int32_t>(static_cast<int16_t>(larger_hist_item >> 32) << 16) |
|
||||
static_cast<int32_t>(larger_hist_item & 0x000000000000ffff);
|
||||
}
|
||||
} else if (SMALLER_USE_16BIT_HIST) {
|
||||
const int32_t* smaller_leaf_hist = reinterpret_cast<const int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
|
||||
int64_t* larger_leaf_hist = reinterpret_cast<int64_t*>(cuda_larger_leaf_splits->hist_in_leaf);
|
||||
if (global_thread_index < num_total_bin) {
|
||||
const int64_t parent_hist_item = larger_leaf_hist[global_thread_index];
|
||||
const int32_t smaller_hist_item = smaller_leaf_hist[global_thread_index];
|
||||
const int64_t smaller_hist_item_int64 = (static_cast<int64_t>(static_cast<int16_t>(smaller_hist_item >> 16)) << 32) |
|
||||
static_cast<int64_t>(smaller_hist_item & 0x0000ffff);
|
||||
const int64_t larger_hist_item = parent_hist_item - smaller_hist_item_int64;
|
||||
larger_leaf_hist[global_thread_index] = larger_hist_item;
|
||||
}
|
||||
} else {
|
||||
const int64_t* smaller_leaf_hist = reinterpret_cast<const int64_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
|
||||
int64_t* larger_leaf_hist = reinterpret_cast<int64_t*>(cuda_larger_leaf_splits->hist_in_leaf);
|
||||
if (global_thread_index < num_total_bin) {
|
||||
larger_leaf_hist[global_thread_index] -= smaller_leaf_hist[global_thread_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CopyChangedNumBitHistogram(
|
||||
const int num_total_bin,
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
|
||||
hist_t* num_bit_change_buffer) {
|
||||
int32_t* hist_dst = reinterpret_cast<int32_t*>(cuda_larger_leaf_splits->hist_in_leaf);
|
||||
const int32_t* hist_src = reinterpret_cast<const int32_t*>(num_bit_change_buffer);
|
||||
const unsigned int global_thread_index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (global_thread_index < static_cast<unsigned int>(num_total_bin)) {
|
||||
hist_dst[global_thread_index] = hist_src[global_thread_index];
|
||||
}
|
||||
}
|
||||
|
||||
template <bool USE_16BIT_HIST>
|
||||
__global__ void FixHistogramDiscretizedKernel(
|
||||
const uint32_t* cuda_feature_num_bins,
|
||||
const uint32_t* cuda_feature_hist_offsets,
|
||||
const uint32_t* cuda_feature_most_freq_bins,
|
||||
const int* cuda_need_fix_histogram_features,
|
||||
const uint32_t* cuda_need_fix_histogram_features_num_bin_aligned,
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits) {
|
||||
__shared__ int64_t shared_mem_buffer[32];
|
||||
const unsigned int blockIdx_x = blockIdx.x;
|
||||
const int feature_index = cuda_need_fix_histogram_features[blockIdx_x];
|
||||
const uint32_t num_bin_aligned = cuda_need_fix_histogram_features_num_bin_aligned[blockIdx_x];
|
||||
const uint32_t feature_hist_offset = cuda_feature_hist_offsets[feature_index];
|
||||
const uint32_t most_freq_bin = cuda_feature_most_freq_bins[feature_index];
|
||||
if (USE_16BIT_HIST) {
|
||||
const int64_t leaf_sum_gradients_hessians_int64 = cuda_smaller_leaf_splits->sum_of_gradients_hessians;
|
||||
const int32_t leaf_sum_gradients_hessians =
|
||||
(static_cast<int32_t>(leaf_sum_gradients_hessians_int64 >> 32) << 16) | static_cast<int32_t>(leaf_sum_gradients_hessians_int64 & 0x000000000000ffff);
|
||||
int32_t* feature_hist = reinterpret_cast<int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf) + feature_hist_offset;
|
||||
const unsigned int threadIdx_x = threadIdx.x;
|
||||
const uint32_t num_bin = cuda_feature_num_bins[feature_index];
|
||||
const int32_t bin_gradient_hessian = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[threadIdx_x] : 0;
|
||||
const int32_t sum_gradient_hessian = ShuffleReduceSum<int32_t>(
|
||||
bin_gradient_hessian,
|
||||
reinterpret_cast<int32_t*>(shared_mem_buffer),
|
||||
num_bin_aligned);
|
||||
if (threadIdx_x == 0) {
|
||||
feature_hist[most_freq_bin] = leaf_sum_gradients_hessians - sum_gradient_hessian;
|
||||
}
|
||||
} else {
|
||||
const int64_t leaf_sum_gradients_hessians = cuda_smaller_leaf_splits->sum_of_gradients_hessians;
|
||||
int64_t* feature_hist = reinterpret_cast<int64_t*>(cuda_smaller_leaf_splits->hist_in_leaf) + feature_hist_offset;
|
||||
const unsigned int threadIdx_x = threadIdx.x;
|
||||
const uint32_t num_bin = cuda_feature_num_bins[feature_index];
|
||||
const int64_t bin_gradient_hessian = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[threadIdx_x] : 0;
|
||||
const int64_t sum_gradient_hessian = ShuffleReduceSum<int64_t>(bin_gradient_hessian, shared_mem_buffer, num_bin_aligned);
|
||||
if (threadIdx_x == 0) {
|
||||
feature_hist[most_freq_bin] = leaf_sum_gradients_hessians - sum_gradient_hessian;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAHistogramConstructor::LaunchSubtractHistogramKernel(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits) {
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
|
||||
const bool use_discretized_grad,
|
||||
const uint8_t parent_num_bits_in_histogram_bins,
|
||||
const uint8_t smaller_num_bits_in_histogram_bins,
|
||||
const uint8_t larger_num_bits_in_histogram_bins) {
|
||||
if (!use_discretized_grad) {
|
||||
const int num_subtract_threads = 2 * num_total_bin_;
|
||||
const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE;
|
||||
global_timer.Start("CUDAHistogramConstructor::FixHistogramKernel");
|
||||
if (need_fix_histogram_features_.size() > 0) {
|
||||
FixHistogramKernel<<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
cuda_feature_num_bins_,
|
||||
cuda_feature_hist_offsets_,
|
||||
cuda_feature_most_freq_bins_,
|
||||
cuda_need_fix_histogram_features_,
|
||||
cuda_need_fix_histogram_features_num_bin_aligned_,
|
||||
cuda_feature_num_bins_.RawData(),
|
||||
cuda_feature_hist_offsets_.RawData(),
|
||||
cuda_feature_most_freq_bins_.RawData(),
|
||||
cuda_need_fix_histogram_features_.RawData(),
|
||||
cuda_need_fix_histogram_features_num_bin_aligned_.RawData(),
|
||||
cuda_smaller_leaf_splits);
|
||||
}
|
||||
global_timer.Stop("CUDAHistogramConstructor::FixHistogramKernel");
|
||||
|
@ -425,6 +894,65 @@ void CUDAHistogramConstructor::LaunchSubtractHistogramKernel(
|
|||
cuda_smaller_leaf_splits,
|
||||
cuda_larger_leaf_splits);
|
||||
global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramKernel");
|
||||
} else {
|
||||
const int num_subtract_threads = num_total_bin_;
|
||||
const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE;
|
||||
global_timer.Start("CUDAHistogramConstructor::FixHistogramDiscretizedKernel");
|
||||
if (need_fix_histogram_features_.size() > 0) {
|
||||
if (smaller_num_bits_in_histogram_bins <= 16) {
|
||||
FixHistogramDiscretizedKernel<true><<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
cuda_feature_num_bins_.RawData(),
|
||||
cuda_feature_hist_offsets_.RawData(),
|
||||
cuda_feature_most_freq_bins_.RawData(),
|
||||
cuda_need_fix_histogram_features_.RawData(),
|
||||
cuda_need_fix_histogram_features_num_bin_aligned_.RawData(),
|
||||
cuda_smaller_leaf_splits);
|
||||
} else {
|
||||
FixHistogramDiscretizedKernel<false><<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
cuda_feature_num_bins_.RawData(),
|
||||
cuda_feature_hist_offsets_.RawData(),
|
||||
cuda_feature_most_freq_bins_.RawData(),
|
||||
cuda_need_fix_histogram_features_.RawData(),
|
||||
cuda_need_fix_histogram_features_num_bin_aligned_.RawData(),
|
||||
cuda_smaller_leaf_splits);
|
||||
}
|
||||
}
|
||||
global_timer.Stop("CUDAHistogramConstructor::FixHistogramDiscretizedKernel");
|
||||
global_timer.Start("CUDAHistogramConstructor::SubtractHistogramDiscretizedKernel");
|
||||
if (parent_num_bits_in_histogram_bins <= 16) {
|
||||
CHECK_LE(smaller_num_bits_in_histogram_bins, 16);
|
||||
CHECK_LE(larger_num_bits_in_histogram_bins, 16);
|
||||
SubtractHistogramDiscretizedKernel<true, true, true><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
num_total_bin_,
|
||||
cuda_smaller_leaf_splits,
|
||||
cuda_larger_leaf_splits,
|
||||
hist_buffer_for_num_bit_change_.RawData());
|
||||
} else if (larger_num_bits_in_histogram_bins <= 16) {
|
||||
CHECK_LE(smaller_num_bits_in_histogram_bins, 16);
|
||||
SubtractHistogramDiscretizedKernel<true, true, false><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
num_total_bin_,
|
||||
cuda_smaller_leaf_splits,
|
||||
cuda_larger_leaf_splits,
|
||||
hist_buffer_for_num_bit_change_.RawData());
|
||||
CopyChangedNumBitHistogram<<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
num_total_bin_,
|
||||
cuda_larger_leaf_splits,
|
||||
hist_buffer_for_num_bit_change_.RawData());
|
||||
} else if (smaller_num_bits_in_histogram_bins <= 16) {
|
||||
SubtractHistogramDiscretizedKernel<true, false, false><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
num_total_bin_,
|
||||
cuda_smaller_leaf_splits,
|
||||
cuda_larger_leaf_splits,
|
||||
hist_buffer_for_num_bit_change_.RawData());
|
||||
} else {
|
||||
SubtractHistogramDiscretizedKernel<false, false, false><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
|
||||
num_total_bin_,
|
||||
cuda_smaller_leaf_splits,
|
||||
cuda_larger_leaf_splits,
|
||||
hist_buffer_for_num_bit_change_.RawData());
|
||||
}
|
||||
global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramDiscretizedKernel");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_row_data.hpp>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/feature_group.h>
|
||||
#include <LightGBM/tree.h>
|
||||
|
||||
|
@ -37,7 +38,9 @@ class CUDAHistogramConstructor {
|
|||
const int min_data_in_leaf,
|
||||
const double min_sum_hessian_in_leaf,
|
||||
const int gpu_device_id,
|
||||
const bool gpu_use_dp);
|
||||
const bool gpu_use_dp,
|
||||
const bool use_discretized_grad,
|
||||
const int grad_discretized_bins);
|
||||
|
||||
~CUDAHistogramConstructor();
|
||||
|
||||
|
@ -49,7 +52,16 @@ class CUDAHistogramConstructor {
|
|||
const data_size_t num_data_in_smaller_leaf,
|
||||
const data_size_t num_data_in_larger_leaf,
|
||||
const double sum_hessians_in_smaller_leaf,
|
||||
const double sum_hessians_in_larger_leaf);
|
||||
const double sum_hessians_in_larger_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins);
|
||||
|
||||
void SubtractHistogramForLeaf(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
|
||||
const bool use_discretized_grad,
|
||||
const uint8_t parent_num_bits_in_histogram_bins,
|
||||
const uint8_t smaller_num_bits_in_histogram_bins,
|
||||
const uint8_t larger_num_bits_in_histogram_bins);
|
||||
|
||||
void ResetTrainingData(const Dataset* train_data, TrainingShareStates* share_states);
|
||||
|
||||
|
@ -57,9 +69,9 @@ class CUDAHistogramConstructor {
|
|||
|
||||
void BeforeTrain(const score_t* gradients, const score_t* hessians);
|
||||
|
||||
const hist_t* cuda_hist() const { return cuda_hist_; }
|
||||
const hist_t* cuda_hist() const { return cuda_hist_.RawData(); }
|
||||
|
||||
hist_t* cuda_hist_pointer() { return cuda_hist_; }
|
||||
hist_t* cuda_hist_pointer() { return cuda_hist_.RawData(); }
|
||||
|
||||
private:
|
||||
void InitFeatureMetaInfo(const Dataset* train_data, const std::vector<uint32_t>& feature_hist_offsets);
|
||||
|
@ -74,30 +86,39 @@ class CUDAHistogramConstructor {
|
|||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE>
|
||||
void LaunchConstructHistogramKernelInner(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf);
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins);
|
||||
|
||||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE>
|
||||
void LaunchConstructHistogramKernelInner0(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf);
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins);
|
||||
|
||||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE>
|
||||
void LaunchConstructHistogramKernelInner1(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf);
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins);
|
||||
|
||||
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE, bool USE_GLOBAL_MEM_BUFFER>
|
||||
void LaunchConstructHistogramKernelInner2(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf);
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins);
|
||||
|
||||
void LaunchConstructHistogramKernel(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const data_size_t num_data_in_smaller_leaf);
|
||||
const data_size_t num_data_in_smaller_leaf,
|
||||
const uint8_t num_bits_in_histogram_bins);
|
||||
|
||||
void LaunchSubtractHistogramKernel(
|
||||
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits);
|
||||
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
|
||||
const bool use_discretized_grad,
|
||||
const uint8_t parent_num_bits_in_histogram_bins,
|
||||
const uint8_t smaller_num_bits_in_histogram_bins,
|
||||
const uint8_t larger_num_bits_in_histogram_bins);
|
||||
|
||||
// Host memory
|
||||
|
||||
|
@ -136,19 +157,21 @@ class CUDAHistogramConstructor {
|
|||
/*! \brief CUDA row wise data */
|
||||
std::unique_ptr<CUDARowData> cuda_row_data_;
|
||||
/*! \brief number of bins per feature */
|
||||
uint32_t* cuda_feature_num_bins_;
|
||||
CUDAVector<uint32_t> cuda_feature_num_bins_;
|
||||
/*! \brief offsets in histogram of all features */
|
||||
uint32_t* cuda_feature_hist_offsets_;
|
||||
CUDAVector<uint32_t> cuda_feature_hist_offsets_;
|
||||
/*! \brief most frequent bins in each feature */
|
||||
uint32_t* cuda_feature_most_freq_bins_;
|
||||
CUDAVector<uint32_t> cuda_feature_most_freq_bins_;
|
||||
/*! \brief CUDA histograms */
|
||||
hist_t* cuda_hist_;
|
||||
CUDAVector<hist_t> cuda_hist_;
|
||||
/*! \brief CUDA histograms buffer for each block */
|
||||
float* cuda_hist_buffer_;
|
||||
CUDAVector<float> cuda_hist_buffer_;
|
||||
/*! \brief indices of feature whose histograms need to be fixed */
|
||||
int* cuda_need_fix_histogram_features_;
|
||||
CUDAVector<int> cuda_need_fix_histogram_features_;
|
||||
/*! \brief aligned number of bins of the features whose histograms need to be fixed */
|
||||
uint32_t* cuda_need_fix_histogram_features_num_bin_aligned_;
|
||||
CUDAVector<uint32_t> cuda_need_fix_histogram_features_num_bin_aligned_;
|
||||
/*! \brief histogram buffer used in histogram subtraction with different number of bits for histogram bins */
|
||||
CUDAVector<hist_t> hist_buffer_for_num_bit_change_;
|
||||
|
||||
// CUDA memory, held by other object
|
||||
|
||||
|
@ -161,6 +184,10 @@ class CUDAHistogramConstructor {
|
|||
const int gpu_device_id_;
|
||||
/*! \brief use double precision histogram per block */
|
||||
const bool gpu_use_dp_;
|
||||
/*! \brief whether to use quantized gradients */
|
||||
const bool use_quantized_grad_;
|
||||
/*! \brief the number of bins to quantized gradients */
|
||||
const int num_grad_quant_bins_;
|
||||
};
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -11,27 +11,22 @@
|
|||
namespace LightGBM {
|
||||
|
||||
CUDALeafSplits::CUDALeafSplits(const data_size_t num_data):
|
||||
num_data_(num_data) {
|
||||
cuda_struct_ = nullptr;
|
||||
cuda_sum_of_gradients_buffer_ = nullptr;
|
||||
cuda_sum_of_hessians_buffer_ = nullptr;
|
||||
}
|
||||
num_data_(num_data) {}
|
||||
|
||||
CUDALeafSplits::~CUDALeafSplits() {
|
||||
DeallocateCUDAMemory<CUDALeafSplitsStruct>(&cuda_struct_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__);
|
||||
}
|
||||
CUDALeafSplits::~CUDALeafSplits() {}
|
||||
|
||||
void CUDALeafSplits::Init() {
|
||||
void CUDALeafSplits::Init(const bool use_quantized_grad) {
|
||||
num_blocks_init_from_gradients_ = (num_data_ + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
|
||||
|
||||
// allocate more memory for sum reduction in CUDA
|
||||
// only the first element records the final sum
|
||||
AllocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
|
||||
AllocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
|
||||
cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
|
||||
cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
|
||||
if (use_quantized_grad) {
|
||||
cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
|
||||
}
|
||||
|
||||
AllocateCUDAMemory<CUDALeafSplitsStruct>(&cuda_struct_, 1, __FILE__, __LINE__);
|
||||
cuda_struct_.Resize(1);
|
||||
}
|
||||
|
||||
void CUDALeafSplits::InitValues() {
|
||||
|
@ -46,24 +41,33 @@ void CUDALeafSplits::InitValues(
|
|||
const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians) {
|
||||
cuda_gradients_ = cuda_gradients;
|
||||
cuda_hessians_ = cuda_hessians;
|
||||
SetCUDAMemory<double>(cuda_sum_of_gradients_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__);
|
||||
SetCUDAMemory<double>(cuda_sum_of_hessians_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__);
|
||||
cuda_sum_of_gradients_buffer_.SetValue(0);
|
||||
cuda_sum_of_hessians_buffer_.SetValue(0);
|
||||
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf);
|
||||
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_, 1, __FILE__, __LINE__);
|
||||
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
void CUDALeafSplits::InitValues(
|
||||
const double lambda_l1, const double lambda_l2,
|
||||
const int16_t* cuda_gradients_and_hessians,
|
||||
const data_size_t* cuda_bagging_data_indices,
|
||||
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
|
||||
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
|
||||
const score_t* grad_scale, const score_t* hess_scale) {
|
||||
cuda_gradients_ = reinterpret_cast<const score_t*>(cuda_gradients_and_hessians);
|
||||
cuda_hessians_ = nullptr;
|
||||
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf, grad_scale, hess_scale);
|
||||
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
void CUDALeafSplits::Resize(const data_size_t num_data) {
|
||||
if (num_data > num_data_) {
|
||||
DeallocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__);
|
||||
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
|
||||
AllocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
|
||||
AllocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
|
||||
} else {
|
||||
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
|
||||
}
|
||||
num_data_ = num_data;
|
||||
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
|
||||
cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
|
||||
cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
|
||||
cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -81,6 +81,90 @@ __global__ void CUDAInitValuesKernel2(
|
|||
}
|
||||
}
|
||||
|
||||
template <bool USE_INDICES>
|
||||
__global__ void CUDAInitValuesKernel3(const int16_t* cuda_gradients_and_hessians,
|
||||
const data_size_t num_data, const data_size_t* cuda_bagging_data_indices,
|
||||
double* cuda_sum_of_gradients, double* cuda_sum_of_hessians, int64_t* cuda_sum_of_hessians_hessians,
|
||||
const score_t* grad_scale_pointer, const score_t* hess_scale_pointer) {
|
||||
const score_t grad_scale = *grad_scale_pointer;
|
||||
const score_t hess_scale = *hess_scale_pointer;
|
||||
__shared__ int64_t shared_mem_buffer[32];
|
||||
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
|
||||
int64_t int_gradient = 0;
|
||||
int64_t int_hessian = 0;
|
||||
if (data_index < num_data) {
|
||||
int_gradient = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index] + 1] :
|
||||
cuda_gradients_and_hessians[2 * data_index + 1];
|
||||
int_hessian = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index]] :
|
||||
cuda_gradients_and_hessians[2 * data_index];
|
||||
}
|
||||
const int64_t block_sum_gradient = ShuffleReduceSum<int64_t>(int_gradient, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
const int64_t block_sum_hessian = ShuffleReduceSum<int64_t>(int_hessian, shared_mem_buffer, blockDim.x);
|
||||
if (threadIdx.x == 0) {
|
||||
cuda_sum_of_gradients[blockIdx.x] = block_sum_gradient * grad_scale;
|
||||
cuda_sum_of_hessians[blockIdx.x] = block_sum_hessian * hess_scale;
|
||||
cuda_sum_of_hessians_hessians[blockIdx.x] = ((block_sum_gradient << 32) | block_sum_hessian);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CUDAInitValuesKernel4(
|
||||
const double lambda_l1,
|
||||
const double lambda_l2,
|
||||
const int num_blocks_to_reduce,
|
||||
double* cuda_sum_of_gradients,
|
||||
double* cuda_sum_of_hessians,
|
||||
int64_t* cuda_sum_of_gradients_hessians,
|
||||
const data_size_t num_data,
|
||||
const data_size_t* cuda_data_indices_in_leaf,
|
||||
hist_t* cuda_hist_in_leaf,
|
||||
CUDALeafSplitsStruct* cuda_struct) {
|
||||
__shared__ double shared_mem_buffer[32];
|
||||
double thread_sum_of_gradients = 0.0f;
|
||||
double thread_sum_of_hessians = 0.0f;
|
||||
int64_t thread_sum_of_gradients_hessians = 0;
|
||||
for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks_to_reduce; block_index += static_cast<int>(blockDim.x)) {
|
||||
thread_sum_of_gradients += cuda_sum_of_gradients[block_index];
|
||||
thread_sum_of_hessians += cuda_sum_of_hessians[block_index];
|
||||
thread_sum_of_gradients_hessians += cuda_sum_of_gradients_hessians[block_index];
|
||||
}
|
||||
const double sum_of_gradients = ShuffleReduceSum<double>(thread_sum_of_gradients, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
const double sum_of_hessians = ShuffleReduceSum<double>(thread_sum_of_hessians, shared_mem_buffer, blockDim.x);
|
||||
__syncthreads();
|
||||
const double sum_of_gradients_hessians = ShuffleReduceSum<int64_t>(
|
||||
thread_sum_of_gradients_hessians,
|
||||
reinterpret_cast<int64_t*>(shared_mem_buffer),
|
||||
blockDim.x);
|
||||
if (threadIdx.x == 0) {
|
||||
cuda_sum_of_hessians[0] = sum_of_hessians;
|
||||
cuda_struct->leaf_index = 0;
|
||||
cuda_struct->sum_of_gradients = sum_of_gradients;
|
||||
cuda_struct->sum_of_hessians = sum_of_hessians;
|
||||
cuda_struct->sum_of_gradients_hessians = sum_of_gradients_hessians;
|
||||
cuda_struct->num_data_in_leaf = num_data;
|
||||
const bool use_l1 = lambda_l1 > 0.0f;
|
||||
if (!use_l1) {
|
||||
// no smoothing on root node
|
||||
cuda_struct->gain = CUDALeafSplits::GetLeafGain<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
|
||||
} else {
|
||||
// no smoothing on root node
|
||||
cuda_struct->gain = CUDALeafSplits::GetLeafGain<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
|
||||
}
|
||||
if (!use_l1) {
|
||||
// no smoothing on root node
|
||||
cuda_struct->leaf_value =
|
||||
CUDALeafSplits::CalculateSplittedLeafOutput<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
|
||||
} else {
|
||||
// no smoothing on root node
|
||||
cuda_struct->leaf_value =
|
||||
CUDALeafSplits::CalculateSplittedLeafOutput<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
|
||||
}
|
||||
cuda_struct->data_indices_in_leaf = cuda_data_indices_in_leaf;
|
||||
cuda_struct->hist_in_leaf = cuda_hist_in_leaf;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) {
|
||||
cuda_struct->leaf_index = -1;
|
||||
cuda_struct->sum_of_gradients = 0.0f;
|
||||
|
@ -93,7 +177,7 @@ __global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) {
|
|||
}
|
||||
|
||||
void CUDALeafSplits::LaunchInitValuesEmptyKernel() {
|
||||
InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_);
|
||||
InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_.RawData());
|
||||
}
|
||||
|
||||
void CUDALeafSplits::LaunchInitValuesKernal(
|
||||
|
@ -104,23 +188,55 @@ void CUDALeafSplits::LaunchInitValuesKernal(
|
|||
hist_t* cuda_hist_in_leaf) {
|
||||
if (cuda_bagging_data_indices == nullptr) {
|
||||
CUDAInitValuesKernel1<false><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
|
||||
cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_,
|
||||
cuda_sum_of_hessians_buffer_);
|
||||
cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
|
||||
cuda_sum_of_hessians_buffer_.RawData());
|
||||
} else {
|
||||
CUDAInitValuesKernel1<true><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
|
||||
cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_,
|
||||
cuda_sum_of_hessians_buffer_);
|
||||
cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
|
||||
cuda_sum_of_hessians_buffer_.RawData());
|
||||
}
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
CUDAInitValuesKernel2<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
|
||||
lambda_l1, lambda_l2,
|
||||
num_blocks_init_from_gradients_,
|
||||
cuda_sum_of_gradients_buffer_,
|
||||
cuda_sum_of_hessians_buffer_,
|
||||
cuda_sum_of_gradients_buffer_.RawData(),
|
||||
cuda_sum_of_hessians_buffer_.RawData(),
|
||||
num_used_indices,
|
||||
cuda_data_indices_in_leaf,
|
||||
cuda_hist_in_leaf,
|
||||
cuda_struct_);
|
||||
cuda_struct_.RawData());
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
void CUDALeafSplits::LaunchInitValuesKernal(
|
||||
const double lambda_l1, const double lambda_l2,
|
||||
const data_size_t* cuda_bagging_data_indices,
|
||||
const data_size_t* cuda_data_indices_in_leaf,
|
||||
const data_size_t num_used_indices,
|
||||
hist_t* cuda_hist_in_leaf,
|
||||
const score_t* grad_scale,
|
||||
const score_t* hess_scale) {
|
||||
if (cuda_bagging_data_indices == nullptr) {
|
||||
CUDAInitValuesKernel3<false><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
|
||||
reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
|
||||
cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
|
||||
} else {
|
||||
CUDAInitValuesKernel3<true><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
|
||||
reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
|
||||
cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
|
||||
}
|
||||
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
CUDAInitValuesKernel4<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
|
||||
lambda_l1, lambda_l2,
|
||||
num_blocks_init_from_gradients_,
|
||||
cuda_sum_of_gradients_buffer_.RawData(),
|
||||
cuda_sum_of_hessians_buffer_.RawData(),
|
||||
cuda_sum_of_gradients_hessians_buffer_.RawData(),
|
||||
num_used_indices,
|
||||
cuda_data_indices_in_leaf,
|
||||
cuda_hist_in_leaf,
|
||||
cuda_struct_.RawData());
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/bin.h>
|
||||
#include <LightGBM/utils/log.h>
|
||||
#include <LightGBM/meta.h>
|
||||
|
@ -23,6 +23,7 @@ struct CUDALeafSplitsStruct {
|
|||
int leaf_index;
|
||||
double sum_of_gradients;
|
||||
double sum_of_hessians;
|
||||
int64_t sum_of_gradients_hessians;
|
||||
data_size_t num_data_in_leaf;
|
||||
double gain;
|
||||
double leaf_value;
|
||||
|
@ -36,7 +37,7 @@ class CUDALeafSplits {
|
|||
|
||||
~CUDALeafSplits();
|
||||
|
||||
void Init();
|
||||
void Init(const bool use_quantized_grad);
|
||||
|
||||
void InitValues(
|
||||
const double lambda_l1, const double lambda_l2,
|
||||
|
@ -45,11 +46,19 @@ class CUDALeafSplits {
|
|||
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
|
||||
hist_t* cuda_hist_in_leaf, double* root_sum_hessians);
|
||||
|
||||
void InitValues(
|
||||
const double lambda_l1, const double lambda_l2,
|
||||
const int16_t* cuda_gradients_and_hessians,
|
||||
const data_size_t* cuda_bagging_data_indices,
|
||||
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
|
||||
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
|
||||
const score_t* grad_scale, const score_t* hess_scale);
|
||||
|
||||
void InitValues();
|
||||
|
||||
const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_; }
|
||||
const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_.RawDataReadOnly(); }
|
||||
|
||||
CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_; }
|
||||
CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_.RawData(); }
|
||||
|
||||
void Resize(const data_size_t num_data);
|
||||
|
||||
|
@ -140,14 +149,24 @@ class CUDALeafSplits {
|
|||
const data_size_t num_used_indices,
|
||||
hist_t* cuda_hist_in_leaf);
|
||||
|
||||
void LaunchInitValuesKernal(
|
||||
const double lambda_l1, const double lambda_l2,
|
||||
const data_size_t* cuda_bagging_data_indices,
|
||||
const data_size_t* cuda_data_indices_in_leaf,
|
||||
const data_size_t num_used_indices,
|
||||
hist_t* cuda_hist_in_leaf,
|
||||
const score_t* grad_scale,
|
||||
const score_t* hess_scale);
|
||||
|
||||
// Host memory
|
||||
data_size_t num_data_;
|
||||
int num_blocks_init_from_gradients_;
|
||||
|
||||
// CUDA memory, held by this object
|
||||
CUDALeafSplitsStruct* cuda_struct_;
|
||||
double* cuda_sum_of_gradients_buffer_;
|
||||
double* cuda_sum_of_hessians_buffer_;
|
||||
CUDAVector<CUDALeafSplitsStruct> cuda_struct_;
|
||||
CUDAVector<double> cuda_sum_of_gradients_buffer_;
|
||||
CUDAVector<double> cuda_sum_of_hessians_buffer_;
|
||||
CUDAVector<int64_t> cuda_sum_of_gradients_hessians_buffer_;
|
||||
|
||||
// CUDA memory, held by other object
|
||||
const score_t* cuda_gradients_;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "cuda_single_gpu_tree_learner.hpp"
|
||||
|
||||
#include <LightGBM/cuda/cuda_tree.hpp>
|
||||
#include <LightGBM/cuda/cuda_utils.h>
|
||||
#include <LightGBM/cuda/cuda_utils.hu>
|
||||
#include <LightGBM/feature_group.h>
|
||||
#include <LightGBM/network.h>
|
||||
#include <LightGBM/objective_function.h>
|
||||
|
@ -39,13 +39,14 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
|
|||
SetCUDADevice(gpu_device_id_, __FILE__, __LINE__);
|
||||
|
||||
cuda_smaller_leaf_splits_.reset(new CUDALeafSplits(num_data_));
|
||||
cuda_smaller_leaf_splits_->Init();
|
||||
cuda_smaller_leaf_splits_->Init(config_->use_quantized_grad);
|
||||
cuda_larger_leaf_splits_.reset(new CUDALeafSplits(num_data_));
|
||||
cuda_larger_leaf_splits_->Init();
|
||||
cuda_larger_leaf_splits_->Init(config_->use_quantized_grad);
|
||||
|
||||
cuda_histogram_constructor_.reset(new CUDAHistogramConstructor(train_data_, config_->num_leaves, num_threads_,
|
||||
share_state_->feature_hist_offsets(),
|
||||
config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp));
|
||||
config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp,
|
||||
config_->use_quantized_grad, config_->num_grad_quant_bins));
|
||||
cuda_histogram_constructor_->Init(train_data_, share_state_.get());
|
||||
|
||||
const auto& feature_hist_offsets = share_state_->feature_hist_offsets();
|
||||
|
@ -73,11 +74,19 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
|
|||
}
|
||||
AllocateBitset();
|
||||
|
||||
cuda_leaf_gradient_stat_buffer_ = nullptr;
|
||||
cuda_leaf_hessian_stat_buffer_ = nullptr;
|
||||
leaf_stat_buffer_size_ = 0;
|
||||
num_cat_threshold_ = 0;
|
||||
|
||||
if (config_->use_quantized_grad) {
|
||||
cuda_leaf_gradient_stat_buffer_.Resize(config_->num_leaves);
|
||||
cuda_leaf_hessian_stat_buffer_.Resize(config_->num_leaves);
|
||||
cuda_gradient_discretizer_.reset(new CUDAGradientDiscretizer(
|
||||
config_->num_grad_quant_bins, config_->num_iterations, config_->seed, is_constant_hessian, config_->stochastic_rounding));
|
||||
cuda_gradient_discretizer_->Init(num_data_, config_->num_leaves, train_data_->num_features(), train_data_);
|
||||
} else {
|
||||
cuda_gradient_discretizer_.reset(nullptr);
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
host_gradients_.resize(num_data_, 0.0f);
|
||||
host_hessians_.resize(num_data_, 0.0f);
|
||||
|
@ -101,6 +110,24 @@ void CUDASingleGPUTreeLearner::BeforeTrain() {
|
|||
const data_size_t* leaf_splits_init_indices =
|
||||
cuda_data_partition_->use_bagging() ? cuda_data_partition_->cuda_data_indices() : nullptr;
|
||||
cuda_data_partition_->BeforeTrain();
|
||||
if (config_->use_quantized_grad) {
|
||||
cuda_gradient_discretizer_->DiscretizeGradients(num_data_, gradients_, hessians_);
|
||||
cuda_histogram_constructor_->BeforeTrain(
|
||||
reinterpret_cast<const score_t*>(cuda_gradient_discretizer_->discretized_gradients_and_hessians()), nullptr);
|
||||
cuda_smaller_leaf_splits_->InitValues(
|
||||
config_->lambda_l1,
|
||||
config_->lambda_l2,
|
||||
reinterpret_cast<const int16_t*>(cuda_gradient_discretizer_->discretized_gradients_and_hessians()),
|
||||
leaf_splits_init_indices,
|
||||
cuda_data_partition_->cuda_data_indices(),
|
||||
root_num_data,
|
||||
cuda_histogram_constructor_->cuda_hist_pointer(),
|
||||
&leaf_sum_hessians_[0],
|
||||
cuda_gradient_discretizer_->grad_scale_ptr(),
|
||||
cuda_gradient_discretizer_->hess_scale_ptr());
|
||||
cuda_gradient_discretizer_->SetNumBitsInHistogramBin<false>(0, -1, root_num_data, 0);
|
||||
} else {
|
||||
cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_);
|
||||
cuda_smaller_leaf_splits_->InitValues(
|
||||
config_->lambda_l1,
|
||||
config_->lambda_l2,
|
||||
|
@ -111,9 +138,9 @@ void CUDASingleGPUTreeLearner::BeforeTrain() {
|
|||
root_num_data,
|
||||
cuda_histogram_constructor_->cuda_hist_pointer(),
|
||||
&leaf_sum_hessians_[0]);
|
||||
}
|
||||
leaf_num_data_[0] = root_num_data;
|
||||
cuda_larger_leaf_splits_->InitValues();
|
||||
cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_);
|
||||
col_sampler_.ResetByTree();
|
||||
cuda_best_split_finder_->BeforeTrain(col_sampler_.is_feature_used_bytree());
|
||||
leaf_data_start_[0] = 0;
|
||||
|
@ -141,24 +168,70 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
|
|||
const data_size_t num_data_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_num_data_[larger_leaf_index_];
|
||||
const double sum_hessians_in_smaller_leaf = leaf_sum_hessians_[smaller_leaf_index_];
|
||||
const double sum_hessians_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_sum_hessians_[larger_leaf_index_];
|
||||
const uint8_t num_bits_in_histogram_bins = config_->use_quantized_grad ? cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_) : 0;
|
||||
cuda_histogram_constructor_->ConstructHistogramForLeaf(
|
||||
cuda_smaller_leaf_splits_->GetCUDAStruct(),
|
||||
cuda_larger_leaf_splits_->GetCUDAStruct(),
|
||||
num_data_in_smaller_leaf,
|
||||
num_data_in_larger_leaf,
|
||||
sum_hessians_in_smaller_leaf,
|
||||
sum_hessians_in_larger_leaf);
|
||||
sum_hessians_in_larger_leaf,
|
||||
num_bits_in_histogram_bins);
|
||||
global_timer.Stop("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf");
|
||||
global_timer.Start("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf");
|
||||
|
||||
uint8_t parent_num_bits_bin = 0;
|
||||
uint8_t smaller_num_bits_bin = 0;
|
||||
uint8_t larger_num_bits_bin = 0;
|
||||
if (config_->use_quantized_grad) {
|
||||
if (larger_leaf_index_ != -1) {
|
||||
const int parent_leaf_index = std::min(smaller_leaf_index_, larger_leaf_index_);
|
||||
parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInNode<false>(parent_leaf_index);
|
||||
smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_);
|
||||
larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(larger_leaf_index_);
|
||||
} else {
|
||||
parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
|
||||
smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
|
||||
larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
|
||||
}
|
||||
} else {
|
||||
parent_num_bits_bin = 0;
|
||||
smaller_num_bits_bin = 0;
|
||||
larger_num_bits_bin = 0;
|
||||
}
|
||||
cuda_histogram_constructor_->SubtractHistogramForLeaf(
|
||||
cuda_smaller_leaf_splits_->GetCUDAStruct(),
|
||||
cuda_larger_leaf_splits_->GetCUDAStruct(),
|
||||
config_->use_quantized_grad,
|
||||
parent_num_bits_bin,
|
||||
smaller_num_bits_bin,
|
||||
larger_num_bits_bin);
|
||||
|
||||
SelectFeatureByNode(tree.get());
|
||||
|
||||
if (config_->use_quantized_grad) {
|
||||
const uint8_t smaller_leaf_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_);
|
||||
const uint8_t larger_leaf_num_bits_bin = larger_leaf_index_ < 0 ? 32 : cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(larger_leaf_index_);
|
||||
cuda_best_split_finder_->FindBestSplitsForLeaf(
|
||||
cuda_smaller_leaf_splits_->GetCUDAStruct(),
|
||||
cuda_larger_leaf_splits_->GetCUDAStruct(),
|
||||
smaller_leaf_index_, larger_leaf_index_,
|
||||
num_data_in_smaller_leaf, num_data_in_larger_leaf,
|
||||
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf);
|
||||
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf,
|
||||
cuda_gradient_discretizer_->grad_scale_ptr(),
|
||||
cuda_gradient_discretizer_->hess_scale_ptr(),
|
||||
smaller_leaf_num_bits_bin,
|
||||
larger_leaf_num_bits_bin);
|
||||
} else {
|
||||
cuda_best_split_finder_->FindBestSplitsForLeaf(
|
||||
cuda_smaller_leaf_splits_->GetCUDAStruct(),
|
||||
cuda_larger_leaf_splits_->GetCUDAStruct(),
|
||||
smaller_leaf_index_, larger_leaf_index_,
|
||||
num_data_in_smaller_leaf, num_data_in_larger_leaf,
|
||||
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf,
|
||||
nullptr, nullptr, 0, 0);
|
||||
}
|
||||
|
||||
global_timer.Stop("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf");
|
||||
global_timer.Start("CUDASingleGPUTreeLearner::FindBestFromAllSplits");
|
||||
const CUDASplitInfo* best_split_info = nullptr;
|
||||
|
@ -247,9 +320,19 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
|
|||
#endif // DEBUG
|
||||
smaller_leaf_index_ = (leaf_num_data_[best_leaf_index_] < leaf_num_data_[right_leaf_index] ? best_leaf_index_ : right_leaf_index);
|
||||
larger_leaf_index_ = (smaller_leaf_index_ == best_leaf_index_ ? right_leaf_index : best_leaf_index_);
|
||||
|
||||
if (config_->use_quantized_grad) {
|
||||
cuda_gradient_discretizer_->SetNumBitsInHistogramBin<false>(
|
||||
best_leaf_index_, right_leaf_index, leaf_num_data_[best_leaf_index_], leaf_num_data_[right_leaf_index]);
|
||||
}
|
||||
global_timer.Stop("CUDASingleGPUTreeLearner::Split");
|
||||
}
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
if (config_->use_quantized_grad && config_->quant_train_renew_leaf) {
|
||||
global_timer.Start("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves");
|
||||
RenewDiscretizedTreeLeaves(tree.get());
|
||||
global_timer.Stop("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves");
|
||||
}
|
||||
tree->ToHost();
|
||||
return tree.release();
|
||||
}
|
||||
|
@ -357,8 +440,8 @@ void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFuncti
|
|||
|
||||
Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const {
|
||||
std::unique_ptr<CUDATree> cuda_tree(new CUDATree(old_tree));
|
||||
SetCUDAMemory<double>(cuda_leaf_gradient_stat_buffer_, 0, static_cast<size_t>(old_tree->num_leaves()), __FILE__, __LINE__);
|
||||
SetCUDAMemory<double>(cuda_leaf_hessian_stat_buffer_, 0, static_cast<size_t>(old_tree->num_leaves()), __FILE__, __LINE__);
|
||||
cuda_leaf_gradient_stat_buffer_.SetValue(0);
|
||||
cuda_leaf_hessian_stat_buffer_.SetValue(0);
|
||||
ReduceLeafStat(cuda_tree.get(), gradients, hessians, cuda_data_partition_->cuda_data_indices());
|
||||
cuda_tree->SyncLeafOutputFromCUDAToHost();
|
||||
return cuda_tree.release();
|
||||
|
@ -373,13 +456,9 @@ Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const st
|
|||
const int num_block = (refit_num_data_ + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
|
||||
buffer_size *= static_cast<data_size_t>(num_block + 1);
|
||||
}
|
||||
if (buffer_size != leaf_stat_buffer_size_) {
|
||||
if (leaf_stat_buffer_size_ != 0) {
|
||||
DeallocateCUDAMemory<double>(&cuda_leaf_gradient_stat_buffer_, __FILE__, __LINE__);
|
||||
DeallocateCUDAMemory<double>(&cuda_leaf_hessian_stat_buffer_, __FILE__, __LINE__);
|
||||
}
|
||||
AllocateCUDAMemory<double>(&cuda_leaf_gradient_stat_buffer_, static_cast<size_t>(buffer_size), __FILE__, __LINE__);
|
||||
AllocateCUDAMemory<double>(&cuda_leaf_hessian_stat_buffer_, static_cast<size_t>(buffer_size), __FILE__, __LINE__);
|
||||
if (static_cast<size_t>(buffer_size) > cuda_leaf_gradient_stat_buffer_.Size()) {
|
||||
cuda_leaf_gradient_stat_buffer_.Resize(buffer_size);
|
||||
cuda_leaf_hessian_stat_buffer_.Resize(buffer_size);
|
||||
}
|
||||
return FitByExistingTree(old_tree, gradients, hessians);
|
||||
}
|
||||
|
@ -513,6 +592,15 @@ void CUDASingleGPUTreeLearner::CheckSplitValid(
|
|||
}
|
||||
#endif // DEBUG
|
||||
|
||||
void CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves(CUDATree* cuda_tree) {
|
||||
cuda_data_partition_->ReduceLeafGradStat(
|
||||
gradients_, hessians_, cuda_tree,
|
||||
cuda_leaf_gradient_stat_buffer_.RawData(),
|
||||
cuda_leaf_hessian_stat_buffer_.RawData());
|
||||
LaunchCalcLeafValuesGivenGradStat(cuda_tree, cuda_data_partition_->cuda_data_indices());
|
||||
SynchronizeCUDADevice(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // USE_CUDA
|
||||
|
|
|
@ -129,18 +129,18 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel(
|
|||
if (num_leaves <= 2048) {
|
||||
ReduceLeafStatKernel_SharedMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE, 2 * num_leaves * sizeof(double)>>>(
|
||||
gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(),
|
||||
cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_);
|
||||
cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData());
|
||||
} else {
|
||||
ReduceLeafStatKernel_GlobalMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(
|
||||
gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(),
|
||||
cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_);
|
||||
cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData());
|
||||
}
|
||||
const bool use_l1 = config_->lambda_l1 > 0.0f;
|
||||
const bool use_smoothing = config_->path_smooth > 0.0f;
|
||||
num_block = (num_leaves + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
|
||||
|
||||
#define CalcRefitLeafOutputKernel_ARGS \
|
||||
num_leaves, cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_, num_data_in_leaf, \
|
||||
num_leaves, cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \
|
||||
leaf_parent, left_child, right_child, \
|
||||
config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \
|
||||
shrinkage_rate, config_->refit_decay_rate, cuda_leaf_value
|
||||
|
@ -162,6 +162,7 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel(
|
|||
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
|
||||
}
|
||||
}
|
||||
#undef CalcRefitLeafOutputKernel_ARGS
|
||||
}
|
||||
|
||||
template <typename T, bool IS_INNER>
|
||||
|
@ -256,6 +257,37 @@ void CUDASingleGPUTreeLearner::LaunchConstructBitsetForCategoricalSplitKernel(
|
|||
CUDAConstructBitset<int, false>(best_split_info, num_cat_threshold_, cuda_bitset_, cuda_bitset_len_);
|
||||
}
|
||||
|
||||
void CUDASingleGPUTreeLearner::LaunchCalcLeafValuesGivenGradStat(
|
||||
CUDATree* cuda_tree, const data_size_t* num_data_in_leaf) {
|
||||
#define CalcRefitLeafOutputKernel_ARGS \
|
||||
cuda_tree->num_leaves(), cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \
|
||||
cuda_tree->cuda_leaf_parent(), cuda_tree->cuda_left_child(), cuda_tree->cuda_right_child(), \
|
||||
config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \
|
||||
1.0f, config_->refit_decay_rate, cuda_tree->cuda_leaf_value_ref()
|
||||
const bool use_l1 = config_->lambda_l1 > 0.0f;
|
||||
const bool use_smoothing = config_->path_smooth > 0.0f;
|
||||
const int num_block = (cuda_tree->num_leaves() + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
|
||||
if (!use_l1) {
|
||||
if (!use_smoothing) {
|
||||
CalcRefitLeafOutputKernel<false, false>
|
||||
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
|
||||
} else {
|
||||
CalcRefitLeafOutputKernel<false, true>
|
||||
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
|
||||
}
|
||||
} else {
|
||||
if (!use_smoothing) {
|
||||
CalcRefitLeafOutputKernel<true, false>
|
||||
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
|
||||
} else {
|
||||
CalcRefitLeafOutputKernel<true, true>
|
||||
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
|
||||
}
|
||||
}
|
||||
|
||||
#undef CalcRefitLeafOutputKernel_ARGS
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // USE_CUDA
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "cuda_data_partition.hpp"
|
||||
#include "cuda_best_split_finder.hpp"
|
||||
|
||||
#include "cuda_gradient_discretizer.hpp"
|
||||
#include "../serial_tree_learner.h"
|
||||
|
||||
namespace LightGBM {
|
||||
|
@ -74,6 +75,10 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
|
|||
const double sum_left_gradients, const double sum_right_gradients);
|
||||
#endif // DEBUG
|
||||
|
||||
void RenewDiscretizedTreeLeaves(CUDATree* cuda_tree);
|
||||
|
||||
void LaunchCalcLeafValuesGivenGradStat(CUDATree* cuda_tree, const data_size_t* num_data_in_leaf);
|
||||
|
||||
// GPU device ID
|
||||
int gpu_device_id_;
|
||||
// number of threads on CPU
|
||||
|
@ -90,6 +95,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
|
|||
std::unique_ptr<CUDAHistogramConstructor> cuda_histogram_constructor_;
|
||||
// for best split information finding, given the histograms
|
||||
std::unique_ptr<CUDABestSplitFinder> cuda_best_split_finder_;
|
||||
// gradient discretizer for quantized training
|
||||
std::unique_ptr<CUDAGradientDiscretizer> cuda_gradient_discretizer_;
|
||||
|
||||
std::vector<int> leaf_best_split_feature_;
|
||||
std::vector<uint32_t> leaf_best_split_threshold_;
|
||||
|
@ -108,8 +115,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
|
|||
std::vector<int> categorical_bin_to_value_;
|
||||
std::vector<int> categorical_bin_offsets_;
|
||||
|
||||
mutable double* cuda_leaf_gradient_stat_buffer_;
|
||||
mutable double* cuda_leaf_hessian_stat_buffer_;
|
||||
mutable CUDAVector<double> cuda_leaf_gradient_stat_buffer_;
|
||||
mutable CUDAVector<double> cuda_leaf_hessian_stat_buffer_;
|
||||
mutable data_size_t leaf_stat_buffer_size_;
|
||||
mutable data_size_t refit_num_data_;
|
||||
uint32_t* cuda_bitset_;
|
||||
|
|
Загрузка…
Ссылка в новой задаче