[R-package] move creation of character vectors in some methods to C++ side (#4256)

* [R-package] move creation of character vectors in some methods to C++ side

* convert LGBM_BoosterGetEvalNames_R

* convert LGBM_BoosterDumpModel_R and LGBM_BoosterSaveModelToString_R

* remove debugging code

* update docs

* remove comment

* add handling for larger model strings

* handle large strings in feature and eval names

* got long feature names working

* more fixes

* linting

* resize

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* stricter test

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
James Lamb 2021-05-09 17:28:58 -05:00 коммит произвёл GitHub
Родитель a421217e4e
Коммит c1d2dbe2c5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 195 добавлений и 194 удалений

Просмотреть файл

@ -466,40 +466,14 @@ Booster <- R6::R6Class(
num_iteration <- self$best_iter
}
# Create buffer
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
# Call buffer
.Call(
model_str <- .Call(
LGBM_BoosterSaveModelToString_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)
# Check for buffer content
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_BoosterSaveModelToString_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)
}
return(
lgb.encode.char(arr = buf, len = act_len)
)
return(model_str)
},
@ -511,36 +485,14 @@ Booster <- R6::R6Class(
num_iteration <- self$best_iter
}
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
.Call(
model_str <- .Call(
LGBM_BoosterDumpModel_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_BoosterDumpModel_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)
}
return(
lgb.encode.char(arr = buf, len = act_len)
)
return(model_str)
},
@ -666,41 +618,20 @@ Booster <- R6::R6Class(
# Check for evaluation names emptiness
if (is.null(private$eval_names)) {
# Get evaluation names
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
.Call(
eval_names <- .Call(
LGBM_BoosterGetEvalNames_R
, private$handle
, buf_len
, act_len
, buf
)
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_BoosterGetEvalNames_R
, private$handle
, buf_len
, act_len
, buf
)
}
names <- lgb.encode.char(arr = buf, len = act_len)
# Check names' length
if (nchar(names) > 0L) {
if (length(eval_names) > 0L) {
# Parse and store privately names
names <- strsplit(names, "\t")[[1L]]
private$eval_names <- names
private$eval_names <- eval_names
# some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
# ndcg metric evaluated at the first "query result" in learning-to-rank
metric_names <- gsub("@.*", "", names)
metric_names <- gsub("@.*", "", eval_names)
private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
}

Просмотреть файл

@ -369,31 +369,10 @@ Dataset <- R6::R6Class(
# Check for handle
if (!lgb.is.null.handle(x = private$handle)) {
# Get feature names and write them
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
.Call(
private$colnames <- .Call(
LGBM_DatasetGetFeatureNames_R
, private$handle
, buf_len
, act_len
, buf
)
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_DatasetGetFeatureNames_R
, private$handle
, buf_len
, act_len
, buf
)
}
cnames <- lgb.encode.char(arr = buf, len = act_len)
private$colnames <- as.character(base::strsplit(cnames, "\t")[[1L]])
return(private$colnames)
} else if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) {

Просмотреть файл

@ -18,13 +18,6 @@ lgb.is.null.handle <- function(x) {
return(is.null(x) || is.na(x))
}
lgb.encode.char <- function(arr, len) {
if (!is.raw(arr)) {
stop("lgb.encode.char: Can only encode from raw type")
}
return(rawToChar(arr[seq_len(len)]))
}
# [description] Get the most recent error stored on the C++ side and raise it
# as an R error.
lgb.last_error <- function() {

Просмотреть файл

@ -96,8 +96,6 @@ typedef union { VECTOR_SER s; double align; } SEXPREC_ALIGN;
#define DATAPTR(x) ((reinterpret_cast<SEXPREC_ALIGN*>(x)) + 1)
#define R_CHAR_PTR(x) (reinterpret_cast<char*>DATAPTR(x))
#define R_IS_NULL(x) ((*reinterpret_cast<LGBM_SE>(x)).sxpinfo.type == 0)
// 64bit pointer

Просмотреть файл

@ -39,23 +39,9 @@
return R_NilValue; \
}
using LightGBM::Common::Join;
using LightGBM::Common::Split;
using LightGBM::Log;
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, SEXP actual_len, size_t str_len) {
if (str_len > INT32_MAX) {
Log::Fatal("Don't support large string in R-package");
}
INTEGER(actual_len)[0] = static_cast<int>(str_len);
if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) {
return dest;
}
auto ptr = R_CHAR_PTR(dest);
std::memcpy(ptr, src, str_len);
return dest;
}
SEXP LGBM_GetLastError_R() {
SEXP out;
out = PROTECT(Rf_allocVector(STRSXP, 1));
@ -153,10 +139,8 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
R_API_END();
}
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
SEXP buf_len,
SEXP actual_len,
LGBM_SE feature_names) {
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) {
SEXP feature_names;
R_API_BEGIN();
int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
@ -175,10 +159,29 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
// if any feature names were larger than allocated size,
// allow for a larger size and try again
if (required_string_size > reserved_string_size) {
for (int i = 0; i < len; ++i) {
names[i].resize(required_string_size);
ptr_names[i] = names[i].data();
}
CHECK_CALL(
LGBM_DatasetGetFeatureNames(
R_GET_PTR(handle),
len,
&out_len,
required_string_size,
&required_string_size,
ptr_names.data()));
}
CHECK_EQ(len, out_len);
CHECK_GE(reserved_string_size, required_string_size);
auto merge_str = Join<char*>(ptr_names, "\t");
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
feature_names = PROTECT(Rf_allocVector(STRSXP, len));
for (int i = 0; i < len; ++i) {
SET_STRING_ELT(feature_names, i, Rf_mkChar(ptr_names[i]));
}
UNPROTECT(1);
return feature_names;
R_API_END();
}
@ -432,10 +435,8 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
R_API_END();
}
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
SEXP buf_len,
SEXP actual_len,
LGBM_SE eval_names) {
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) {
SEXP eval_names;
R_API_BEGIN();
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
@ -456,10 +457,29 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
// if any eval names were larger than allocated size,
// allow for a larger size and try again
if (required_string_size > reserved_string_size) {
for (int i = 0; i < len; ++i) {
names[i].resize(required_string_size);
ptr_names[i] = names[i].data();
}
CHECK_CALL(
LGBM_BoosterGetEvalNames(
R_GET_PTR(handle),
len,
&out_len,
required_string_size,
&required_string_size,
ptr_names.data()));
}
CHECK_EQ(out_len, len);
CHECK_GE(reserved_string_size, required_string_size);
auto merge_names = Join<char*>(ptr_names, "\t");
EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
eval_names = PROTECT(Rf_allocVector(STRSXP, len));
for (int i = 0; i < len; ++i) {
SET_STRING_ELT(eval_names, i, Rf_mkChar(ptr_names[i]));
}
UNPROTECT(1);
return eval_names;
R_API_END();
}
@ -616,31 +636,47 @@ SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,
SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
SEXP actual_len,
LGBM_SE out_str) {
SEXP feature_importance_type) {
SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0;
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
int64_t buf_len = 1024 * 1024;
int64_t num_iter = Rf_asInteger(num_iteration);
int64_t importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
}
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1);
return model_str;
R_API_END();
}
SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
SEXP actual_len,
LGBM_SE out_str) {
SEXP feature_importance_type) {
SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0;
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
int64_t buf_len = 1024 * 1024;
int64_t num_iter = Rf_asInteger(num_iteration);
int64_t importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
}
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1);
return model_str;
R_API_END();
}
@ -652,7 +688,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_DatasetCreateFromMat_R" , (DL_FUNC) &LGBM_DatasetCreateFromMat_R , 6},
{"LGBM_DatasetGetSubset_R" , (DL_FUNC) &LGBM_DatasetGetSubset_R , 5},
{"LGBM_DatasetSetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R , 2},
{"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 4},
{"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 1},
{"LGBM_DatasetSaveBinary_R" , (DL_FUNC) &LGBM_DatasetSaveBinary_R , 2},
{"LGBM_DatasetFree_R" , (DL_FUNC) &LGBM_DatasetFree_R , 1},
{"LGBM_DatasetSetField_R" , (DL_FUNC) &LGBM_DatasetSetField_R , 4},
@ -676,7 +712,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
{"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
{"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
{"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 4},
{"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 1},
{"LGBM_BoosterGetEval_R" , (DL_FUNC) &LGBM_BoosterGetEval_R , 3},
{"LGBM_BoosterGetNumPredict_R" , (DL_FUNC) &LGBM_BoosterGetNumPredict_R , 3},
{"LGBM_BoosterGetPredict_R" , (DL_FUNC) &LGBM_BoosterGetPredict_R , 3},
@ -685,8 +721,8 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 6},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 6},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{NULL, NULL, 0}
};

Просмотреть файл

@ -101,7 +101,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R(
* \brief save feature names to Dataset
* \param handle handle
* \param feature_names feature names
* \return R NULL value
* \return R character vector of feature names
*/
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R(
LGBM_SE handle,
@ -109,16 +109,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R(
);
/*!
* \brief save feature names to Dataset
* \param handle handle
* \param feature_names feature names
* \return 0 when succeed, -1 when failure happens
* \brief get feature names from Dataset
* \param handle Dataset handle
* \return an R character vector with feature names from the Dataset or NULL if no feature names
*/
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R(
LGBM_SE handle,
SEXP buf_len,
SEXP actual_len,
LGBM_SE feature_names
LGBM_SE handle
);
/*!
@ -387,15 +383,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R(
);
/*!
* \brief Get Name of eval
* \param eval_names eval names
* \return 0 when succeed, -1 when failure happens
* \brief Get names of eval metrics
* \param handle Handle of booster
* \return R character vector with names of eval metrics
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R(
LGBM_SE handle,
SEXP buf_len,
SEXP actual_len,
LGBM_SE eval_names
LGBM_SE handle
);
/*!
@ -583,34 +576,28 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
/*!
* \brief create string containing model
* \param handle handle
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param out_str string of model
* \return 0 when succeed, -1 when failure happens
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \return R character vector (length=1) with model string
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
LGBM_SE handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
SEXP actual_len,
LGBM_SE out_str
SEXP feature_importance_type
);
/*!
* \brief dump model to json
* \param handle handle
* \brief dump model to JSON
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param out_str json format string of model
* \return 0 when succeed, -1 when failure happens
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \return R character vector (length=1) with model JSON
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
LGBM_SE handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
SEXP actual_len,
LGBM_SE out_str
SEXP feature_importance_type
);
#endif // LIGHTGBM_R_H_

Просмотреть файл

@ -278,3 +278,21 @@ test_that("lgb.Dataset: should be able to run lgb.cv() immediately after using l
expect_is(bst, "lgb.CVBooster")
})
test_that("lgb.Dataset: should be able to use and retrieve long feature names", {
# set one feature to a value longer than the default buffer size used
# in LGBM_DatasetGetFeatureNames_R
feature_names <- names(iris)
long_name <- paste0(rep("a", 1000L), collapse = "")
feature_names[1L] <- long_name
names(iris) <- feature_names
# check that feature name survived the trip from R to C++ and back
dtrain <- lgb.Dataset(
data = as.matrix(iris[, -5L])
, label = as.numeric(iris$Species) - 1L
)
dtrain$construct()
col_names <- dtrain$get_colnames()
expect_equal(col_names[1L], long_name)
expect_equal(nchar(col_names[1L]), 1000L)
})

Просмотреть файл

@ -240,6 +240,74 @@ test_that("Loading a Booster from a string works", {
expect_identical(pred, pred2)
})
test_that("Saving a large model to string should work", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 100L
, learning_rate = 0.01
, nrounds = 500L
, objective = "binary"
, save_name = tempfile(fileext = ".model")
, verbose = -1L
)
pred <- predict(bst, train$data)
pred_leaf_indx <- predict(bst, train$data, predleaf = TRUE)
pred_raw_score <- predict(bst, train$data, rawscore = TRUE)
model_string <- bst$save_model_to_string()
# make sure this test is still producing a model bigger than the default
# buffer size used in LGBM_BoosterSaveModelToString_R
expect_gt(nchar(model_string), 1024L * 1024L)
# finalize the booster and destroy it so you know we aren't cheating
bst$finalize()
expect_null(bst$.__enclos_env__$private$handle)
rm(bst)
# make sure a new model can be created from this string, and that it
# produces expected results
bst2 <- lgb.load(
model_str = model_string
)
pred2 <- predict(bst2, train$data)
pred2_leaf_indx <- predict(bst2, train$data, predleaf = TRUE)
pred2_raw_score <- predict(bst2, train$data, rawscore = TRUE)
expect_identical(pred, pred2)
expect_identical(pred_leaf_indx, pred2_leaf_indx)
expect_identical(pred_raw_score, pred2_raw_score)
})
test_that("Saving a large model to JSON should work", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 100L
, learning_rate = 0.01
, nrounds = 200L
, objective = "binary"
, save_name = tempfile(fileext = ".model")
, verbose = -1L
)
model_json <- bst$dump_model()
# make sure this test is still producing a model bigger than the default
# buffer size used in LGBM_BoosterDumpModel_R
expect_gt(nchar(model_json), 1024L * 1024L)
# check that it is valid JSON that looks like a LightGBM model
model_list <- jsonlite::fromJSON(model_json)
expect_equal(model_list[["objective"]], "binary sigmoid:1")
})
test_that("If a string and a file are both passed to lgb.load() the file is used model_str is totally ignored", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")

Просмотреть файл

@ -1,12 +1,3 @@
context("lgb.encode.char")
test_that("lgb.encode.char throws an informative error if it is passed a non-raw input", {
x <- "some-string"
expect_error({
lgb.encode.char(x)
}, regexp = "Can only encode from raw type")
})
context("lgb.check.r6.class")
test_that("lgb.check.r6.class() should return FALSE for NULL input", {