[CUDA] CUDA Quantized Training (fixes #5606) (#5933)

* add quantized training (first stage)

* add histogram construction functions for integer gradients

* add stochastic rounding

* update docs

* fix compilation errors by adding template instantiations

* update files for compilation

* fix compilation of gpu version

* initialize gradient discretizer before share states

* add a test case for quantized training

* add quantized training for data distributed training

* Delete origin.pred

* Delete ifelse.pred

* Delete LightGBM_model.txt

* remove useless changes

* fix lint error

* remove debug loggings

* fix mismatch of vector and allocator types

* remove changes in main.cpp

* fix bugs with uninitialized gradient discretizer

* initialize ordered gradients in gradient discretizer

* disable quantized training with gpu and cuda

fix msvc compilation errors and warnings

* fix bug in data parallel tree learner

* make quantized training test deterministic

* make quantized training in test case more accurate

* refactor test_quantized_training

* fix leaf splits initialization with quantized training

* check distributed quantized training result

* add cuda gradient discretizer

* add quantized training for CUDA version in tree learner

* remove cuda computability 6.1 and 6.2

* fix parts of gpu quantized training errors and warnings

* fix build-python.sh to install locally built version

* fix memory access bugs

* fix lint errors

* mark cuda quantized training on cuda with categorical features as unsupported

* rename cuda_utils.h to cuda_utils.hu

* enable quantized training with cuda

* fix cuda quantized training with sparse row data

* allow using global memory buffer in histogram construction with cuda quantized training

* recover build-python.sh

enlarge allowed package size to 100M
This commit is contained in:
shiyu1994 2023-10-08 23:25:46 +08:00 коммит произвёл GitHub
Родитель 3d9ada7657
Коммит f901f47141
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
33 изменённых файлов: 1912 добавлений и 259 удалений

Просмотреть файл

@ -25,7 +25,7 @@ if [ $PY_MINOR_VER -gt 7 ]; then
pydistcheck \
--inspect \
--ignore 'compiled-objects-have-debug-symbols,distro-too-large-compressed' \
--max-allowed-size-uncompressed '70M' \
--max-allowed-size-uncompressed '100M' \
--max-allowed-files 800 \
${DIST_DIR}/* || exit -1
elif { test $(uname -m) = "aarch64"; }; then

Просмотреть файл

@ -13,7 +13,7 @@
#include <stdio.h>
#include <LightGBM/bin.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/utils/log.h>
#include <algorithm>

Просмотреть файл

@ -9,7 +9,7 @@
#define LIGHTGBM_CUDA_CUDA_COLUMN_DATA_HPP_
#include <LightGBM/config.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h>

Просмотреть файл

@ -8,7 +8,7 @@
#ifndef LIGHTGBM_CUDA_CUDA_METADATA_HPP_
#define LIGHTGBM_CUDA_CUDA_METADATA_HPP_
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/meta.h>
#include <vector>

Просмотреть файл

@ -9,7 +9,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/metric.h>
namespace LightGBM {

Просмотреть файл

@ -9,7 +9,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/objective_function.h>
#include <LightGBM/meta.h>

Просмотреть файл

@ -10,7 +10,7 @@
#include <LightGBM/bin.h>
#include <LightGBM/config.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/dataset.h>
#include <LightGBM/train_share_states.h>
#include <LightGBM/utils/openmp_wrapper.h>

Просмотреть файл

@ -24,12 +24,14 @@ class CUDASplitInfo {
double left_sum_gradients;
double left_sum_hessians;
int64_t left_sum_of_gradients_hessians;
data_size_t left_count;
double left_gain;
double left_value;
double right_sum_gradients;
double right_sum_hessians;
int64_t right_sum_of_gradients_hessians;
data_size_t right_count;
double right_gain;
double right_value;

Просмотреть файл

@ -7,15 +7,21 @@
#define LIGHTGBM_CUDA_CUDA_UTILS_H_
#ifdef USE_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <LightGBM/utils/log.h>
#include <algorithm>
#include <vector>
#include <cmath>
namespace LightGBM {
typedef unsigned long long atomic_add_long_t;
#define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) {
if (code != cudaSuccess) {
@ -125,13 +131,19 @@ class CUDAVector {
T* new_data = nullptr;
AllocateCUDAMemory<T>(&new_data, size, __FILE__, __LINE__);
if (size_ > 0 && data_ != nullptr) {
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size, __FILE__, __LINE__);
const size_t size_for_old_content = std::min<size_t>(size_, size);
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size_for_old_content, __FILE__, __LINE__);
}
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
data_ = new_data;
size_ = size;
}
void InitFromHostVector(const std::vector<T>& host_vector) {
Resize(host_vector.size());
CopyFromHostToCUDADevice(data_, host_vector.data(), host_vector.size(), __FILE__, __LINE__);
}
void Clear() {
if (size_ > 0 && data_ != nullptr) {
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
@ -171,6 +183,10 @@ class CUDAVector {
return data_;
}
void SetValue(int value) {
SetCUDAMemory<T>(data_, value, size_, __FILE__, __LINE__);
}
const T* RawDataReadOnly() const {
return data_;
}

Просмотреть файл

@ -6,7 +6,7 @@
#ifndef LIGHTGBM_SAMPLE_STRATEGY_H_
#define LIGHTGBM_SAMPLE_STRATEGY_H_
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>

Просмотреть файл

@ -8,7 +8,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include "../score_updater.hpp"

Просмотреть файл

@ -5,7 +5,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
namespace LightGBM {

Просмотреть файл

@ -389,10 +389,6 @@ void Config::CheckParamConflict() {
if (deterministic) {
Log::Warning("Although \"deterministic\" is set, the results ran by GPU may be non-deterministic.");
}
if (use_quantized_grad) {
Log::Warning("Quantized training is not supported by CUDA tree learner. Switch to full precision training.");
use_quantized_grad = false;
}
}
// linear tree learner must be serial type and run on CPU device
if (linear_tree) {

Просмотреть файл

@ -10,7 +10,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_metric.hpp>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <vector>

Просмотреть файл

@ -10,7 +10,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_metric.hpp>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <vector>

Просмотреть файл

@ -10,7 +10,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_metric.hpp>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <vector>

Просмотреть файл

@ -40,6 +40,9 @@ CUDABestSplitFinder::CUDABestSplitFinder(
select_features_by_node_(select_features_by_node),
cuda_hist_(cuda_hist) {
InitFeatureMetaInfo(train_data);
if (has_categorical_feature_ && config->use_quantized_grad) {
Log::Fatal("Quantized training on GPU with categorical features is not supported yet.");
}
cuda_leaf_best_split_info_ = nullptr;
cuda_best_split_info_ = nullptr;
cuda_best_split_info_buffer_ = nullptr;
@ -326,13 +329,23 @@ void CUDABestSplitFinder::FindBestSplitsForLeaf(
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf) {
const double sum_hessians_in_larger_leaf,
const score_t* grad_scale,
const score_t* hess_scale,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins) {
const bool is_smaller_leaf_valid = (num_data_in_smaller_leaf > min_data_in_leaf_ &&
sum_hessians_in_smaller_leaf > min_sum_hessian_in_leaf_);
const bool is_larger_leaf_valid = (num_data_in_larger_leaf > min_data_in_leaf_ &&
sum_hessians_in_larger_leaf > min_sum_hessian_in_leaf_ && larger_leaf_index >= 0);
LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
if (grad_scale != nullptr && hess_scale != nullptr) {
LaunchFindBestSplitsDiscretizedForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid,
grad_scale, hess_scale, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins);
} else {
LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
}
global_timer.Start("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel");
LaunchSyncBestSplitForLeafKernel(smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
SynchronizeCUDADevice(__FILE__, __LINE__);

Просмотреть файл

@ -320,6 +320,175 @@ __device__ void FindBestSplitsForLeafKernelInner(
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool REVERSE, typename BIN_HIST_TYPE, typename ACC_HIST_TYPE, bool USE_16BIT_BIN_HIST, bool USE_16BIT_ACC_HIST>
__device__ void FindBestSplitsDiscretizedForLeafKernelInner(
// input feature information
const BIN_HIST_TYPE* feature_hist_ptr,
// input task information
const SplitFindTask* task,
CUDARandom* cuda_random,
// input config parameter values
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
// input parent node information
const double parent_gain,
const int64_t sum_gradients_hessians,
const data_size_t num_data,
const double parent_output,
// gradient scale
const double grad_scale,
const double hess_scale,
// output parameters
CUDASplitInfo* cuda_best_split_info) {
const double sum_hessians = static_cast<double>(sum_gradients_hessians & 0x00000000ffffffff) * hess_scale;
const double cnt_factor = num_data / sum_hessians;
const double min_gain_shift = parent_gain + min_gain_to_split;
cuda_best_split_info->is_valid = false;
ACC_HIST_TYPE local_grad_hess_hist = 0;
double local_gain = 0.0f;
bool threshold_found = false;
uint32_t threshold_value = 0;
__shared__ int rand_threshold;
if (USE_RAND && threadIdx.x == 0) {
if (task->num_bin - 2 > 0) {
rand_threshold = cuda_random->NextInt(0, task->num_bin - 2);
}
}
__shared__ uint32_t best_thread_index;
__shared__ double shared_double_buffer[32];
__shared__ bool shared_bool_buffer[32];
__shared__ uint32_t shared_int_buffer[64];
const unsigned int threadIdx_x = threadIdx.x;
const bool skip_sum = REVERSE ?
(task->skip_default_bin && (task->num_bin - 1 - threadIdx_x) == static_cast<int>(task->default_bin)) :
(task->skip_default_bin && (threadIdx_x + task->mfb_offset) == static_cast<int>(task->default_bin));
const uint32_t feature_num_bin_minus_offset = task->num_bin - task->mfb_offset;
if (!REVERSE) {
if (threadIdx_x < feature_num_bin_minus_offset && !skip_sum) {
const unsigned int bin_offset = threadIdx_x;
if (USE_16BIT_BIN_HIST && !USE_16BIT_ACC_HIST) {
const int32_t local_grad_hess_hist_int32 = feature_hist_ptr[bin_offset];
local_grad_hess_hist = (static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist_int32 >> 16)) << 32) | (static_cast<int64_t>(local_grad_hess_hist_int32 & 0x0000ffff));
} else {
local_grad_hess_hist = feature_hist_ptr[bin_offset];
}
}
} else {
if (threadIdx_x >= static_cast<unsigned int>(task->na_as_missing) &&
threadIdx_x < feature_num_bin_minus_offset && !skip_sum) {
const unsigned int read_index = feature_num_bin_minus_offset - 1 - threadIdx_x;
if (USE_16BIT_BIN_HIST && !USE_16BIT_ACC_HIST) {
const int32_t local_grad_hess_hist_int32 = feature_hist_ptr[read_index];
local_grad_hess_hist = (static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist_int32 >> 16)) << 32) | (static_cast<int64_t>(local_grad_hess_hist_int32 & 0x0000ffff));
} else {
local_grad_hess_hist = feature_hist_ptr[read_index];
}
}
}
__syncthreads();
local_gain = kMinScore;
local_grad_hess_hist = ShufflePrefixSum<ACC_HIST_TYPE>(local_grad_hess_hist, reinterpret_cast<ACC_HIST_TYPE*>(shared_int_buffer));
double sum_left_gradient = 0.0f;
double sum_left_hessian = 0.0f;
double sum_right_gradient = 0.0f;
double sum_right_hessian = 0.0f;
data_size_t left_count = 0;
data_size_t right_count = 0;
int64_t sum_left_gradient_hessian = 0;
int64_t sum_right_gradient_hessian = 0;
if (REVERSE) {
if (threadIdx_x >= static_cast<unsigned int>(task->na_as_missing) && threadIdx_x <= task->num_bin - 2 && !skip_sum) {
sum_right_gradient_hessian = USE_16BIT_ACC_HIST ?
(static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist >> 16)) << 32) | static_cast<int64_t>(local_grad_hess_hist & 0x0000ffff) :
local_grad_hess_hist;
sum_right_gradient = static_cast<double>(static_cast<int32_t>((sum_right_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale;
sum_right_hessian = static_cast<double>(static_cast<int32_t>(sum_right_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
sum_left_gradient_hessian = sum_gradients_hessians - sum_right_gradient_hessian;
sum_left_gradient = static_cast<double>(static_cast<int32_t>((sum_left_gradient_hessian & 0xffffffff00000000)>> 32)) * grad_scale;
sum_left_hessian = static_cast<double>(static_cast<int32_t>(sum_left_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
left_count = num_data - right_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || static_cast<int>(task->num_bin - 2 - threadIdx_x) == rand_threshold)) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient,
sum_right_hessian + kEpsilon, lambda_l1,
lambda_l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_value = static_cast<uint32_t>(task->num_bin - 2 - threadIdx_x);
threshold_found = true;
}
}
}
} else {
if (threadIdx_x <= feature_num_bin_minus_offset - 2 && !skip_sum) {
sum_left_gradient_hessian = USE_16BIT_ACC_HIST ?
(static_cast<int64_t>(static_cast<int16_t>(local_grad_hess_hist >> 16)) << 32) | static_cast<int64_t>(local_grad_hess_hist & 0x0000ffff) :
local_grad_hess_hist;
sum_left_gradient = static_cast<double>(static_cast<int32_t>((sum_left_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale;
sum_left_hessian = static_cast<double>(static_cast<int32_t>(sum_left_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
sum_right_gradient_hessian = sum_gradients_hessians - sum_left_gradient_hessian;
sum_right_gradient = static_cast<double>(static_cast<int32_t>((sum_right_gradient_hessian & 0xffffffff00000000) >> 32)) * grad_scale;
sum_right_hessian = static_cast<double>(static_cast<int32_t>(sum_right_gradient_hessian & 0x00000000ffffffff)) * hess_scale;
right_count = num_data - left_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || static_cast<int>(threadIdx_x + task->mfb_offset) == rand_threshold)) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient,
sum_right_hessian + kEpsilon, lambda_l1,
lambda_l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_value = static_cast<uint32_t>(threadIdx_x + task->mfb_offset);
threshold_found = true;
}
}
}
}
__syncthreads();
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_double_buffer, shared_bool_buffer, shared_int_buffer);
if (threadIdx_x == 0) {
best_thread_index = result;
}
__syncthreads();
if (threshold_found && threadIdx_x == best_thread_index) {
cuda_best_split_info->is_valid = true;
cuda_best_split_info->threshold = threshold_value;
cuda_best_split_info->gain = local_gain;
cuda_best_split_info->default_left = task->assume_out_default_left;
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_sum_of_gradients_hessians = sum_left_gradient_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_sum_of_gradients_hessians = sum_right_gradient_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, right_output);
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
__device__ void FindBestSplitsForLeafKernelCategoricalInner(
// input feature information
@ -715,6 +884,169 @@ __global__ void FindBestSplitsForLeafKernel(
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool IS_LARGER>
__global__ void FindBestSplitsDiscretizedForLeafKernel(
// input feature information
const int8_t* is_feature_used_bytree,
// input task information
const int num_tasks,
const SplitFindTask* tasks,
CUDARandom* cuda_randoms,
// input leaf information
const CUDALeafSplitsStruct* smaller_leaf_splits,
const CUDALeafSplitsStruct* larger_leaf_splits,
const uint8_t smaller_leaf_num_bits_in_histogram_bin,
const uint8_t larger_leaf_num_bits_in_histogram_bin,
// input config parameter values
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const double cat_smooth,
const double cat_l2,
const int max_cat_threshold,
const int min_data_per_group,
const int max_cat_to_onehot,
// gradient scale
const score_t* grad_scale,
const score_t* hess_scale,
// output
CUDASplitInfo* cuda_best_split_info) {
const unsigned int task_index = blockIdx.x;
const SplitFindTask* task = tasks + task_index;
const int inner_feature_index = task->inner_feature_index;
const double parent_gain = IS_LARGER ? larger_leaf_splits->gain : smaller_leaf_splits->gain;
const int64_t sum_gradients_hessians = IS_LARGER ? larger_leaf_splits->sum_of_gradients_hessians : smaller_leaf_splits->sum_of_gradients_hessians;
const data_size_t num_data = IS_LARGER ? larger_leaf_splits->num_data_in_leaf : smaller_leaf_splits->num_data_in_leaf;
const double parent_output = IS_LARGER ? larger_leaf_splits->leaf_value : smaller_leaf_splits->leaf_value;
const unsigned int output_offset = IS_LARGER ? (task_index + num_tasks) : task_index;
CUDASplitInfo* out = cuda_best_split_info + output_offset;
CUDARandom* cuda_random = USE_RAND ?
(IS_LARGER ? cuda_randoms + task_index * 2 + 1 : cuda_randoms + task_index * 2) : nullptr;
const bool use_16bit_bin = IS_LARGER ? (larger_leaf_num_bits_in_histogram_bin <= 16) : (smaller_leaf_num_bits_in_histogram_bin <= 16);
if (is_feature_used_bytree[inner_feature_index]) {
if (task->is_categorical) {
__threadfence(); // ensure store issued before trap
asm("trap;");
} else {
if (!task->reverse) {
if (use_16bit_bin) {
const int32_t* hist_ptr =
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, false, int32_t, int32_t, true, true>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients_hessians,
num_data,
parent_output,
// gradient scale
*grad_scale,
*hess_scale,
// output parameters
out);
} else {
const int32_t* hist_ptr =
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, false, int32_t, int64_t, false, false>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients_hessians,
num_data,
parent_output,
// gradient scale
*grad_scale,
*hess_scale,
// output parameters
out);
}
} else {
if (use_16bit_bin) {
const int32_t* hist_ptr =
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, true, int32_t, int32_t, true, true>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients_hessians,
num_data,
parent_output,
// gradient scale
*grad_scale,
*hess_scale,
// output parameters
out);
} else {
const int32_t* hist_ptr =
reinterpret_cast<const int32_t*>(IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset;
FindBestSplitsDiscretizedForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, true, int32_t, int64_t, false, false>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients_hessians,
num_data,
parent_output,
// gradient scale
*grad_scale,
*hess_scale,
// output parameters
out);
}
}
}
} else {
out->is_valid = false;
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool REVERSE>
__device__ void FindBestSplitsForLeafKernelInner_GlobalMemory(
// input feature information
@ -1466,6 +1798,108 @@ void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner2(LaunchFindBest
#undef FindBestSplitsForLeafKernel_ARGS
#undef GlobalMemory_Buffer_ARGS
#define LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS \
const CUDALeafSplitsStruct* smaller_leaf_splits, \
const CUDALeafSplitsStruct* larger_leaf_splits, \
const int smaller_leaf_index, \
const int larger_leaf_index, \
const bool is_smaller_leaf_valid, \
const bool is_larger_leaf_valid, \
const score_t* grad_scale, \
const score_t* hess_scale, \
const uint8_t smaller_num_bits_in_histogram_bins, \
const uint8_t larger_num_bits_in_histogram_bins
#define LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS \
smaller_leaf_splits, \
larger_leaf_splits, \
smaller_leaf_index, \
larger_leaf_index, \
is_smaller_leaf_valid, \
is_larger_leaf_valid, \
grad_scale, \
hess_scale, \
smaller_num_bits_in_histogram_bins, \
larger_num_bits_in_histogram_bins
#define FindBestSplitsDiscretizedForLeafKernel_ARGS \
cuda_is_feature_used_bytree_, \
num_tasks_, \
cuda_split_find_tasks_.RawData(), \
cuda_randoms_.RawData(), \
smaller_leaf_splits, \
larger_leaf_splits, \
smaller_num_bits_in_histogram_bins, \
larger_num_bits_in_histogram_bins, \
min_data_in_leaf_, \
min_sum_hessian_in_leaf_, \
min_gain_to_split_, \
lambda_l1_, \
lambda_l2_, \
path_smooth_, \
cat_smooth_, \
cat_l2_, \
max_cat_threshold_, \
min_data_per_group_, \
max_cat_to_onehot_, \
grad_scale, \
hess_scale, \
cuda_best_split_info_
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernel(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
if (!is_smaller_leaf_valid && !is_larger_leaf_valid) {
return;
}
if (!extra_trees_) {
LaunchFindBestSplitsDiscretizedForLeafKernelInner0<false>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
} else {
LaunchFindBestSplitsDiscretizedForLeafKernelInner0<true>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
}
}
template <bool USE_RAND>
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
if (lambda_l1_ <= 0.0f) {
LaunchFindBestSplitsDiscretizedForLeafKernelInner1<USE_RAND, false>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
} else {
LaunchFindBestSplitsDiscretizedForLeafKernelInner1<USE_RAND, true>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
}
}
template <bool USE_RAND, bool USE_L1>
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
if (!use_smoothing_) {
LaunchFindBestSplitsDiscretizedForLeafKernelInner2<USE_RAND, USE_L1, false>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
} else {
LaunchFindBestSplitsDiscretizedForLeafKernelInner2<USE_RAND, USE_L1, true>(LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS);
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
void CUDABestSplitFinder::LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS) {
if (!use_global_memory_) {
if (is_smaller_leaf_valid) {
FindBestSplitsDiscretizedForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, false>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[0]>>>
(FindBestSplitsDiscretizedForLeafKernel_ARGS);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (is_larger_leaf_valid) {
FindBestSplitsDiscretizedForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, true>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[1]>>>
(FindBestSplitsDiscretizedForLeafKernel_ARGS);
}
} else {
// TODO(shiyu1994)
}
}
#undef LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS
#undef LaunchFindBestSplitsDiscretizedForLeafKernel_ARGS
#undef FindBestSplitsDiscretizedForLeafKernel_ARGS
__device__ void ReduceBestSplit(bool* found, double* gain, uint32_t* shared_read_index,
uint32_t num_features_aligned) {
const uint32_t threadIdx_x = threadIdx.x;

Просмотреть файл

@ -67,7 +67,11 @@ class CUDABestSplitFinder {
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf);
const double sum_hessians_in_larger_leaf,
const score_t* grad_scale,
const score_t* hess_scale,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins);
const CUDASplitInfo* FindBestFromAllSplits(
const int cur_num_leaves,
@ -114,6 +118,31 @@ class CUDABestSplitFinder {
#undef LaunchFindBestSplitsForLeafKernel_PARAMS
#define LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS \
const CUDALeafSplitsStruct* smaller_leaf_splits, \
const CUDALeafSplitsStruct* larger_leaf_splits, \
const int smaller_leaf_index, \
const int larger_leaf_index, \
const bool is_smaller_leaf_valid, \
const bool is_larger_leaf_valid, \
const score_t* grad_scale, \
const score_t* hess_scale, \
const uint8_t smaller_num_bits_in_histogram_bins, \
const uint8_t larger_num_bits_in_histogram_bins
void LaunchFindBestSplitsDiscretizedForLeafKernel(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
template <bool USE_RAND>
void LaunchFindBestSplitsDiscretizedForLeafKernelInner0(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
template <bool USE_RAND, bool USE_L1>
void LaunchFindBestSplitsDiscretizedForLeafKernelInner1(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
void LaunchFindBestSplitsDiscretizedForLeafKernelInner2(LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS);
#undef LaunchFindBestSplitsDiscretizedForLeafKernel_PARAMS
void LaunchSyncBestSplitForLeafKernel(
const int host_smaller_leaf_index,
const int host_larger_leaf_index,

Просмотреть файл

@ -368,6 +368,12 @@ void CUDADataPartition::ResetByLeafPred(const std::vector<int>& leaf_pred, int n
cur_num_leaves_ = num_leaves;
}
void CUDADataPartition::ReduceLeafGradStat(
const score_t* gradients, const score_t* hessians,
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const {
LaunchReduceLeafGradStat(gradients, hessians, tree, leaf_grad_stat_buffer, leaf_hess_state_buffer);
}
} // namespace LightGBM
#endif // USE_CUDA

Просмотреть файл

@ -1069,6 +1069,53 @@ void CUDADataPartition::LaunchAddPredictionToScoreKernel(const double* leaf_valu
global_timer.Stop("CUDADataPartition::AddPredictionToScoreKernel");
}
__global__ void RenewDiscretizedTreeLeavesKernel(
const score_t* gradients,
const score_t* hessians,
const data_size_t* data_indices,
const data_size_t* leaf_data_start,
const data_size_t* leaf_num_data,
double* leaf_grad_stat_buffer,
double* leaf_hess_stat_buffer,
double* leaf_values) {
__shared__ double shared_mem_buffer[32];
const int leaf_index = static_cast<int>(blockIdx.x);
const data_size_t* data_indices_in_leaf = data_indices + leaf_data_start[leaf_index];
const data_size_t num_data_in_leaf = leaf_num_data[leaf_index];
double sum_gradients = 0.0f;
double sum_hessians = 0.0f;
for (data_size_t inner_data_index = static_cast<int>(threadIdx.x);
inner_data_index < num_data_in_leaf; inner_data_index += static_cast<int>(blockDim.x)) {
const data_size_t data_index = data_indices_in_leaf[inner_data_index];
const score_t gradient = gradients[data_index];
const score_t hessian = hessians[data_index];
sum_gradients += static_cast<double>(gradient);
sum_hessians += static_cast<double>(hessian);
}
sum_gradients = ShuffleReduceSum<double>(sum_gradients, shared_mem_buffer, blockDim.x);
__syncthreads();
sum_hessians = ShuffleReduceSum<double>(sum_hessians, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
leaf_grad_stat_buffer[leaf_index] = sum_gradients;
leaf_hess_stat_buffer[leaf_index] = sum_hessians;
}
}
void CUDADataPartition::LaunchReduceLeafGradStat(
const score_t* gradients, const score_t* hessians,
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const {
const int num_blocks = tree->num_leaves();
RenewDiscretizedTreeLeavesKernel<<<num_blocks, FILL_INDICES_BLOCK_SIZE_DATA_PARTITION>>>(
gradients,
hessians,
cuda_data_indices_,
cuda_leaf_data_start_,
cuda_leaf_num_data_,
leaf_grad_stat_buffer,
leaf_hess_state_buffer,
tree->cuda_leaf_value_ref());
}
} // namespace LightGBM
#endif // USE_CUDA

Просмотреть файл

@ -78,6 +78,10 @@ class CUDADataPartition {
void ResetByLeafPred(const std::vector<int>& leaf_pred, int num_leaves);
void ReduceLeafGradStat(
const score_t* gradients, const score_t* hessians,
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const;
data_size_t root_num_data() const {
if (use_bagging_) {
return num_used_indices_;
@ -292,6 +296,10 @@ class CUDADataPartition {
void LaunchFillDataIndexToLeafIndex();
void LaunchReduceLeafGradStat(
const score_t* gradients, const score_t* hessians,
CUDATree* tree, double* leaf_grad_stat_buffer, double* leaf_hess_state_buffer) const;
// Host memory
// dataset information

Просмотреть файл

@ -0,0 +1,171 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA
#include <algorithm>
#include <LightGBM/cuda/cuda_algorithms.hpp>
#include "cuda_gradient_discretizer.hpp"
namespace LightGBM {
__global__ void ReduceMinMaxKernel(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians,
score_t* grad_min_block_buffer,
score_t* grad_max_block_buffer,
score_t* hess_min_block_buffer,
score_t* hess_max_block_buffer) {
__shared__ score_t shared_mem_buffer[32];
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
score_t grad_max_val = kMinScore;
score_t grad_min_val = kMaxScore;
score_t hess_max_val = kMinScore;
score_t hess_min_val = kMaxScore;
if (index < num_data) {
grad_max_val = input_gradients[index];
grad_min_val = input_gradients[index];
hess_max_val = input_hessians[index];
hess_min_val = input_hessians[index];
}
grad_min_val = ShuffleReduceMin<score_t>(grad_min_val, shared_mem_buffer, blockDim.x);
__syncthreads();
grad_max_val = ShuffleReduceMax<score_t>(grad_max_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_min_val = ShuffleReduceMin<score_t>(hess_min_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
grad_min_block_buffer[blockIdx.x] = grad_min_val;
grad_max_block_buffer[blockIdx.x] = grad_max_val;
hess_min_block_buffer[blockIdx.x] = hess_min_val;
hess_max_block_buffer[blockIdx.x] = hess_max_val;
}
}
__global__ void ReduceBlockMinMaxKernel(
const int num_blocks,
const int grad_discretize_bins,
score_t* grad_min_block_buffer,
score_t* grad_max_block_buffer,
score_t* hess_min_block_buffer,
score_t* hess_max_block_buffer) {
__shared__ score_t shared_mem_buffer[32];
score_t grad_max_val = kMinScore;
score_t grad_min_val = kMaxScore;
score_t hess_max_val = kMinScore;
score_t hess_min_val = kMaxScore;
for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks; block_index += static_cast<int>(blockDim.x)) {
grad_min_val = min(grad_min_val, grad_min_block_buffer[block_index]);
grad_max_val = max(grad_max_val, grad_max_block_buffer[block_index]);
hess_min_val = min(hess_min_val, hess_min_block_buffer[block_index]);
hess_max_val = max(hess_max_val, hess_max_block_buffer[block_index]);
}
grad_min_val = ShuffleReduceMin<score_t>(grad_min_val, shared_mem_buffer, blockDim.x);
__syncthreads();
grad_max_val = ShuffleReduceMax<score_t>(grad_max_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
__syncthreads();
hess_max_val = ShuffleReduceMax<score_t>(hess_max_val, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
const score_t grad_abs_max = max(fabs(grad_min_val), fabs(grad_max_val));
const score_t hess_abs_max = max(fabs(hess_min_val), fabs(hess_max_val));
grad_min_block_buffer[0] = 1.0f / (grad_abs_max / (grad_discretize_bins / 2));
grad_max_block_buffer[0] = (grad_abs_max / (grad_discretize_bins / 2));
hess_min_block_buffer[0] = 1.0f / (hess_abs_max / (grad_discretize_bins));
hess_max_block_buffer[0] = (hess_abs_max / (grad_discretize_bins));
}
}
template <bool STOCHASTIC_ROUNDING>
__global__ void DiscretizeGradientsKernel(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians,
const score_t* grad_scale_ptr,
const score_t* hess_scale_ptr,
const int iter,
const int* random_values_use_start,
const score_t* gradient_random_values,
const score_t* hessian_random_values,
const int grad_discretize_bins,
int8_t* output_gradients_and_hessians) {
const int start = random_values_use_start[iter];
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
const score_t grad_scale = *grad_scale_ptr;
const score_t hess_scale = *hess_scale_ptr;
int16_t* output_gradients_and_hessians_ptr = reinterpret_cast<int16_t*>(output_gradients_and_hessians);
if (index < num_data) {
if (STOCHASTIC_ROUNDING) {
const data_size_t index_offset = (index + start) % num_data;
const score_t gradient = input_gradients[index];
const score_t hessian = input_hessians[index];
const score_t gradient_random_value = gradient_random_values[index_offset];
const score_t hessian_random_value = hessian_random_values[index_offset];
output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ?
static_cast<int16_t>(gradient * grad_scale + gradient_random_value) :
static_cast<int16_t>(gradient * grad_scale - gradient_random_value);
output_gradients_and_hessians_ptr[2 * index] = static_cast<int16_t>(hessian * hess_scale + hessian_random_value);
} else {
const score_t gradient = input_gradients[index];
const score_t hessian = input_hessians[index];
output_gradients_and_hessians_ptr[2 * index + 1] = gradient > 0.0f ?
static_cast<int16_t>(gradient * grad_scale + 0.5) :
static_cast<int16_t>(gradient * grad_scale - 0.5);
output_gradients_and_hessians_ptr[2 * index] = static_cast<int16_t>(hessian * hess_scale + 0.5);
}
}
}
void CUDAGradientDiscretizer::DiscretizeGradients(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians) {
ReduceMinMaxKernel<<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(
num_data, input_gradients, input_hessians,
grad_min_block_buffer_.RawData(),
grad_max_block_buffer_.RawData(),
hess_min_block_buffer_.RawData(),
hess_max_block_buffer_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__);
ReduceBlockMinMaxKernel<<<1, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(
num_reduce_blocks_,
num_grad_quant_bins_,
grad_min_block_buffer_.RawData(),
grad_max_block_buffer_.RawData(),
hess_min_block_buffer_.RawData(),
hess_max_block_buffer_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__);
#define DiscretizeGradientsKernel_ARGS \
num_data, \
input_gradients, \
input_hessians, \
grad_min_block_buffer_.RawData(), \
hess_min_block_buffer_.RawData(), \
iter_, \
random_values_use_start_.RawData(), \
gradient_random_values_.RawData(), \
hessian_random_values_.RawData(), \
num_grad_quant_bins_, \
discretized_gradients_and_hessians_.RawData()
if (stochastic_rounding_) {
DiscretizeGradientsKernel<true><<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(DiscretizeGradientsKernel_ARGS);
} else {
DiscretizeGradientsKernel<false><<<num_reduce_blocks_, CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE>>>(DiscretizeGradientsKernel_ARGS);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
++iter_;
}
} // namespace LightGBM
#endif // USE_CUDA

Просмотреть файл

@ -0,0 +1,118 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
#define LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_
#ifdef USE_CUDA
#include <LightGBM/bin.h>
#include <LightGBM/meta.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/utils/threading.h>
#include <algorithm>
#include <random>
#include <vector>
#include "cuda_leaf_splits.hpp"
#include "../gradient_discretizer.hpp"
namespace LightGBM {
#define CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE (1024)
class CUDAGradientDiscretizer: public GradientDiscretizer {
public:
CUDAGradientDiscretizer(int num_grad_quant_bins, int num_trees, int random_seed, bool is_constant_hessian, bool stochastic_roudning):
GradientDiscretizer(num_grad_quant_bins, num_trees, random_seed, is_constant_hessian, stochastic_roudning) {
}
void DiscretizeGradients(
const data_size_t num_data,
const score_t* input_gradients,
const score_t* input_hessians) override;
const int8_t* discretized_gradients_and_hessians() const override { return discretized_gradients_and_hessians_.RawData(); }
double grad_scale() const override {
Log::Fatal("grad_scale() of CUDAGradientDiscretizer should not be called.");
return 0.0;
}
double hess_scale() const override {
Log::Fatal("hess_scale() of CUDAGradientDiscretizer should not be called.");
return 0.0;
}
const score_t* grad_scale_ptr() const { return grad_max_block_buffer_.RawData(); }
const score_t* hess_scale_ptr() const { return hess_max_block_buffer_.RawData(); }
void Init(const data_size_t num_data, const int num_leaves,
const int num_features, const Dataset* train_data) override {
GradientDiscretizer::Init(num_data, num_leaves, num_features, train_data);
discretized_gradients_and_hessians_.Resize(num_data * 2);
num_reduce_blocks_ = (num_data + CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE - 1) / CUDA_GRADIENT_DISCRETIZER_BLOCK_SIZE;
grad_min_block_buffer_.Resize(num_reduce_blocks_);
grad_max_block_buffer_.Resize(num_reduce_blocks_);
hess_min_block_buffer_.Resize(num_reduce_blocks_);
hess_max_block_buffer_.Resize(num_reduce_blocks_);
random_values_use_start_.Resize(num_trees_);
gradient_random_values_.Resize(num_data);
hessian_random_values_.Resize(num_data);
std::vector<score_t> gradient_random_values(num_data, 0.0f);
std::vector<score_t> hessian_random_values(num_data, 0.0f);
std::vector<int> random_values_use_start(num_trees_, 0);
const int num_threads = OMP_NUM_THREADS();
std::mt19937 random_values_use_start_eng = std::mt19937(random_seed_);
std::uniform_int_distribution<data_size_t> random_values_use_start_dist = std::uniform_int_distribution<data_size_t>(0, num_data);
for (int tree_index = 0; tree_index < num_trees_; ++tree_index) {
random_values_use_start[tree_index] = random_values_use_start_dist(random_values_use_start_eng);
}
int num_blocks = 0;
data_size_t block_size = 0;
Threading::BlockInfo<data_size_t>(num_data, 512, &num_blocks, &block_size);
#pragma omp parallel for schedule(static, 1) num_threads(num_threads)
for (int thread_id = 0; thread_id < num_blocks; ++thread_id) {
const data_size_t start = thread_id * block_size;
const data_size_t end = std::min(start + block_size, num_data);
std::mt19937 gradient_random_values_eng(random_seed_ + thread_id);
std::uniform_real_distribution<double> gradient_random_values_dist(0.0f, 1.0f);
std::mt19937 hessian_random_values_eng(random_seed_ + thread_id + num_threads);
std::uniform_real_distribution<double> hessian_random_values_dist(0.0f, 1.0f);
for (data_size_t i = start; i < end; ++i) {
gradient_random_values[i] = gradient_random_values_dist(gradient_random_values_eng);
hessian_random_values[i] = hessian_random_values_dist(hessian_random_values_eng);
}
}
CopyFromHostToCUDADevice<score_t>(gradient_random_values_.RawData(), gradient_random_values.data(), gradient_random_values.size(), __FILE__, __LINE__);
CopyFromHostToCUDADevice<score_t>(hessian_random_values_.RawData(), hessian_random_values.data(), hessian_random_values.size(), __FILE__, __LINE__);
CopyFromHostToCUDADevice<int>(random_values_use_start_.RawData(), random_values_use_start.data(), random_values_use_start.size(), __FILE__, __LINE__);
iter_ = 0;
}
protected:
mutable CUDAVector<int8_t> discretized_gradients_and_hessians_;
mutable CUDAVector<score_t> grad_min_block_buffer_;
mutable CUDAVector<score_t> grad_max_block_buffer_;
mutable CUDAVector<score_t> hess_min_block_buffer_;
mutable CUDAVector<score_t> hess_max_block_buffer_;
CUDAVector<int> random_values_use_start_;
CUDAVector<score_t> gradient_random_values_;
CUDAVector<score_t> hessian_random_values_;
int num_reduce_blocks_;
};
} // namespace LightGBM
#endif // USE_CUDA
#endif // LIGHTGBM_TREELEARNER_CUDA_CUDA_GRADIENT_DISCRETIZER_HPP_

Просмотреть файл

@ -20,7 +20,9 @@ CUDAHistogramConstructor::CUDAHistogramConstructor(
const int min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const int gpu_device_id,
const bool gpu_use_dp):
const bool gpu_use_dp,
const bool use_quantized_grad,
const int num_grad_quant_bins):
num_data_(train_data->num_data()),
num_features_(train_data->num_features()),
num_leaves_(num_leaves),
@ -28,24 +30,14 @@ CUDAHistogramConstructor::CUDAHistogramConstructor(
min_data_in_leaf_(min_data_in_leaf),
min_sum_hessian_in_leaf_(min_sum_hessian_in_leaf),
gpu_device_id_(gpu_device_id),
gpu_use_dp_(gpu_use_dp) {
gpu_use_dp_(gpu_use_dp),
use_quantized_grad_(use_quantized_grad),
num_grad_quant_bins_(num_grad_quant_bins) {
InitFeatureMetaInfo(train_data, feature_hist_offsets);
cuda_row_data_.reset(nullptr);
cuda_feature_num_bins_ = nullptr;
cuda_feature_hist_offsets_ = nullptr;
cuda_feature_most_freq_bins_ = nullptr;
cuda_hist_ = nullptr;
cuda_need_fix_histogram_features_ = nullptr;
cuda_need_fix_histogram_features_num_bin_aligned_ = nullptr;
}
CUDAHistogramConstructor::~CUDAHistogramConstructor() {
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__);
}
@ -84,54 +76,70 @@ void CUDAHistogramConstructor::InitFeatureMetaInfo(const Dataset* train_data, co
void CUDAHistogramConstructor::BeforeTrain(const score_t* gradients, const score_t* hessians) {
cuda_gradients_ = gradients;
cuda_hessians_ = hessians;
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
cuda_hist_.SetValue(0);
}
void CUDAHistogramConstructor::Init(const Dataset* train_data, TrainingShareStates* share_state) {
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
cuda_hist_.SetValue(0);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_,
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_,
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__);
cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_);
cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_);
cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_);
cuda_row_data_.reset(new CUDARowData(train_data, share_state, gpu_device_id_, gpu_use_dp_));
cuda_row_data_->Init(train_data, share_state);
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_));
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(),
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_);
cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_);
if (cuda_row_data_->NumLargeBinPartition() > 0) {
int grid_dim_x = 0, grid_dim_y = 0, block_dim_x = 0, block_dim_y = 0;
CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_);
const size_t buffer_size = static_cast<size_t>(grid_dim_y) * static_cast<size_t>(num_total_bin_) * 2;
AllocateCUDAMemory<float>(&cuda_hist_buffer_, buffer_size, __FILE__, __LINE__);
const size_t buffer_size = static_cast<size_t>(grid_dim_y) * static_cast<size_t>(num_total_bin_);
if (!use_quantized_grad_) {
if (gpu_use_dp_) {
// need to double the size of histogram buffer in global memory when using double precision in histogram construction
cuda_hist_buffer_.Resize(buffer_size * 4);
} else {
cuda_hist_buffer_.Resize(buffer_size * 2);
}
} else {
// use only half the size of histogram buffer in global memory when quantized training since each gradient and hessian takes only 2 bytes
cuda_hist_buffer_.Resize(buffer_size);
}
}
hist_buffer_for_num_bit_change_.Resize(num_total_bin_ * 2);
}
void CUDAHistogramConstructor::ConstructHistogramForLeaf(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const CUDALeafSplitsStruct* /*cuda_larger_leaf_splits*/,
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf) {
const double sum_hessians_in_larger_leaf,
const uint8_t num_bits_in_histogram_bins) {
if ((num_data_in_smaller_leaf <= min_data_in_leaf_ || sum_hessians_in_smaller_leaf <= min_sum_hessian_in_leaf_) &&
(num_data_in_larger_leaf <= min_data_in_leaf_ || sum_hessians_in_larger_leaf <= min_sum_hessian_in_leaf_)) {
return;
}
LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDAHistogramConstructor::SubtractHistogramForLeaf(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const bool use_quantized_grad,
const uint8_t parent_num_bits_in_histogram_bins,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins) {
global_timer.Start("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits);
LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits, use_quantized_grad,
parent_num_bits_in_histogram_bins, smaller_num_bits_in_histogram_bins, larger_num_bits_in_histogram_bins);
global_timer.Stop("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
}
@ -152,33 +160,18 @@ void CUDAHistogramConstructor::ResetTrainingData(const Dataset* train_data, Trai
num_data_ = train_data->num_data();
num_features_ = train_data->num_features();
InitFeatureMetaInfo(train_data, share_states->feature_hist_offsets());
if (feature_num_bins_.size() > 0) {
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
}
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_,
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_,
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__);
cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
cuda_hist_.SetValue(0);
cuda_feature_num_bins_.InitFromHostVector(feature_num_bins_);
cuda_feature_hist_offsets_.InitFromHostVector(feature_hist_offsets_);
cuda_feature_most_freq_bins_.InitFromHostVector(feature_most_freq_bins_);
cuda_row_data_.reset(new CUDARowData(train_data, share_states, gpu_device_id_, gpu_use_dp_));
cuda_row_data_->Init(train_data, share_states);
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(),
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
cuda_need_fix_histogram_features_.InitFromHostVector(need_fix_histogram_features_);
cuda_need_fix_histogram_features_num_bin_aligned_.InitFromHostVector(need_fix_histogram_features_num_bin_aligend_);
}
void CUDAHistogramConstructor::ResetConfig(const Config* config) {
@ -186,9 +179,8 @@ void CUDAHistogramConstructor::ResetConfig(const Config* config) {
num_leaves_ = config->num_leaves;
min_data_in_leaf_ = config->min_data_in_leaf;
min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf;
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
cuda_hist_.Resize(static_cast<size_t>(num_total_bin_ * 2 * num_leaves_));
cuda_hist_.SetValue(0);
}
} // namespace LightGBM

Просмотреть файл

@ -125,7 +125,7 @@ __global__ void CUDAConstructHistogramSparseKernel(
}
}
template <typename BIN_TYPE>
template <typename BIN_TYPE, typename HIST_TYPE>
__global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const score_t* cuda_gradients,
@ -135,7 +135,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
const uint32_t* column_hist_offsets_full,
const int* feature_partition_column_index_offsets,
const data_size_t num_data,
float* global_hist_buffer) {
HIST_TYPE* global_hist_buffer) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
@ -150,7 +150,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
const int num_total_bin = column_hist_offsets_full[gridDim.x];
float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
HIST_TYPE* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist[i] = 0.0f;
}
@ -166,14 +166,14 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
float* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1);
HIST_TYPE* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const score_t grad = cuda_gradients[data_index];
const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[static_cast<size_t>(data_index) * num_columns_in_partition + threadIdx.x]);
const uint32_t pos = bin << 1;
float* pos_ptr = shared_hist_ptr + pos;
HIST_TYPE* pos_ptr = shared_hist_ptr + pos;
atomicAdd_block(pos_ptr, grad);
atomicAdd_block(pos_ptr + 1, hess);
inner_data_index += blockDim.y;
@ -186,7 +186,7 @@ __global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
}
}
template <typename BIN_TYPE, typename DATA_PTR_TYPE>
template <typename BIN_TYPE, typename HIST_TYPE, typename DATA_PTR_TYPE>
__global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const score_t* cuda_gradients,
@ -196,7 +196,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
const DATA_PTR_TYPE* partition_ptr,
const uint32_t* column_hist_offsets_full,
const data_size_t num_data,
float* global_hist_buffer) {
HIST_TYPE* global_hist_buffer) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
@ -209,7 +209,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
const int num_total_bin = column_hist_offsets_full[gridDim.x];
float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
HIST_TYPE* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist[i] = 0.0f;
}
@ -233,7 +233,7 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
const uint32_t pos = bin << 1;
float* pos_ptr = shared_hist + pos;
HIST_TYPE* pos_ptr = shared_hist + pos;
atomicAdd_block(pos_ptr, grad);
atomicAdd_block(pos_ptr + 1, hess);
}
@ -246,13 +246,278 @@ __global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
}
}
template <typename BIN_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
__global__ void CUDAConstructDiscretizedHistogramDenseKernel(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const int32_t* cuda_gradients_and_hessians,
const BIN_TYPE* data,
const uint32_t* column_hist_offsets,
const uint32_t* column_hist_offsets_full,
const int* feature_partition_column_index_offsets,
const data_size_t num_data) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
__shared__ int16_t shared_hist[SHARED_HIST_SIZE];
int32_t* shared_hist_packed = reinterpret_cast<int32_t*>(shared_hist);
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
const BIN_TYPE* data_ptr = data + partition_column_start * num_data;
const int num_columns_in_partition = partition_column_end - partition_column_start;
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist_packed[i] = 0;
}
__syncthreads();
const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
const data_size_t remainder = block_num_data % blockDim.y;
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
int32_t* shared_hist_ptr = shared_hist_packed + (column_hist_offsets[column_index]);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]);
int32_t* pos_ptr = shared_hist_ptr + bin;
atomicAdd_block(pos_ptr, grad_and_hess);
inner_data_index += blockDim.y;
}
}
__syncthreads();
if (USE_16BIT_HIST) {
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
}
} else {
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
}
}
}
template <typename BIN_TYPE, typename DATA_PTR_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
__global__ void CUDAConstructDiscretizedHistogramSparseKernel(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const int32_t* cuda_gradients_and_hessians,
const BIN_TYPE* data,
const DATA_PTR_TYPE* row_ptr,
const DATA_PTR_TYPE* partition_ptr,
const uint32_t* column_hist_offsets_full,
const data_size_t num_data) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
__shared__ int16_t shared_hist[SHARED_HIST_SIZE];
int32_t* shared_hist_packed = reinterpret_cast<int32_t*>(shared_hist);
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1);
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist_packed[i] = 0.0f;
}
__syncthreads();
const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
const data_size_t remainder = block_num_data % blockDim.y;
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const DATA_PTR_TYPE row_start = block_row_ptr[data_index];
const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1];
const DATA_PTR_TYPE row_size = row_end - row_start;
if (threadIdx.x < row_size) {
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
int32_t* pos_ptr = shared_hist_packed + bin;
atomicAdd_block(pos_ptr, grad_and_hess);
}
inner_data_index += blockDim.y;
}
__syncthreads();
if (USE_16BIT_HIST) {
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
}
} else {
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
}
}
}
template <typename BIN_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
__global__ void CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const int32_t* cuda_gradients_and_hessians,
const BIN_TYPE* data,
const uint32_t* column_hist_offsets,
const uint32_t* column_hist_offsets_full,
const int* feature_partition_column_index_offsets,
const data_size_t num_data,
int32_t* global_hist_buffer) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
const BIN_TYPE* data_ptr = data + partition_column_start * num_data;
const int num_columns_in_partition = partition_column_end - partition_column_start;
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
const int num_total_bin = column_hist_offsets_full[gridDim.x];
int32_t* shared_hist_packed = global_hist_buffer + (blockIdx.y * num_total_bin + partition_column_start);
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist_packed[i] = 0;
}
__syncthreads();
const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
const data_size_t remainder = block_num_data % blockDim.y;
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
int32_t* shared_hist_ptr = shared_hist_packed + (column_hist_offsets[column_index]);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]);
int32_t* pos_ptr = shared_hist_ptr + bin;
atomicAdd_block(pos_ptr, grad_and_hess);
inner_data_index += blockDim.y;
}
}
__syncthreads();
if (USE_16BIT_HIST) {
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
}
} else {
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
}
}
}
template <typename BIN_TYPE, typename DATA_PTR_TYPE, int SHARED_HIST_SIZE, bool USE_16BIT_HIST>
__global__ void CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const int32_t* cuda_gradients_and_hessians,
const BIN_TYPE* data,
const DATA_PTR_TYPE* row_ptr,
const DATA_PTR_TYPE* partition_ptr,
const uint32_t* column_hist_offsets_full,
const data_size_t num_data,
int32_t* global_hist_buffer) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
const int num_total_bin = column_hist_offsets_full[gridDim.x];
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1);
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start);
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
int32_t* shared_hist_packed = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start);
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist_packed[i] = 0.0f;
}
__syncthreads();
const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
const data_size_t remainder = block_num_data % blockDim.y;
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const DATA_PTR_TYPE row_start = block_row_ptr[data_index];
const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1];
const DATA_PTR_TYPE row_size = row_end - row_start;
if (threadIdx.x < row_size) {
const int32_t grad_and_hess = cuda_gradients_and_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
int32_t* pos_ptr = shared_hist_packed + bin;
atomicAdd_block(pos_ptr, grad_and_hess);
}
inner_data_index += blockDim.y;
}
__syncthreads();
if (USE_16BIT_HIST) {
int32_t* feature_histogram_ptr = reinterpret_cast<int32_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
atomicAdd_system(feature_histogram_ptr + i, packed_grad_hess);
}
} else {
atomic_add_long_t* feature_histogram_ptr = reinterpret_cast<atomic_add_long_t*>(smaller_leaf_splits->hist_in_leaf) + partition_hist_start;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
const int32_t packed_grad_hess = shared_hist_packed[i];
const int64_t packed_grad_hess_int64 = (static_cast<int64_t>(static_cast<int16_t>(packed_grad_hess >> 16)) << 32) | (static_cast<int64_t>(packed_grad_hess & 0x0000ffff));
atomicAdd_system(feature_histogram_ptr + i, (atomic_add_long_t)(packed_grad_hess_int64));
}
}
}
void CUDAHistogramConstructor::LaunchConstructHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins) {
if (cuda_row_data_->shared_hist_size() == DP_SHARED_HIST_SIZE && gpu_use_dp_) {
LaunchConstructHistogramKernelInner<double, DP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner<double, DP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else if (cuda_row_data_->shared_hist_size() == SP_SHARED_HIST_SIZE && !gpu_use_dp_) {
LaunchConstructHistogramKernelInner<float, SP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner<float, SP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else {
Log::Fatal("Unknown shared histogram size %d", cuda_row_data_->shared_hist_size());
}
@ -261,13 +526,14 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernel(
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins) {
if (cuda_row_data_->bit_type() == 8) {
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint8_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint8_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else if (cuda_row_data_->bit_type() == 16) {
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else if (cuda_row_data_->bit_type() == 32) {
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else {
Log::Fatal("Unknown bit_type = %d", cuda_row_data_->bit_type());
}
@ -276,16 +542,17 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner(
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins) {
if (cuda_row_data_->row_ptr_bit_type() == 16) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else if (cuda_row_data_->row_ptr_bit_type() == 32) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else if (cuda_row_data_->row_ptr_bit_type() == 64) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint64_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint64_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else {
if (!cuda_row_data_->is_sparse()) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else {
Log::Fatal("Unknown row_ptr_bit_type = %d", cuda_row_data_->row_ptr_bit_type());
}
@ -295,18 +562,20 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0(
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner1(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins) {
if (cuda_row_data_->NumLargeBinPartition() == 0) {
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, false>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, false>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
} else {
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, true>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, true>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf, num_bits_in_histogram_bins);
}
}
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE, bool USE_GLOBAL_MEM_BUFFER>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins) {
int grid_dim_x = 0;
int grid_dim_y = 0;
int block_dim_x = 0;
@ -314,47 +583,139 @@ void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_in_smaller_leaf);
dim3 grid_dim(grid_dim_x, grid_dim_y);
dim3 block_dim(block_dim_x, block_dim_y);
if (!USE_GLOBAL_MEM_BUFFER) {
if (cuda_row_data_->is_sparse()) {
CUDAConstructHistogramSparseKernel<BIN_TYPE, PTR_TYPE, HIST_TYPE, SHARED_HIST_SIZE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_);
if (use_quantized_grad_) {
if (USE_GLOBAL_MEM_BUFFER) {
if (cuda_row_data_->is_sparse()) {
if (num_bits_in_histogram_bins <= 16) {
CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_,
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
} else {
CUDAConstructDiscretizedHistogramSparseKernel_GlobalMemory<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_,
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
}
} else {
if (num_bits_in_histogram_bins <= 16) {
CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory<BIN_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_,
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
} else {
CUDAConstructDiscretizedHistogramDenseKernel_GlobalMemory<BIN_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_,
reinterpret_cast<int32_t*>(cuda_hist_buffer_.RawData()));
}
}
} else {
CUDAConstructHistogramDenseKernel<BIN_TYPE, HIST_TYPE, SHARED_HIST_SIZE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_);
if (cuda_row_data_->is_sparse()) {
if (num_bits_in_histogram_bins <= 16) {
CUDAConstructDiscretizedHistogramSparseKernel<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_);
} else {
CUDAConstructDiscretizedHistogramSparseKernel<BIN_TYPE, PTR_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_);
}
} else {
if (num_bits_in_histogram_bins <= 16) {
CUDAConstructDiscretizedHistogramDenseKernel<BIN_TYPE, SHARED_HIST_SIZE, true><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_);
} else {
CUDAConstructDiscretizedHistogramDenseKernel<BIN_TYPE, SHARED_HIST_SIZE, false><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
reinterpret_cast<const int32_t*>(cuda_gradients_),
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_);
}
}
}
} else {
if (cuda_row_data_->is_sparse()) {
CUDAConstructHistogramSparseKernel_GlobalMemory<BIN_TYPE, PTR_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_,
cuda_hist_buffer_);
if (!USE_GLOBAL_MEM_BUFFER) {
if (cuda_row_data_->is_sparse()) {
CUDAConstructHistogramSparseKernel<BIN_TYPE, PTR_TYPE, HIST_TYPE, SHARED_HIST_SIZE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_);
} else {
CUDAConstructHistogramDenseKernel<BIN_TYPE, HIST_TYPE, SHARED_HIST_SIZE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_);
}
} else {
CUDAConstructHistogramDenseKernel_GlobalMemory<BIN_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_,
cuda_hist_buffer_);
if (cuda_row_data_->is_sparse()) {
CUDAConstructHistogramSparseKernel_GlobalMemory<BIN_TYPE, HIST_TYPE, PTR_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_,
reinterpret_cast<HIST_TYPE*>(cuda_hist_buffer_.RawData()));
} else {
CUDAConstructHistogramDenseKernel_GlobalMemory<BIN_TYPE, HIST_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_,
reinterpret_cast<HIST_TYPE*>(cuda_hist_buffer_.RawData()));
}
}
}
}
@ -403,28 +764,195 @@ __global__ void FixHistogramKernel(
}
}
template <bool SMALLER_USE_16BIT_HIST, bool LARGER_USE_16BIT_HIST, bool PARENT_USE_16BIT_HIST>
__global__ void SubtractHistogramDiscretizedKernel(
const int num_total_bin,
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
hist_t* num_bit_change_buffer) {
const unsigned int global_thread_index = threadIdx.x + blockIdx.x * blockDim.x;
const int cuda_larger_leaf_index_ref = cuda_larger_leaf_splits->leaf_index;
if (cuda_larger_leaf_index_ref >= 0) {
if (PARENT_USE_16BIT_HIST) {
const int32_t* smaller_leaf_hist = reinterpret_cast<const int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
int32_t* larger_leaf_hist = reinterpret_cast<int32_t*>(cuda_larger_leaf_splits->hist_in_leaf);
if (global_thread_index < num_total_bin) {
larger_leaf_hist[global_thread_index] -= smaller_leaf_hist[global_thread_index];
}
} else if (LARGER_USE_16BIT_HIST) {
int32_t* buffer = reinterpret_cast<int32_t*>(num_bit_change_buffer);
const int32_t* smaller_leaf_hist = reinterpret_cast<const int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
int64_t* larger_leaf_hist = reinterpret_cast<int64_t*>(cuda_larger_leaf_splits->hist_in_leaf);
if (global_thread_index < num_total_bin) {
const int64_t parent_hist_item = larger_leaf_hist[global_thread_index];
const int32_t smaller_hist_item = smaller_leaf_hist[global_thread_index];
const int64_t smaller_hist_item_int64 = (static_cast<int64_t>(static_cast<int16_t>(smaller_hist_item >> 16)) << 32) |
static_cast<int64_t>(smaller_hist_item & 0x0000ffff);
const int64_t larger_hist_item = parent_hist_item - smaller_hist_item_int64;
buffer[global_thread_index] = static_cast<int32_t>(static_cast<int16_t>(larger_hist_item >> 32) << 16) |
static_cast<int32_t>(larger_hist_item & 0x000000000000ffff);
}
} else if (SMALLER_USE_16BIT_HIST) {
const int32_t* smaller_leaf_hist = reinterpret_cast<const int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
int64_t* larger_leaf_hist = reinterpret_cast<int64_t*>(cuda_larger_leaf_splits->hist_in_leaf);
if (global_thread_index < num_total_bin) {
const int64_t parent_hist_item = larger_leaf_hist[global_thread_index];
const int32_t smaller_hist_item = smaller_leaf_hist[global_thread_index];
const int64_t smaller_hist_item_int64 = (static_cast<int64_t>(static_cast<int16_t>(smaller_hist_item >> 16)) << 32) |
static_cast<int64_t>(smaller_hist_item & 0x0000ffff);
const int64_t larger_hist_item = parent_hist_item - smaller_hist_item_int64;
larger_leaf_hist[global_thread_index] = larger_hist_item;
}
} else {
const int64_t* smaller_leaf_hist = reinterpret_cast<const int64_t*>(cuda_smaller_leaf_splits->hist_in_leaf);
int64_t* larger_leaf_hist = reinterpret_cast<int64_t*>(cuda_larger_leaf_splits->hist_in_leaf);
if (global_thread_index < num_total_bin) {
larger_leaf_hist[global_thread_index] -= smaller_leaf_hist[global_thread_index];
}
}
}
}
__global__ void CopyChangedNumBitHistogram(
const int num_total_bin,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
hist_t* num_bit_change_buffer) {
int32_t* hist_dst = reinterpret_cast<int32_t*>(cuda_larger_leaf_splits->hist_in_leaf);
const int32_t* hist_src = reinterpret_cast<const int32_t*>(num_bit_change_buffer);
const unsigned int global_thread_index = threadIdx.x + blockIdx.x * blockDim.x;
if (global_thread_index < static_cast<unsigned int>(num_total_bin)) {
hist_dst[global_thread_index] = hist_src[global_thread_index];
}
}
template <bool USE_16BIT_HIST>
__global__ void FixHistogramDiscretizedKernel(
const uint32_t* cuda_feature_num_bins,
const uint32_t* cuda_feature_hist_offsets,
const uint32_t* cuda_feature_most_freq_bins,
const int* cuda_need_fix_histogram_features,
const uint32_t* cuda_need_fix_histogram_features_num_bin_aligned,
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits) {
__shared__ int64_t shared_mem_buffer[32];
const unsigned int blockIdx_x = blockIdx.x;
const int feature_index = cuda_need_fix_histogram_features[blockIdx_x];
const uint32_t num_bin_aligned = cuda_need_fix_histogram_features_num_bin_aligned[blockIdx_x];
const uint32_t feature_hist_offset = cuda_feature_hist_offsets[feature_index];
const uint32_t most_freq_bin = cuda_feature_most_freq_bins[feature_index];
if (USE_16BIT_HIST) {
const int64_t leaf_sum_gradients_hessians_int64 = cuda_smaller_leaf_splits->sum_of_gradients_hessians;
const int32_t leaf_sum_gradients_hessians =
(static_cast<int32_t>(leaf_sum_gradients_hessians_int64 >> 32) << 16) | static_cast<int32_t>(leaf_sum_gradients_hessians_int64 & 0x000000000000ffff);
int32_t* feature_hist = reinterpret_cast<int32_t*>(cuda_smaller_leaf_splits->hist_in_leaf) + feature_hist_offset;
const unsigned int threadIdx_x = threadIdx.x;
const uint32_t num_bin = cuda_feature_num_bins[feature_index];
const int32_t bin_gradient_hessian = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[threadIdx_x] : 0;
const int32_t sum_gradient_hessian = ShuffleReduceSum<int32_t>(
bin_gradient_hessian,
reinterpret_cast<int32_t*>(shared_mem_buffer),
num_bin_aligned);
if (threadIdx_x == 0) {
feature_hist[most_freq_bin] = leaf_sum_gradients_hessians - sum_gradient_hessian;
}
} else {
const int64_t leaf_sum_gradients_hessians = cuda_smaller_leaf_splits->sum_of_gradients_hessians;
int64_t* feature_hist = reinterpret_cast<int64_t*>(cuda_smaller_leaf_splits->hist_in_leaf) + feature_hist_offset;
const unsigned int threadIdx_x = threadIdx.x;
const uint32_t num_bin = cuda_feature_num_bins[feature_index];
const int64_t bin_gradient_hessian = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[threadIdx_x] : 0;
const int64_t sum_gradient_hessian = ShuffleReduceSum<int64_t>(bin_gradient_hessian, shared_mem_buffer, num_bin_aligned);
if (threadIdx_x == 0) {
feature_hist[most_freq_bin] = leaf_sum_gradients_hessians - sum_gradient_hessian;
}
}
}
void CUDAHistogramConstructor::LaunchSubtractHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits) {
const int num_subtract_threads = 2 * num_total_bin_;
const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE;
global_timer.Start("CUDAHistogramConstructor::FixHistogramKernel");
if (need_fix_histogram_features_.size() > 0) {
FixHistogramKernel<<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
cuda_feature_num_bins_,
cuda_feature_hist_offsets_,
cuda_feature_most_freq_bins_,
cuda_need_fix_histogram_features_,
cuda_need_fix_histogram_features_num_bin_aligned_,
cuda_smaller_leaf_splits);
}
global_timer.Stop("CUDAHistogramConstructor::FixHistogramKernel");
global_timer.Start("CUDAHistogramConstructor::SubtractHistogramKernel");
SubtractHistogramKernel<<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_smaller_leaf_splits,
cuda_larger_leaf_splits);
global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramKernel");
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const bool use_discretized_grad,
const uint8_t parent_num_bits_in_histogram_bins,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins) {
if (!use_discretized_grad) {
const int num_subtract_threads = 2 * num_total_bin_;
const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE;
global_timer.Start("CUDAHistogramConstructor::FixHistogramKernel");
if (need_fix_histogram_features_.size() > 0) {
FixHistogramKernel<<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
cuda_feature_num_bins_.RawData(),
cuda_feature_hist_offsets_.RawData(),
cuda_feature_most_freq_bins_.RawData(),
cuda_need_fix_histogram_features_.RawData(),
cuda_need_fix_histogram_features_num_bin_aligned_.RawData(),
cuda_smaller_leaf_splits);
}
global_timer.Stop("CUDAHistogramConstructor::FixHistogramKernel");
global_timer.Start("CUDAHistogramConstructor::SubtractHistogramKernel");
SubtractHistogramKernel<<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_smaller_leaf_splits,
cuda_larger_leaf_splits);
global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramKernel");
} else {
const int num_subtract_threads = num_total_bin_;
const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE;
global_timer.Start("CUDAHistogramConstructor::FixHistogramDiscretizedKernel");
if (need_fix_histogram_features_.size() > 0) {
if (smaller_num_bits_in_histogram_bins <= 16) {
FixHistogramDiscretizedKernel<true><<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
cuda_feature_num_bins_.RawData(),
cuda_feature_hist_offsets_.RawData(),
cuda_feature_most_freq_bins_.RawData(),
cuda_need_fix_histogram_features_.RawData(),
cuda_need_fix_histogram_features_num_bin_aligned_.RawData(),
cuda_smaller_leaf_splits);
} else {
FixHistogramDiscretizedKernel<false><<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
cuda_feature_num_bins_.RawData(),
cuda_feature_hist_offsets_.RawData(),
cuda_feature_most_freq_bins_.RawData(),
cuda_need_fix_histogram_features_.RawData(),
cuda_need_fix_histogram_features_num_bin_aligned_.RawData(),
cuda_smaller_leaf_splits);
}
}
global_timer.Stop("CUDAHistogramConstructor::FixHistogramDiscretizedKernel");
global_timer.Start("CUDAHistogramConstructor::SubtractHistogramDiscretizedKernel");
if (parent_num_bits_in_histogram_bins <= 16) {
CHECK_LE(smaller_num_bits_in_histogram_bins, 16);
CHECK_LE(larger_num_bits_in_histogram_bins, 16);
SubtractHistogramDiscretizedKernel<true, true, true><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_smaller_leaf_splits,
cuda_larger_leaf_splits,
hist_buffer_for_num_bit_change_.RawData());
} else if (larger_num_bits_in_histogram_bins <= 16) {
CHECK_LE(smaller_num_bits_in_histogram_bins, 16);
SubtractHistogramDiscretizedKernel<true, true, false><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_smaller_leaf_splits,
cuda_larger_leaf_splits,
hist_buffer_for_num_bit_change_.RawData());
CopyChangedNumBitHistogram<<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_larger_leaf_splits,
hist_buffer_for_num_bit_change_.RawData());
} else if (smaller_num_bits_in_histogram_bins <= 16) {
SubtractHistogramDiscretizedKernel<true, false, false><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_smaller_leaf_splits,
cuda_larger_leaf_splits,
hist_buffer_for_num_bit_change_.RawData());
} else {
SubtractHistogramDiscretizedKernel<false, false, false><<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_smaller_leaf_splits,
cuda_larger_leaf_splits,
hist_buffer_for_num_bit_change_.RawData());
}
global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramDiscretizedKernel");
}
}
} // namespace LightGBM

