зеркало из https://github.com/microsoft/LightGBM.git
[python] [R-package] Use the same address when updated label/weight/query (#2662)
* Update metadata.cpp * add version for training set, for efficiently update label/weight/... during training. * Update lgb.Booster.R
This commit is contained in:
Родитель
350d56d593
Коммит
82886ba644
|
@ -55,6 +55,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Create private booster information
|
||||
private$train_set <- train_set
|
||||
private$train_set_version <- train_set$.__enclos_env__$private$version
|
||||
private$num_dataset <- 1L
|
||||
private$init_predictor <- train_set$.__enclos_env__$private$predictor
|
||||
|
||||
|
@ -207,6 +208,12 @@ Booster <- R6::R6Class(
|
|||
# Perform boosting update iteration
|
||||
update = function(train_set = NULL, fobj = NULL) {
|
||||
|
||||
if (is.null(train_set)) {
|
||||
if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
|
||||
train_set <- private$train_set
|
||||
}
|
||||
}
|
||||
|
||||
# Check if training set is not null
|
||||
if (!is.null(train_set)) {
|
||||
|
||||
|
@ -230,6 +237,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Store private train set
|
||||
private$train_set <- train_set
|
||||
private$train_set_version <- train_set$.__enclos_env__$private$version
|
||||
|
||||
}
|
||||
|
||||
|
@ -497,6 +505,7 @@ Booster <- R6::R6Class(
|
|||
eval_names = NULL,
|
||||
higher_better_inner_eval = NULL,
|
||||
set_objective_to_none = FALSE,
|
||||
train_set_version = 0L,
|
||||
# Predict data
|
||||
inner_predict = function(idx) {
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ Dataset <- R6::R6Class(
|
|||
private$free_raw_data <- free_raw_data
|
||||
private$used_indices <- sort(used_indices, decreasing = FALSE)
|
||||
private$info <- info
|
||||
private$version <- 0L
|
||||
|
||||
},
|
||||
|
||||
|
@ -503,6 +504,8 @@ Dataset <- R6::R6Class(
|
|||
, length(info)
|
||||
)
|
||||
|
||||
private$version <- private$version + 1L
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -638,6 +641,7 @@ Dataset <- R6::R6Class(
|
|||
free_raw_data = TRUE,
|
||||
used_indices = NULL,
|
||||
info = NULL,
|
||||
version = 0L,
|
||||
|
||||
# Get handle
|
||||
get_handle = function() {
|
||||
|
|
|
@ -771,6 +771,7 @@ class Dataset(object):
|
|||
self.params_back_up = None
|
||||
self.feature_penalty = None
|
||||
self.monotone_constraints = None
|
||||
self.version = 0
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
|
@ -1233,6 +1234,7 @@ class Dataset(object):
|
|||
ptr_data,
|
||||
ctypes.c_int(len(data)),
|
||||
ctypes.c_int(type_data)))
|
||||
self.version += 1
|
||||
return self
|
||||
|
||||
def get_field(self, field_name):
|
||||
|
@ -1740,6 +1742,7 @@ class Booster(object):
|
|||
self.__is_predicted_cur_iter = [False]
|
||||
self.__get_eval_info()
|
||||
self.pandas_categorical = train_set.pandas_categorical
|
||||
self.train_set_version = train_set.version
|
||||
elif model_file is not None:
|
||||
# Prediction task
|
||||
out_num_iterations = ctypes.c_int(0)
|
||||
|
@ -2076,7 +2079,12 @@ class Booster(object):
|
|||
Whether the update was successfully finished.
|
||||
"""
|
||||
# need reset training data
|
||||
if train_set is not None and train_set is not self.train_set:
|
||||
if train_set is None and self.train_set_version != self.train_set.version:
|
||||
train_set = self.train_set
|
||||
is_the_same_train_set = False
|
||||
else:
|
||||
is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version
|
||||
if train_set is not None and not is_the_same_train_set:
|
||||
if not isinstance(train_set, Dataset):
|
||||
raise TypeError('Training data should be Dataset instance, met {}'
|
||||
.format(type(train_set).__name__))
|
||||
|
@ -2088,6 +2096,7 @@ class Booster(object):
|
|||
self.handle,
|
||||
self.train_set.construct().handle))
|
||||
self.__inner_predict_buffer[0] = None
|
||||
self.train_set_version = self.train_set.version
|
||||
is_finished = ctypes.c_int(0)
|
||||
if fobj is None:
|
||||
if self.__set_objective_to_none:
|
||||
|
|
|
@ -290,9 +290,9 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
|
|||
if ((len % num_data_) != 0) {
|
||||
Log::Fatal("Initial score size doesn't match data size");
|
||||
}
|
||||
if (!init_score_.empty()) { init_score_.clear(); }
|
||||
if (init_score_.empty()) { init_score_.resize(len); }
|
||||
num_init_score_ = len;
|
||||
init_score_ = std::vector<double>(len);
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int64_t i = 0; i < num_init_score_; ++i) {
|
||||
init_score_[i] = Common::AvoidInf(init_score[i]);
|
||||
|
@ -308,8 +308,8 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) {
|
|||
if (num_data_ != len) {
|
||||
Log::Fatal("Length of label is not same with #data");
|
||||
}
|
||||
if (!label_.empty()) { label_.clear(); }
|
||||
label_ = std::vector<label_t>(num_data_);
|
||||
if (label_.empty()) { label_.resize(num_data_); }
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (data_size_t i = 0; i < num_data_; ++i) {
|
||||
label_[i] = Common::AvoidInf(label[i]);
|
||||
|
@ -327,9 +327,9 @@ void Metadata::SetWeights(const label_t* weights, data_size_t len) {
|
|||
if (num_data_ != len) {
|
||||
Log::Fatal("Length of weights is not same with #data");
|
||||
}
|
||||
if (!weights_.empty()) { weights_.clear(); }
|
||||
if (weights_.empty()) { weights_.resize(num_data_); }
|
||||
num_weights_ = num_data_;
|
||||
weights_ = std::vector<label_t>(num_weights_);
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (data_size_t i = 0; i < num_weights_; ++i) {
|
||||
weights_[i] = Common::AvoidInf(weights[i]);
|
||||
|
@ -354,9 +354,8 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
|
|||
if (num_data_ != sum) {
|
||||
Log::Fatal("Sum of query counts is not same with #data");
|
||||
}
|
||||
if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
|
||||
num_queries_ = len;
|
||||
query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
|
||||
query_boundaries_.resize(num_queries_ + 1);
|
||||
query_boundaries_[0] = 0;
|
||||
for (data_size_t i = 0; i < num_queries_; ++i) {
|
||||
query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
|
||||
|
|
Загрузка…
Ссылка в новой задаче