[python] [R-package] Use the same address when updated label/weight/query (#2662)

* Update metadata.cpp

* add version for training set, for efficiently update label/weight/... during training.

* Update lgb.Booster.R
This commit is contained in:
Guolin Ke 2020-01-14 18:44:57 +08:00 коммит произвёл GitHub
Родитель 350d56d593
Коммит 82886ba644
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 30 добавлений и 9 удалений

Просмотреть файл

@ -55,6 +55,7 @@ Booster <- R6::R6Class(
# Create private booster information
private$train_set <- train_set
private$train_set_version <- train_set$.__enclos_env__$private$version
private$num_dataset <- 1L
private$init_predictor <- train_set$.__enclos_env__$private$predictor
@ -207,6 +208,12 @@ Booster <- R6::R6Class(
# Perform boosting update iteration
update = function(train_set = NULL, fobj = NULL) {
if (is.null(train_set)) {
if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
train_set <- private$train_set
}
}
# Check if training set is not null
if (!is.null(train_set)) {
@ -230,6 +237,7 @@ Booster <- R6::R6Class(
# Store private train set
private$train_set <- train_set
private$train_set_version <- train_set$.__enclos_env__$private$version
}
@ -497,6 +505,7 @@ Booster <- R6::R6Class(
eval_names = NULL,
higher_better_inner_eval = NULL,
set_objective_to_none = FALSE,
train_set_version = 0L,
# Predict data
inner_predict = function(idx) {

Просмотреть файл

@ -89,6 +89,7 @@ Dataset <- R6::R6Class(
private$free_raw_data <- free_raw_data
private$used_indices <- sort(used_indices, decreasing = FALSE)
private$info <- info
private$version <- 0L
},
@ -503,6 +504,8 @@ Dataset <- R6::R6Class(
, length(info)
)
private$version <- private$version + 1L
}
}
@ -638,6 +641,7 @@ Dataset <- R6::R6Class(
free_raw_data = TRUE,
used_indices = NULL,
info = NULL,
version = 0L,
# Get handle
get_handle = function() {

Просмотреть файл

@ -771,6 +771,7 @@ class Dataset(object):
self.params_back_up = None
self.feature_penalty = None
self.monotone_constraints = None
self.version = 0
def __del__(self):
try:
@ -1233,6 +1234,7 @@ class Dataset(object):
ptr_data,
ctypes.c_int(len(data)),
ctypes.c_int(type_data)))
self.version += 1
return self
def get_field(self, field_name):
@ -1740,6 +1742,7 @@ class Booster(object):
self.__is_predicted_cur_iter = [False]
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
self.train_set_version = train_set.version
elif model_file is not None:
# Prediction task
out_num_iterations = ctypes.c_int(0)
@ -2076,7 +2079,12 @@ class Booster(object):
Whether the update was successfully finished.
"""
# need reset training data
if train_set is not None and train_set is not self.train_set:
if train_set is None and self.train_set_version != self.train_set.version:
train_set = self.train_set
is_the_same_train_set = False
else:
is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version
if train_set is not None and not is_the_same_train_set:
if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
@ -2088,6 +2096,7 @@ class Booster(object):
self.handle,
self.train_set.construct().handle))
self.__inner_predict_buffer[0] = None
self.train_set_version = self.train_set.version
is_finished = ctypes.c_int(0)
if fobj is None:
if self.__set_objective_to_none:

Просмотреть файл

@ -290,9 +290,9 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
if ((len % num_data_) != 0) {
Log::Fatal("Initial score size doesn't match data size");
}
if (!init_score_.empty()) { init_score_.clear(); }
if (init_score_.empty()) { init_score_.resize(len); }
num_init_score_ = len;
init_score_ = std::vector<double>(len);
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < num_init_score_; ++i) {
init_score_[i] = Common::AvoidInf(init_score[i]);
@ -308,8 +308,8 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("Length of label is not same with #data");
}
if (!label_.empty()) { label_.clear(); }
label_ = std::vector<label_t>(num_data_);
if (label_.empty()) { label_.resize(num_data_); }
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = Common::AvoidInf(label[i]);
@ -327,9 +327,9 @@ void Metadata::SetWeights(const label_t* weights, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("Length of weights is not same with #data");
}
if (!weights_.empty()) { weights_.clear(); }
if (weights_.empty()) { weights_.resize(num_data_); }
num_weights_ = num_data_;
weights_ = std::vector<label_t>(num_weights_);
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = Common::AvoidInf(weights[i]);
@ -354,9 +354,8 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
if (num_data_ != sum) {
Log::Fatal("Sum of query counts is not same with #data");
}
if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
num_queries_ = len;
query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
query_boundaries_.resize(num_queries_ + 1);
query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + query[i];