Просмотреть файл

@ -9,6 +9,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_row_data.hpp>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/feature_group.h>
#include <LightGBM/tree.h>
@ -37,7 +38,9 @@ class CUDAHistogramConstructor {
const int min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const int gpu_device_id,
const bool gpu_use_dp);
const bool gpu_use_dp,
const bool use_discretized_grad,
const int grad_discretized_bins);
~CUDAHistogramConstructor();
@ -49,7 +52,16 @@ class CUDAHistogramConstructor {
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf);
const double sum_hessians_in_larger_leaf,
const uint8_t num_bits_in_histogram_bins);
void SubtractHistogramForLeaf(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const bool use_discretized_grad,
const uint8_t parent_num_bits_in_histogram_bins,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins);
void ResetTrainingData(const Dataset* train_data, TrainingShareStates* share_states);
@ -57,9 +69,9 @@ class CUDAHistogramConstructor {
void BeforeTrain(const score_t* gradients, const score_t* hessians);
const hist_t* cuda_hist() const { return cuda_hist_; }
const hist_t* cuda_hist() const { return cuda_hist_.RawData(); }
hist_t* cuda_hist_pointer() { return cuda_hist_; }
hist_t* cuda_hist_pointer() { return cuda_hist_.RawData(); }
private:
void InitFeatureMetaInfo(const Dataset* train_data, const std::vector<uint32_t>& feature_hist_offsets);
@ -74,30 +86,39 @@ class CUDAHistogramConstructor {
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE>
void LaunchConstructHistogramKernelInner(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf);
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE>
void LaunchConstructHistogramKernelInner0(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf);
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE>
void LaunchConstructHistogramKernelInner1(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf);
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE, bool USE_GLOBAL_MEM_BUFFER>
void LaunchConstructHistogramKernelInner2(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf);
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
void LaunchConstructHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf);
const data_size_t num_data_in_smaller_leaf,
const uint8_t num_bits_in_histogram_bins);
void LaunchSubtractHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits);
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const bool use_discretized_grad,
const uint8_t parent_num_bits_in_histogram_bins,
const uint8_t smaller_num_bits_in_histogram_bins,
const uint8_t larger_num_bits_in_histogram_bins);
// Host memory
@ -136,19 +157,21 @@ class CUDAHistogramConstructor {
/*! \brief CUDA row wise data */
std::unique_ptr<CUDARowData> cuda_row_data_;
/*! \brief number of bins per feature */
uint32_t* cuda_feature_num_bins_;
CUDAVector<uint32_t> cuda_feature_num_bins_;
/*! \brief offsets in histogram of all features */
uint32_t* cuda_feature_hist_offsets_;
CUDAVector<uint32_t> cuda_feature_hist_offsets_;
/*! \brief most frequent bins in each feature */
uint32_t* cuda_feature_most_freq_bins_;
CUDAVector<uint32_t> cuda_feature_most_freq_bins_;
/*! \brief CUDA histograms */
hist_t* cuda_hist_;
CUDAVector<hist_t> cuda_hist_;
/*! \brief CUDA histograms buffer for each block */
float* cuda_hist_buffer_;
CUDAVector<float> cuda_hist_buffer_;
/*! \brief indices of feature whose histograms need to be fixed */
int* cuda_need_fix_histogram_features_;
CUDAVector<int> cuda_need_fix_histogram_features_;
/*! \brief aligned number of bins of the features whose histograms need to be fixed */
uint32_t* cuda_need_fix_histogram_features_num_bin_aligned_;
CUDAVector<uint32_t> cuda_need_fix_histogram_features_num_bin_aligned_;
/*! \brief histogram buffer used in histogram subtraction with different number of bits for histogram bins */
CUDAVector<hist_t> hist_buffer_for_num_bit_change_;
// CUDA memory, held by other object
@ -161,6 +184,10 @@ class CUDAHistogramConstructor {
const int gpu_device_id_;
/*! \brief use double precision histogram per block */
const bool gpu_use_dp_;
/*! \brief whether to use quantized gradients */
const bool use_quantized_grad_;
/*! \brief the number of bins to quantized gradients */
const int num_grad_quant_bins_;
};
} // namespace LightGBM

