зеркало из https://github.com/microsoft/LightGBM.git
fixed cpplint issues (#2863)
* fixed cpplint errors * fixed more cpplint errors
This commit is contained in:
Родитель
2215d57114
Коммит
d018d30a97
|
@ -333,7 +333,7 @@ test_that("lgb.train() works as expected with sparse features", {
|
|||
num_obs <- 70000L
|
||||
trainDF <- data.frame(
|
||||
y = sample(c(0L, 1L), size = num_obs, replace = TRUE)
|
||||
, x = sample(c(1.0:10.0, rep(NA_real_, 50L)), size = num_obs , replace = TRUE)
|
||||
, x = sample(c(1.0:10.0, rep(NA_real_, 50L)), size = num_obs, replace = TRUE)
|
||||
)
|
||||
dtrain <- lgb.Dataset(
|
||||
data = as.matrix(trainDF[["x"]], drop = FALSE)
|
||||
|
|
|
@ -32,7 +32,7 @@ class FeatureGroup {
|
|||
FeatureGroup(int num_feature, bool is_multi_val,
|
||||
std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
|
||||
data_size_t num_data) : num_feature_(num_feature), is_multi_val_(is_multi_val), is_sparse_(false) {
|
||||
CHECK(static_cast<int>(bin_mappers->size()) == num_feature);
|
||||
CHECK_EQ(static_cast<int>(bin_mappers->size()), num_feature);
|
||||
// use bin at zero to store most_freq_bin
|
||||
num_total_bin_ = 1;
|
||||
bin_offsets_.emplace_back(num_total_bin_);
|
||||
|
@ -65,7 +65,7 @@ class FeatureGroup {
|
|||
|
||||
FeatureGroup(std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
|
||||
data_size_t num_data) : num_feature_(1), is_multi_val_(false) {
|
||||
CHECK(static_cast<int>(bin_mappers->size()) == 1);
|
||||
CHECK_EQ(static_cast<int>(bin_mappers->size()), 1);
|
||||
// use bin at zero to store default_bin
|
||||
num_total_bin_ = 1;
|
||||
bin_offsets_.emplace_back(num_total_bin_);
|
||||
|
|
|
@ -509,7 +509,7 @@ inline static std::vector<T> StringToArray(const std::string& str, int n) {
|
|||
return std::vector<T>();
|
||||
}
|
||||
std::vector<std::string> strs = Split(str.c_str(), ' ');
|
||||
CHECK(strs.size() == static_cast<size_t>(n));
|
||||
CHECK_EQ(strs.size(), static_cast<size_t>(n));
|
||||
std::vector<T> ret;
|
||||
ret.reserve(strs.size());
|
||||
__StringToTHelper<T, std::is_floating_point<T>::value> helper;
|
||||
|
|
|
@ -41,13 +41,13 @@ GBDT::~GBDT() {
|
|||
|
||||
void GBDT::Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
|
||||
const std::vector<const Metric*>& training_metrics) {
|
||||
CHECK(train_data != nullptr);
|
||||
CHECK_NOTNULL(train_data);
|
||||
train_data_ = train_data;
|
||||
if (!config->monotone_constraints.empty()) {
|
||||
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->monotone_constraints.size());
|
||||
CHECK_EQ(static_cast<size_t>(train_data_->num_total_features()), config->monotone_constraints.size());
|
||||
}
|
||||
if (!config->feature_contri.empty()) {
|
||||
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->feature_contri.size());
|
||||
CHECK_EQ(static_cast<size_t>(train_data_->num_total_features()), config->feature_contri.size());
|
||||
}
|
||||
iter_ = 0;
|
||||
num_iteration_for_pred_ = 0;
|
||||
|
@ -112,7 +112,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
|
|||
|
||||
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
|
||||
if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
|
||||
CHECK(num_tree_per_iteration_ == num_class_);
|
||||
CHECK_EQ(num_tree_per_iteration_, num_class_);
|
||||
for (int i = 0; i < num_class_; ++i) {
|
||||
class_need_train_[i] = objective_function_->ClassNeedTrain(i);
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction)
|
|||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < num_data_; ++i) {
|
||||
leaf_pred[i] = tree_leaf_prediction[i][model_index];
|
||||
CHECK(leaf_pred[i] < models_[model_index]->num_leaves());
|
||||
CHECK_LT(leaf_pred[i], models_[model_index]->num_leaves());
|
||||
}
|
||||
size_t offset = static_cast<size_t>(tree_id) * num_data_;
|
||||
auto grad = gradients_.data() + offset;
|
||||
|
@ -654,7 +654,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
|
|||
objective_function_ = objective_function;
|
||||
if (objective_function_ != nullptr) {
|
||||
is_constant_hessian_ = objective_function_->IsConstantHessian();
|
||||
CHECK(num_tree_per_iteration_ == objective_function_->NumModelPerIteration());
|
||||
CHECK_EQ(num_tree_per_iteration_, objective_function_->NumModelPerIteration());
|
||||
} else {
|
||||
is_constant_hessian_ = false;
|
||||
}
|
||||
|
@ -704,10 +704,10 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
|
|||
void GBDT::ResetConfig(const Config* config) {
|
||||
auto new_config = std::unique_ptr<Config>(new Config(*config));
|
||||
if (!config->monotone_constraints.empty()) {
|
||||
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->monotone_constraints.size());
|
||||
CHECK_EQ(static_cast<size_t>(train_data_->num_total_features()), config->monotone_constraints.size());
|
||||
}
|
||||
if (!config->feature_contri.empty()) {
|
||||
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->feature_contri.size());
|
||||
CHECK_EQ(static_cast<size_t>(train_data_->num_total_features()), config->feature_contri.size());
|
||||
}
|
||||
early_stopping_round_ = new_config->early_stopping_round;
|
||||
shrinkage_rate_ = new_config->learning_rate;
|
||||
|
|
|
@ -50,7 +50,7 @@ class GOSS: public GBDT {
|
|||
}
|
||||
|
||||
void ResetGoss() {
|
||||
CHECK(config_->top_rate + config_->other_rate <= 1.0f);
|
||||
CHECK_LE(config_->top_rate + config_->other_rate, 1.0f);
|
||||
CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f);
|
||||
if (config_->bagging_freq > 0 && config_->bagging_fraction != 1.0f) {
|
||||
Log::Fatal("Cannot use bagging in GOSS");
|
||||
|
|
|
@ -41,9 +41,9 @@ class RF : public GBDT {
|
|||
MultiplyScore(cur_tree_id, 1.0f / num_init_iteration_);
|
||||
}
|
||||
} else {
|
||||
CHECK(train_data->metadata().init_score() == nullptr);
|
||||
CHECK_EQ(train_data->metadata().init_score(), nullptr);
|
||||
}
|
||||
CHECK(num_tree_per_iteration_ == num_class_);
|
||||
CHECK_EQ(num_tree_per_iteration_, num_class_);
|
||||
// not shrinkage rate for the RF
|
||||
shrinkage_rate_ = 1.0f;
|
||||
// only boosting one time
|
||||
|
@ -70,7 +70,7 @@ class RF : public GBDT {
|
|||
train_score_updater_->MultiplyScore(1.0f / (iter_ + num_init_iteration_), cur_tree_id);
|
||||
}
|
||||
}
|
||||
CHECK(num_tree_per_iteration_ == num_class_);
|
||||
CHECK_EQ(num_tree_per_iteration_, num_class_);
|
||||
// only boosting one time
|
||||
Boosting();
|
||||
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
|
||||
|
@ -103,8 +103,8 @@ class RF : public GBDT {
|
|||
bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
|
||||
// bagging logic
|
||||
Bagging(iter_);
|
||||
CHECK(gradients == nullptr);
|
||||
CHECK(hessians == nullptr);
|
||||
CHECK_EQ(gradients, nullptr);
|
||||
CHECK_EQ(hessians, nullptr);
|
||||
|
||||
gradients = gradients_.data();
|
||||
hessians = hessians_.data();
|
||||
|
|
|
@ -843,7 +843,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
|
|||
auto idx = sample_indices[i];
|
||||
auto row = get_row_fun(static_cast<int>(idx));
|
||||
for (std::pair<int, double>& inner_data : row) {
|
||||
CHECK(inner_data.first < num_col);
|
||||
CHECK_LT(inner_data.first, num_col);
|
||||
if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
|
||||
sample_values[inner_data.first].emplace_back(inner_data.second);
|
||||
sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
|
||||
|
@ -911,7 +911,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
|
|||
auto idx = sample_indices[i];
|
||||
get_row_fun(static_cast<int>(idx), buffer);
|
||||
for (std::pair<int, double>& inner_data : buffer) {
|
||||
CHECK(inner_data.first < num_col);
|
||||
CHECK_LT(inner_data.first, num_col);
|
||||
if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
|
||||
sample_values[inner_data.first].emplace_back(inner_data.second);
|
||||
sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
|
||||
|
|
|
@ -250,7 +250,7 @@ namespace LightGBM {
|
|||
}
|
||||
bin_upper_bound.insert(bin_upper_bound.end(), bounds_to_add.begin(), bounds_to_add.end());
|
||||
std::stable_sort(bin_upper_bound.begin(), bin_upper_bound.end());
|
||||
CHECK(bin_upper_bound.size() <= static_cast<size_t>(max_bin));
|
||||
CHECK_LE(bin_upper_bound.size(), static_cast<size_t>(max_bin));
|
||||
return bin_upper_bound;
|
||||
}
|
||||
|
||||
|
@ -308,7 +308,7 @@ namespace LightGBM {
|
|||
} else {
|
||||
bin_upper_bound.push_back(std::numeric_limits<double>::infinity());
|
||||
}
|
||||
CHECK(bin_upper_bound.size() <= static_cast<size_t>(max_bin));
|
||||
CHECK_LE(bin_upper_bound.size(), static_cast<size_t>(max_bin));
|
||||
return bin_upper_bound;
|
||||
}
|
||||
|
||||
|
@ -421,7 +421,7 @@ namespace LightGBM {
|
|||
cnt_in_bin[num_bin_ - 1] = na_cnt;
|
||||
}
|
||||
}
|
||||
CHECK(num_bin_ <= max_bin);
|
||||
CHECK_LE(num_bin_, max_bin);
|
||||
} else {
|
||||
// convert to int type first
|
||||
std::vector<int> distinct_values_int;
|
||||
|
|
|
@ -319,7 +319,7 @@ void Dataset::Construct(std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
|
|||
const int* num_per_col, int num_sample_col,
|
||||
size_t total_sample_cnt, const Config& io_config) {
|
||||
num_total_features_ = num_total_features;
|
||||
CHECK(num_total_features_ == static_cast<int>(bin_mappers->size()));
|
||||
CHECK_EQ(num_total_features_, static_cast<int>(bin_mappers->size()));
|
||||
// get num_features
|
||||
std::vector<int> used_features;
|
||||
auto& ref_bin_mappers = *bin_mappers;
|
||||
|
@ -775,7 +775,7 @@ void Dataset::ReSize(data_size_t num_data) {
|
|||
void Dataset::CopySubrow(const Dataset* fullset,
|
||||
const data_size_t* used_indices,
|
||||
data_size_t num_used_indices, bool need_meta_data) {
|
||||
CHECK(num_used_indices == num_data_);
|
||||
CHECK_EQ(num_used_indices, num_data_);
|
||||
OMP_INIT_EX();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int group = 0; group < num_groups_; ++group) {
|
||||
|
|
|
@ -1214,7 +1214,7 @@ std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced
|
|||
std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
|
||||
for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
|
||||
int feature_num = forced_bins_arr[i]["feature"].int_value();
|
||||
CHECK(feature_num < num_total_features);
|
||||
CHECK_LT(feature_num, num_total_features);
|
||||
if (categorical_features.count(feature_num)) {
|
||||
Log::Warning("Feature %d is categorical. Will ignore forced bins for this feature.", feature_num);
|
||||
} else {
|
||||
|
|
|
@ -102,7 +102,7 @@ class MultiValDenseBin : public MultiValBin {
|
|||
data_size_t end, const score_t* gradients,
|
||||
const score_t* hessians, hist_t* out) const override {
|
||||
ConstructHistogramInner<true, true, false>(data_indices, start, end,
|
||||
gradients, hessians, out);
|
||||
gradients, hessians, out);
|
||||
}
|
||||
|
||||
void ConstructHistogram(data_size_t start, data_size_t end,
|
||||
|
@ -118,7 +118,7 @@ class MultiValDenseBin : public MultiValBin {
|
|||
const score_t* hessians,
|
||||
hist_t* out) const override {
|
||||
ConstructHistogramInner<true, true, true>(data_indices, start, end,
|
||||
gradients, hessians, out);
|
||||
gradients, hessians, out);
|
||||
}
|
||||
|
||||
MultiValBin* CreateLike(data_size_t num_data, int num_bin, int num_feature, double) const override {
|
||||
|
@ -144,7 +144,7 @@ class MultiValDenseBin : public MultiValBin {
|
|||
const auto other_bin =
|
||||
reinterpret_cast<const MultiValDenseBin<VAL_T>*>(full_bin);
|
||||
if (SUBROW) {
|
||||
CHECK(num_data_ == num_used_indices);
|
||||
CHECK_EQ(num_data_, num_used_indices);
|
||||
}
|
||||
int n_block = 1;
|
||||
data_size_t block_size = num_data_;
|
||||
|
@ -184,21 +184,21 @@ class MultiValDenseBin : public MultiValBin {
|
|||
}
|
||||
|
||||
void CopySubcol(const MultiValBin* full_bin,
|
||||
const std::vector<int>& used_feature_index,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>& delta) override {
|
||||
const std::vector<int>& used_feature_index,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>& delta) override {
|
||||
CopyInner<false, true>(full_bin, nullptr, num_data_, used_feature_index,
|
||||
delta);
|
||||
}
|
||||
|
||||
void CopySubrowAndSubcol(const MultiValBin* full_bin,
|
||||
const data_size_t* used_indices,
|
||||
data_size_t num_used_indices,
|
||||
const std::vector<int>& used_feature_index,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>& delta) override {
|
||||
const data_size_t* used_indices,
|
||||
data_size_t num_used_indices,
|
||||
const std::vector<int>& used_feature_index,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>&,
|
||||
const std::vector<uint32_t>& delta) override {
|
||||
CopyInner<true, true>(full_bin, used_indices, num_used_indices,
|
||||
used_feature_index, delta);
|
||||
}
|
||||
|
|
|
@ -163,7 +163,7 @@ class MultiValSparseBin : public MultiValBin {
|
|||
data_size_t end, const score_t* gradients,
|
||||
const score_t* hessians, hist_t* out) const override {
|
||||
ConstructHistogramInner<true, true, false>(data_indices, start, end,
|
||||
gradients, hessians, out);
|
||||
gradients, hessians, out);
|
||||
}
|
||||
|
||||
void ConstructHistogram(data_size_t start, data_size_t end,
|
||||
|
@ -179,7 +179,7 @@ class MultiValSparseBin : public MultiValBin {
|
|||
const score_t* hessians,
|
||||
hist_t* out) const override {
|
||||
ConstructHistogramInner<true, true, true>(data_indices, start, end,
|
||||
gradients, hessians, out);
|
||||
gradients, hessians, out);
|
||||
}
|
||||
|
||||
MultiValBin* CreateLike(data_size_t num_data, int num_bin, int,
|
||||
|
@ -219,7 +219,7 @@ class MultiValSparseBin : public MultiValBin {
|
|||
const auto other =
|
||||
reinterpret_cast<const MultiValSparseBin<INDEX_T, VAL_T>*>(full_bin);
|
||||
if (SUBROW) {
|
||||
CHECK(num_data_ == num_used_indices);
|
||||
CHECK_EQ(num_data_, num_used_indices);
|
||||
}
|
||||
int n_block = 1;
|
||||
data_size_t block_size = num_data_;
|
||||
|
|
|
@ -160,7 +160,7 @@ LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
|
|||
int out_len;
|
||||
CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
|
||||
ptr_names.data(), &out_len));
|
||||
CHECK(len == out_len);
|
||||
CHECK_EQ(len, out_len);
|
||||
auto merge_str = Common::Join<char*>(ptr_names, "\t");
|
||||
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
|
||||
R_API_END();
|
||||
|
@ -455,7 +455,7 @@ LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
|
|||
}
|
||||
int out_len;
|
||||
CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
|
||||
CHECK(out_len == len);
|
||||
CHECK_EQ(out_len, len);
|
||||
auto merge_names = Common::Join<char*>(ptr_names, "\t");
|
||||
EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
|
||||
R_API_END();
|
||||
|
@ -471,7 +471,7 @@ LGBM_SE LGBM_BoosterGetEval_R(LGBM_SE handle,
|
|||
double* ptr_ret = R_REAL_PTR(out_result);
|
||||
int out_len;
|
||||
CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
|
||||
CHECK(out_len == len);
|
||||
CHECK_EQ(out_len, len);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
|
|
|
@ -73,8 +73,8 @@ namespace LightGBM {
|
|||
if (pos == 0 || pos == static_cast<size_t>(cnt_data - 1)) { \
|
||||
return data_reader(sorted_idx[pos]); \
|
||||
} \
|
||||
CHECK(threshold >= weighted_cdf[pos - 1]); \
|
||||
CHECK(threshold < weighted_cdf[pos]); \
|
||||
CHECK_GE(threshold, weighted_cdf[pos - 1]); \
|
||||
CHECK_LT(threshold, weighted_cdf[pos]); \
|
||||
T v1 = data_reader(sorted_idx[pos - 1]); \
|
||||
T v2 = data_reader(sorted_idx[pos]); \
|
||||
if (weighted_cdf[pos + 1] - weighted_cdf[pos] >= 1.0f) { \
|
||||
|
|
|
@ -750,10 +750,10 @@ void GPUTreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_c
|
|||
}
|
||||
|
||||
void GPUTreeLearner::ResetIsConstantHessian(bool is_constant_hessian) {
|
||||
if (is_constant_hessian != share_state_->is_constant_hessian) {
|
||||
if (is_constant_hessian != share_state_->is_constant_hessian) {
|
||||
SerialTreeLearner::ResetIsConstantHessian(is_constant_hessian);
|
||||
BuildGPUKernels();
|
||||
SetupKernelArguments();
|
||||
BuildGPUKernels();
|
||||
SetupKernelArguments();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -79,14 +79,14 @@ void SerialTreeLearner::GetShareStates(const Dataset* dataset,
|
|||
ordered_gradients_.data(), ordered_hessians_.data(), used_feature,
|
||||
is_constant_hessian, config_->force_col_wise, config_->force_row_wise));
|
||||
} else {
|
||||
CHECK(share_state_ != nullptr);
|
||||
CHECK_NOTNULL(share_state_);
|
||||
// cannot change is_hist_col_wise during training
|
||||
share_state_.reset(dataset->GetShareStates(
|
||||
ordered_gradients_.data(), ordered_hessians_.data(), is_feature_used_,
|
||||
is_constant_hessian, share_state_->is_colwise,
|
||||
!share_state_->is_colwise));
|
||||
}
|
||||
CHECK(share_state_ != nullptr);
|
||||
CHECK_NOTNULL(share_state_);
|
||||
}
|
||||
|
||||
void SerialTreeLearner::ResetTrainingDataInner(const Dataset* train_data,
|
||||
|
@ -153,14 +153,14 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
|
|||
gradients_ = gradients;
|
||||
hessians_ = hessians;
|
||||
int num_threads = OMP_NUM_THREADS();
|
||||
if (share_state_->num_threads != num_threads && share_state_->num_threads > 0){
|
||||
if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) {
|
||||
Log::Warning(
|
||||
"Detect num_threads changed durning traing (from %d to %d), may cause "
|
||||
"unexpected errors.",
|
||||
"Detected that num_threads changed during training (from %d to %d), "
|
||||
"it may cause unexpected errors.",
|
||||
share_state_->num_threads, num_threads);
|
||||
}
|
||||
share_state_->num_threads = num_threads;
|
||||
|
||||
|
||||
// some initial works before training
|
||||
BeforeTrain();
|
||||
|
||||
|
@ -206,7 +206,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
|
|||
|
||||
Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const {
|
||||
auto tree = std::unique_ptr<Tree>(new Tree(*old_tree));
|
||||
CHECK(data_partition_->num_leaves() >= tree->num_leaves());
|
||||
CHECK_GE(data_partition_->num_leaves(), tree->num_leaves());
|
||||
OMP_INIT_EX();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < tree->num_leaves(); ++i) {
|
||||
|
@ -751,13 +751,12 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
|
|||
best_split_info.monotone_type,
|
||||
best_split_info.right_output,
|
||||
best_split_info.left_output);
|
||||
|
||||
}
|
||||
|
||||
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
|
||||
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
|
||||
if (obj != nullptr && obj->IsRenewTreeOutput()) {
|
||||
CHECK(tree->num_leaves() <= data_partition_->num_leaves());
|
||||
CHECK_LE(tree->num_leaves(), data_partition_->num_leaves());
|
||||
const data_size_t* bag_mapper = nullptr;
|
||||
if (total_num_data != num_data_) {
|
||||
CHECK_EQ(bag_cnt, num_data_);
|
||||
|
|
|
@ -89,7 +89,7 @@ class SerialTreeLearner: public TreeLearner {
|
|||
if (tree->num_leaves() <= 1) {
|
||||
return;
|
||||
}
|
||||
CHECK(tree->num_leaves() <= data_partition_->num_leaves());
|
||||
CHECK_LE(tree->num_leaves(), data_partition_->num_leaves());
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int i = 0; i < tree->num_leaves(); ++i) {
|
||||
double output = static_cast<double>(tree->LeafOutput(i));
|
||||
|
|
Загрузка…
Ссылка в новой задаче