Просмотреть файл

@ -11,27 +11,22 @@
namespace LightGBM {
CUDALeafSplits::CUDALeafSplits(const data_size_t num_data):
num_data_(num_data) {
cuda_struct_ = nullptr;
cuda_sum_of_gradients_buffer_ = nullptr;
cuda_sum_of_hessians_buffer_ = nullptr;
}
num_data_(num_data) {}
CUDALeafSplits::~CUDALeafSplits() {
DeallocateCUDAMemory<CUDALeafSplitsStruct>(&cuda_struct_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__);
}
CUDALeafSplits::~CUDALeafSplits() {}
void CUDALeafSplits::Init() {
void CUDALeafSplits::Init(const bool use_quantized_grad) {
num_blocks_init_from_gradients_ = (num_data_ + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
// allocate more memory for sum reduction in CUDA
// only the first element records the final sum
AllocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
if (use_quantized_grad) {
cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
}
AllocateCUDAMemory<CUDALeafSplitsStruct>(&cuda_struct_, 1, __FILE__, __LINE__);
cuda_struct_.Resize(1);
}
void CUDALeafSplits::InitValues() {
@ -46,24 +41,33 @@ void CUDALeafSplits::InitValues(
const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf, double* root_sum_hessians) {
cuda_gradients_ = cuda_gradients;
cuda_hessians_ = cuda_hessians;
SetCUDAMemory<double>(cuda_sum_of_gradients_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_sum_of_hessians_buffer_, 0, num_blocks_init_from_gradients_, __FILE__, __LINE__);
cuda_sum_of_gradients_buffer_.SetValue(0);
cuda_sum_of_hessians_buffer_.SetValue(0);
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf);
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_, 1, __FILE__, __LINE__);
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDALeafSplits::InitValues(
const double lambda_l1, const double lambda_l2,
const int16_t* cuda_gradients_and_hessians,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
const score_t* grad_scale, const score_t* hess_scale) {
cuda_gradients_ = reinterpret_cast<const score_t*>(cuda_gradients_and_hessians);
cuda_hessians_ = nullptr;
LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf, grad_scale, hess_scale);
CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDALeafSplits::Resize(const data_size_t num_data) {
if (num_data > num_data_) {
DeallocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, __FILE__, __LINE__);
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
AllocateCUDAMemory<double>(&cuda_sum_of_gradients_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_sum_of_hessians_buffer_, num_blocks_init_from_gradients_, __FILE__, __LINE__);
} else {
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
}
num_data_ = num_data;
num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
}
} // namespace LightGBM

Просмотреть файл

@ -81,6 +81,90 @@ __global__ void CUDAInitValuesKernel2(
}
}
template <bool USE_INDICES>
__global__ void CUDAInitValuesKernel3(const int16_t* cuda_gradients_and_hessians,
const data_size_t num_data, const data_size_t* cuda_bagging_data_indices,
double* cuda_sum_of_gradients, double* cuda_sum_of_hessians, int64_t* cuda_sum_of_hessians_hessians,
const score_t* grad_scale_pointer, const score_t* hess_scale_pointer) {
const score_t grad_scale = *grad_scale_pointer;
const score_t hess_scale = *hess_scale_pointer;
__shared__ int64_t shared_mem_buffer[32];
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
int64_t int_gradient = 0;
int64_t int_hessian = 0;
if (data_index < num_data) {
int_gradient = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index] + 1] :
cuda_gradients_and_hessians[2 * data_index + 1];
int_hessian = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index]] :
cuda_gradients_and_hessians[2 * data_index];
}
const int64_t block_sum_gradient = ShuffleReduceSum<int64_t>(int_gradient, shared_mem_buffer, blockDim.x);
__syncthreads();
const int64_t block_sum_hessian = ShuffleReduceSum<int64_t>(int_hessian, shared_mem_buffer, blockDim.x);
if (threadIdx.x == 0) {
cuda_sum_of_gradients[blockIdx.x] = block_sum_gradient * grad_scale;
cuda_sum_of_hessians[blockIdx.x] = block_sum_hessian * hess_scale;
cuda_sum_of_hessians_hessians[blockIdx.x] = ((block_sum_gradient << 32) | block_sum_hessian);
}
}
__global__ void CUDAInitValuesKernel4(
const double lambda_l1,
const double lambda_l2,
const int num_blocks_to_reduce,
double* cuda_sum_of_gradients,
double* cuda_sum_of_hessians,
int64_t* cuda_sum_of_gradients_hessians,
const data_size_t num_data,
const data_size_t* cuda_data_indices_in_leaf,
hist_t* cuda_hist_in_leaf,
CUDALeafSplitsStruct* cuda_struct) {
__shared__ double shared_mem_buffer[32];
double thread_sum_of_gradients = 0.0f;
double thread_sum_of_hessians = 0.0f;
int64_t thread_sum_of_gradients_hessians = 0;
for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks_to_reduce; block_index += static_cast<int>(blockDim.x)) {
thread_sum_of_gradients += cuda_sum_of_gradients[block_index];
thread_sum_of_hessians += cuda_sum_of_hessians[block_index];
thread_sum_of_gradients_hessians += cuda_sum_of_gradients_hessians[block_index];
}
const double sum_of_gradients = ShuffleReduceSum<double>(thread_sum_of_gradients, shared_mem_buffer, blockDim.x);
__syncthreads();
const double sum_of_hessians = ShuffleReduceSum<double>(thread_sum_of_hessians, shared_mem_buffer, blockDim.x);
__syncthreads();
const double sum_of_gradients_hessians = ShuffleReduceSum<int64_t>(
thread_sum_of_gradients_hessians,
reinterpret_cast<int64_t*>(shared_mem_buffer),
blockDim.x);
if (threadIdx.x == 0) {
cuda_sum_of_hessians[0] = sum_of_hessians;
cuda_struct->leaf_index = 0;
cuda_struct->sum_of_gradients = sum_of_gradients;
cuda_struct->sum_of_hessians = sum_of_hessians;
cuda_struct->sum_of_gradients_hessians = sum_of_gradients_hessians;
cuda_struct->num_data_in_leaf = num_data;
const bool use_l1 = lambda_l1 > 0.0f;
if (!use_l1) {
// no smoothing on root node
cuda_struct->gain = CUDALeafSplits::GetLeafGain<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
} else {
// no smoothing on root node
cuda_struct->gain = CUDALeafSplits::GetLeafGain<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
}
if (!use_l1) {
// no smoothing on root node
cuda_struct->leaf_value =
CUDALeafSplits::CalculateSplittedLeafOutput<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
} else {
// no smoothing on root node
cuda_struct->leaf_value =
CUDALeafSplits::CalculateSplittedLeafOutput<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
}
cuda_struct->data_indices_in_leaf = cuda_data_indices_in_leaf;
cuda_struct->hist_in_leaf = cuda_hist_in_leaf;
}
}
__global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) {
cuda_struct->leaf_index = -1;
cuda_struct->sum_of_gradients = 0.0f;
@ -93,7 +177,7 @@ __global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) {
}
void CUDALeafSplits::LaunchInitValuesEmptyKernel() {
InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_);
InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_.RawData());
}
void CUDALeafSplits::LaunchInitValuesKernal(
@ -104,23 +188,55 @@ void CUDALeafSplits::LaunchInitValuesKernal(
hist_t* cuda_hist_in_leaf) {
if (cuda_bagging_data_indices == nullptr) {
CUDAInitValuesKernel1<false><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_,
cuda_sum_of_hessians_buffer_);
cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData());
} else {
CUDAInitValuesKernel1<true><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_,
cuda_sum_of_hessians_buffer_);
cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData());
}
SynchronizeCUDADevice(__FILE__, __LINE__);
CUDAInitValuesKernel2<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
lambda_l1, lambda_l2,
num_blocks_init_from_gradients_,
cuda_sum_of_gradients_buffer_,
cuda_sum_of_hessians_buffer_,
cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData(),
num_used_indices,
cuda_data_indices_in_leaf,
cuda_hist_in_leaf,
cuda_struct_);
cuda_struct_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDALeafSplits::LaunchInitValuesKernal(
const double lambda_l1, const double lambda_l2,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf,
const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf,
const score_t* grad_scale,
const score_t* hess_scale) {
if (cuda_bagging_data_indices == nullptr) {
CUDAInitValuesKernel3<false><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
} else {
CUDAInitValuesKernel3<true><<<num_blocks_init_from_gradients_, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
CUDAInitValuesKernel4<<<1, NUM_THRADS_PER_BLOCK_LEAF_SPLITS>>>(
lambda_l1, lambda_l2,
num_blocks_init_from_gradients_,
cuda_sum_of_gradients_buffer_.RawData(),
cuda_sum_of_hessians_buffer_.RawData(),
cuda_sum_of_gradients_hessians_buffer_.RawData(),
num_used_indices,
cuda_data_indices_in_leaf,
cuda_hist_in_leaf,
cuda_struct_.RawData());
SynchronizeCUDADevice(__FILE__, __LINE__);
}

Просмотреть файл

@ -8,7 +8,7 @@
#ifdef USE_CUDA
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/bin.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/meta.h>
@ -23,6 +23,7 @@ struct CUDALeafSplitsStruct {
int leaf_index;
double sum_of_gradients;
double sum_of_hessians;
int64_t sum_of_gradients_hessians;
data_size_t num_data_in_leaf;
double gain;
double leaf_value;
@ -36,7 +37,7 @@ class CUDALeafSplits {
~CUDALeafSplits();
void Init();
void Init(const bool use_quantized_grad);
void InitValues(
const double lambda_l1, const double lambda_l2,
@ -45,11 +46,19 @@ class CUDALeafSplits {
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians);
void InitValues(
const double lambda_l1, const double lambda_l2,
const int16_t* cuda_gradients_and_hessians,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf, double* root_sum_hessians,
const score_t* grad_scale, const score_t* hess_scale);
void InitValues();
const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_; }
const CUDALeafSplitsStruct* GetCUDAStruct() const { return cuda_struct_.RawDataReadOnly(); }
CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_; }
CUDALeafSplitsStruct* GetCUDAStructRef() { return cuda_struct_.RawData(); }
void Resize(const data_size_t num_data);
@ -140,14 +149,24 @@ class CUDALeafSplits {
const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf);
void LaunchInitValuesKernal(
const double lambda_l1, const double lambda_l2,
const data_size_t* cuda_bagging_data_indices,
const data_size_t* cuda_data_indices_in_leaf,
const data_size_t num_used_indices,
hist_t* cuda_hist_in_leaf,
const score_t* grad_scale,
const score_t* hess_scale);
// Host memory
data_size_t num_data_;
int num_blocks_init_from_gradients_;
// CUDA memory, held by this object
CUDALeafSplitsStruct* cuda_struct_;
double* cuda_sum_of_gradients_buffer_;
double* cuda_sum_of_hessians_buffer_;
CUDAVector<CUDALeafSplitsStruct> cuda_struct_;
CUDAVector<double> cuda_sum_of_gradients_buffer_;
CUDAVector<double> cuda_sum_of_hessians_buffer_;
CUDAVector<int64_t> cuda_sum_of_gradients_hessians_buffer_;
// CUDA memory, held by other object
const score_t* cuda_gradients_;

Просмотреть файл

@ -9,7 +9,7 @@
#include "cuda_single_gpu_tree_learner.hpp"
#include <LightGBM/cuda/cuda_tree.hpp>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_utils.hu>
#include <LightGBM/feature_group.h>
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>
@ -39,13 +39,14 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
SetCUDADevice(gpu_device_id_, __FILE__, __LINE__);
cuda_smaller_leaf_splits_.reset(new CUDALeafSplits(num_data_));
cuda_smaller_leaf_splits_->Init();
cuda_smaller_leaf_splits_->Init(config_->use_quantized_grad);
cuda_larger_leaf_splits_.reset(new CUDALeafSplits(num_data_));
cuda_larger_leaf_splits_->Init();
cuda_larger_leaf_splits_->Init(config_->use_quantized_grad);
cuda_histogram_constructor_.reset(new CUDAHistogramConstructor(train_data_, config_->num_leaves, num_threads_,
share_state_->feature_hist_offsets(),
config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp));
config_->min_data_in_leaf, config_->min_sum_hessian_in_leaf, gpu_device_id_, config_->gpu_use_dp,
config_->use_quantized_grad, config_->num_grad_quant_bins));
cuda_histogram_constructor_->Init(train_data_, share_state_.get());
const auto& feature_hist_offsets = share_state_->feature_hist_offsets();
@ -73,11 +74,19 @@ void CUDASingleGPUTreeLearner::Init(const Dataset* train_data, bool is_constant_
}
AllocateBitset();
cuda_leaf_gradient_stat_buffer_ = nullptr;
cuda_leaf_hessian_stat_buffer_ = nullptr;
leaf_stat_buffer_size_ = 0;
num_cat_threshold_ = 0;
if (config_->use_quantized_grad) {
cuda_leaf_gradient_stat_buffer_.Resize(config_->num_leaves);
cuda_leaf_hessian_stat_buffer_.Resize(config_->num_leaves);
cuda_gradient_discretizer_.reset(new CUDAGradientDiscretizer(
config_->num_grad_quant_bins, config_->num_iterations, config_->seed, is_constant_hessian, config_->stochastic_rounding));
cuda_gradient_discretizer_->Init(num_data_, config_->num_leaves, train_data_->num_features(), train_data_);
} else {
cuda_gradient_discretizer_.reset(nullptr);
}
#ifdef DEBUG
host_gradients_.resize(num_data_, 0.0f);
host_hessians_.resize(num_data_, 0.0f);
@ -101,19 +110,37 @@ void CUDASingleGPUTreeLearner::BeforeTrain() {
const data_size_t* leaf_splits_init_indices =
cuda_data_partition_->use_bagging() ? cuda_data_partition_->cuda_data_indices() : nullptr;
cuda_data_partition_->BeforeTrain();
cuda_smaller_leaf_splits_->InitValues(
config_->lambda_l1,
config_->lambda_l2,
gradients_,
hessians_,
leaf_splits_init_indices,
cuda_data_partition_->cuda_data_indices(),
root_num_data,
cuda_histogram_constructor_->cuda_hist_pointer(),
&leaf_sum_hessians_[0]);
if (config_->use_quantized_grad) {
cuda_gradient_discretizer_->DiscretizeGradients(num_data_, gradients_, hessians_);
cuda_histogram_constructor_->BeforeTrain(
reinterpret_cast<const score_t*>(cuda_gradient_discretizer_->discretized_gradients_and_hessians()), nullptr);
cuda_smaller_leaf_splits_->InitValues(
config_->lambda_l1,
config_->lambda_l2,
reinterpret_cast<const int16_t*>(cuda_gradient_discretizer_->discretized_gradients_and_hessians()),
leaf_splits_init_indices,
cuda_data_partition_->cuda_data_indices(),
root_num_data,
cuda_histogram_constructor_->cuda_hist_pointer(),
&leaf_sum_hessians_[0],
cuda_gradient_discretizer_->grad_scale_ptr(),
cuda_gradient_discretizer_->hess_scale_ptr());
cuda_gradient_discretizer_->SetNumBitsInHistogramBin<false>(0, -1, root_num_data, 0);
} else {
cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_);
cuda_smaller_leaf_splits_->InitValues(
config_->lambda_l1,
config_->lambda_l2,
gradients_,
hessians_,
leaf_splits_init_indices,
cuda_data_partition_->cuda_data_indices(),
root_num_data,
cuda_histogram_constructor_->cuda_hist_pointer(),
&leaf_sum_hessians_[0]);
}
leaf_num_data_[0] = root_num_data;
cuda_larger_leaf_splits_->InitValues();
cuda_histogram_constructor_->BeforeTrain(gradients_, hessians_);
col_sampler_.ResetByTree();
cuda_best_split_finder_->BeforeTrain(col_sampler_.is_feature_used_bytree());
leaf_data_start_[0] = 0;
@ -141,24 +168,70 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
const data_size_t num_data_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_num_data_[larger_leaf_index_];
const double sum_hessians_in_smaller_leaf = leaf_sum_hessians_[smaller_leaf_index_];
const double sum_hessians_in_larger_leaf = larger_leaf_index_ < 0 ? 0 : leaf_sum_hessians_[larger_leaf_index_];
const uint8_t num_bits_in_histogram_bins = config_->use_quantized_grad ? cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_) : 0;
cuda_histogram_constructor_->ConstructHistogramForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(),
num_data_in_smaller_leaf,
num_data_in_larger_leaf,
sum_hessians_in_smaller_leaf,
sum_hessians_in_larger_leaf);
sum_hessians_in_larger_leaf,
num_bits_in_histogram_bins);
global_timer.Stop("CUDASingleGPUTreeLearner::ConstructHistogramForLeaf");
global_timer.Start("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf");
SelectFeatureByNode(tree.get());
cuda_best_split_finder_->FindBestSplitsForLeaf(
uint8_t parent_num_bits_bin = 0;
uint8_t smaller_num_bits_bin = 0;
uint8_t larger_num_bits_bin = 0;
if (config_->use_quantized_grad) {
if (larger_leaf_index_ != -1) {
const int parent_leaf_index = std::min(smaller_leaf_index_, larger_leaf_index_);
parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInNode<false>(parent_leaf_index);
smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_);
larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(larger_leaf_index_);
} else {
parent_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
smaller_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
larger_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(0);
}
} else {
parent_num_bits_bin = 0;
smaller_num_bits_bin = 0;
larger_num_bits_bin = 0;
}
cuda_histogram_constructor_->SubtractHistogramForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(),
smaller_leaf_index_, larger_leaf_index_,
num_data_in_smaller_leaf, num_data_in_larger_leaf,
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf);
config_->use_quantized_grad,
parent_num_bits_bin,
smaller_num_bits_bin,
larger_num_bits_bin);
SelectFeatureByNode(tree.get());
if (config_->use_quantized_grad) {
const uint8_t smaller_leaf_num_bits_bin = cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(smaller_leaf_index_);
const uint8_t larger_leaf_num_bits_bin = larger_leaf_index_ < 0 ? 32 : cuda_gradient_discretizer_->GetHistBitsInLeaf<false>(larger_leaf_index_);
cuda_best_split_finder_->FindBestSplitsForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(),
smaller_leaf_index_, larger_leaf_index_,
num_data_in_smaller_leaf, num_data_in_larger_leaf,
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf,
cuda_gradient_discretizer_->grad_scale_ptr(),
cuda_gradient_discretizer_->hess_scale_ptr(),
smaller_leaf_num_bits_bin,
larger_leaf_num_bits_bin);
} else {
cuda_best_split_finder_->FindBestSplitsForLeaf(
cuda_smaller_leaf_splits_->GetCUDAStruct(),
cuda_larger_leaf_splits_->GetCUDAStruct(),
smaller_leaf_index_, larger_leaf_index_,
num_data_in_smaller_leaf, num_data_in_larger_leaf,
sum_hessians_in_smaller_leaf, sum_hessians_in_larger_leaf,
nullptr, nullptr, 0, 0);
}
global_timer.Stop("CUDASingleGPUTreeLearner::FindBestSplitsForLeaf");
global_timer.Start("CUDASingleGPUTreeLearner::FindBestFromAllSplits");
const CUDASplitInfo* best_split_info = nullptr;
@ -247,9 +320,19 @@ Tree* CUDASingleGPUTreeLearner::Train(const score_t* gradients,
#endif // DEBUG
smaller_leaf_index_ = (leaf_num_data_[best_leaf_index_] < leaf_num_data_[right_leaf_index] ? best_leaf_index_ : right_leaf_index);
larger_leaf_index_ = (smaller_leaf_index_ == best_leaf_index_ ? right_leaf_index : best_leaf_index_);
if (config_->use_quantized_grad) {
cuda_gradient_discretizer_->SetNumBitsInHistogramBin<false>(
best_leaf_index_, right_leaf_index, leaf_num_data_[best_leaf_index_], leaf_num_data_[right_leaf_index]);
}
global_timer.Stop("CUDASingleGPUTreeLearner::Split");
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (config_->use_quantized_grad && config_->quant_train_renew_leaf) {
global_timer.Start("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves");
RenewDiscretizedTreeLeaves(tree.get());
global_timer.Stop("CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves");
}
tree->ToHost();
return tree.release();
}
@ -357,8 +440,8 @@ void CUDASingleGPUTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFuncti
Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const {
std::unique_ptr<CUDATree> cuda_tree(new CUDATree(old_tree));
SetCUDAMemory<double>(cuda_leaf_gradient_stat_buffer_, 0, static_cast<size_t>(old_tree->num_leaves()), __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_leaf_hessian_stat_buffer_, 0, static_cast<size_t>(old_tree->num_leaves()), __FILE__, __LINE__);
cuda_leaf_gradient_stat_buffer_.SetValue(0);
cuda_leaf_hessian_stat_buffer_.SetValue(0);
ReduceLeafStat(cuda_tree.get(), gradients, hessians, cuda_data_partition_->cuda_data_indices());
cuda_tree->SyncLeafOutputFromCUDAToHost();
return cuda_tree.release();
@ -373,13 +456,9 @@ Tree* CUDASingleGPUTreeLearner::FitByExistingTree(const Tree* old_tree, const st
const int num_block = (refit_num_data_ + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
buffer_size *= static_cast<data_size_t>(num_block + 1);
}
if (buffer_size != leaf_stat_buffer_size_) {
if (leaf_stat_buffer_size_ != 0) {
DeallocateCUDAMemory<double>(&cuda_leaf_gradient_stat_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_leaf_hessian_stat_buffer_, __FILE__, __LINE__);
}
AllocateCUDAMemory<double>(&cuda_leaf_gradient_stat_buffer_, static_cast<size_t>(buffer_size), __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_leaf_hessian_stat_buffer_, static_cast<size_t>(buffer_size), __FILE__, __LINE__);
if (static_cast<size_t>(buffer_size) > cuda_leaf_gradient_stat_buffer_.Size()) {
cuda_leaf_gradient_stat_buffer_.Resize(buffer_size);
cuda_leaf_hessian_stat_buffer_.Resize(buffer_size);
}
return FitByExistingTree(old_tree, gradients, hessians);
}
@ -513,6 +592,15 @@ void CUDASingleGPUTreeLearner::CheckSplitValid(
}
#endif // DEBUG
void CUDASingleGPUTreeLearner::RenewDiscretizedTreeLeaves(CUDATree* cuda_tree) {
cuda_data_partition_->ReduceLeafGradStat(
gradients_, hessians_, cuda_tree,
cuda_leaf_gradient_stat_buffer_.RawData(),
cuda_leaf_hessian_stat_buffer_.RawData());
LaunchCalcLeafValuesGivenGradStat(cuda_tree, cuda_data_partition_->cuda_data_indices());
SynchronizeCUDADevice(__FILE__, __LINE__);
}
} // namespace LightGBM
#endif // USE_CUDA

Просмотреть файл

@ -129,18 +129,18 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel(
if (num_leaves <= 2048) {
ReduceLeafStatKernel_SharedMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE, 2 * num_leaves * sizeof(double)>>>(
gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(),
cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_);
cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData());
} else {
ReduceLeafStatKernel_GlobalMemory<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(
gradients, hessians, num_leaves, num_data, cuda_data_partition_->cuda_data_index_to_leaf_index(),
cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_);
cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData());
}
const bool use_l1 = config_->lambda_l1 > 0.0f;
const bool use_smoothing = config_->path_smooth > 0.0f;
num_block = (num_leaves + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
#define CalcRefitLeafOutputKernel_ARGS \
num_leaves, cuda_leaf_gradient_stat_buffer_, cuda_leaf_hessian_stat_buffer_, num_data_in_leaf, \
num_leaves, cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \
leaf_parent, left_child, right_child, \
config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \
shrinkage_rate, config_->refit_decay_rate, cuda_leaf_value
@ -162,6 +162,7 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel(
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
}
}
#undef CalcRefitLeafOutputKernel_ARGS
}
template <typename T, bool IS_INNER>
@ -256,6 +257,37 @@ void CUDASingleGPUTreeLearner::LaunchConstructBitsetForCategoricalSplitKernel(
CUDAConstructBitset<int, false>(best_split_info, num_cat_threshold_, cuda_bitset_, cuda_bitset_len_);
}
void CUDASingleGPUTreeLearner::LaunchCalcLeafValuesGivenGradStat(
CUDATree* cuda_tree, const data_size_t* num_data_in_leaf) {
#define CalcRefitLeafOutputKernel_ARGS \
cuda_tree->num_leaves(), cuda_leaf_gradient_stat_buffer_.RawData(), cuda_leaf_hessian_stat_buffer_.RawData(), num_data_in_leaf, \
cuda_tree->cuda_leaf_parent(), cuda_tree->cuda_left_child(), cuda_tree->cuda_right_child(), \
config_->lambda_l1, config_->lambda_l2, config_->path_smooth, \
1.0f, config_->refit_decay_rate, cuda_tree->cuda_leaf_value_ref()
const bool use_l1 = config_->lambda_l1 > 0.0f;
const bool use_smoothing = config_->path_smooth > 0.0f;
const int num_block = (cuda_tree->num_leaves() + CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE - 1) / CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE;
if (!use_l1) {
if (!use_smoothing) {
CalcRefitLeafOutputKernel<false, false>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
} else {
CalcRefitLeafOutputKernel<false, true>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
}
} else {
if (!use_smoothing) {
CalcRefitLeafOutputKernel<true, false>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
} else {
CalcRefitLeafOutputKernel<true, true>
<<<num_block, CUDA_SINGLE_GPU_TREE_LEARNER_BLOCK_SIZE>>>(CalcRefitLeafOutputKernel_ARGS);
}
}
#undef CalcRefitLeafOutputKernel_ARGS
}
} // namespace LightGBM
#endif // USE_CUDA

Просмотреть файл

@ -16,6 +16,7 @@
#include "cuda_data_partition.hpp"
#include "cuda_best_split_finder.hpp"
#include "cuda_gradient_discretizer.hpp"
#include "../serial_tree_learner.h"
namespace LightGBM {
@ -74,6 +75,10 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
const double sum_left_gradients, const double sum_right_gradients);
#endif // DEBUG
void RenewDiscretizedTreeLeaves(CUDATree* cuda_tree);
void LaunchCalcLeafValuesGivenGradStat(CUDATree* cuda_tree, const data_size_t* num_data_in_leaf);
// GPU device ID
int gpu_device_id_;
// number of threads on CPU
@ -90,6 +95,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
std::unique_ptr<CUDAHistogramConstructor> cuda_histogram_constructor_;
// for best split information finding, given the histograms
std::unique_ptr<CUDABestSplitFinder> cuda_best_split_finder_;
// gradient discretizer for quantized training
std::unique_ptr<CUDAGradientDiscretizer> cuda_gradient_discretizer_;
std::vector<int> leaf_best_split_feature_;
std::vector<uint32_t> leaf_best_split_threshold_;
@ -108,8 +115,8 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
std::vector<int> categorical_bin_to_value_;
std::vector<int> categorical_bin_offsets_;
mutable double* cuda_leaf_gradient_stat_buffer_;
mutable double* cuda_leaf_hessian_stat_buffer_;
mutable CUDAVector<double> cuda_leaf_gradient_stat_buffer_;
mutable CUDAVector<double> cuda_leaf_hessian_stat_buffer_;
mutable data_size_t leaf_stat_buffer_size_;
mutable data_size_t refit_num_data_;
uint32_t* cuda_bitset_;