зеркало из https://github.com/microsoft/LightGBM.git
[R-package] move creation of character vectors in some methods to C++ side (#4256)
* [R-package] move creation of character vectors in some methods to C++ side * convert LGBM_BoosterGetEvalNames_R * convert LGBM_BoosterDumpModel_R and LGBM_BoosterSaveModelToString_R * remove debugging code * update docs * remove comment * add handling for larger model strings * handle large strings in feature and eval names * got long feature names working * more fixes * linting * resize * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * stricter test Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
a421217e4e
Коммит
c1d2dbe2c5
|
@ -466,40 +466,14 @@ Booster <- R6::R6Class(
|
|||
num_iteration <- self$best_iter
|
||||
}
|
||||
|
||||
# Create buffer
|
||||
buf_len <- as.integer(1024L * 1024L)
|
||||
act_len <- 0L
|
||||
buf <- raw(buf_len)
|
||||
|
||||
# Call buffer
|
||||
.Call(
|
||||
model_str <- .Call(
|
||||
LGBM_BoosterSaveModelToString_R
|
||||
, private$handle
|
||||
, as.integer(num_iteration)
|
||||
, as.integer(feature_importance_type)
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
|
||||
# Check for buffer content
|
||||
if (act_len > buf_len) {
|
||||
buf_len <- act_len
|
||||
buf <- raw(buf_len)
|
||||
.Call(
|
||||
LGBM_BoosterSaveModelToString_R
|
||||
, private$handle
|
||||
, as.integer(num_iteration)
|
||||
, as.integer(feature_importance_type)
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
}
|
||||
|
||||
return(
|
||||
lgb.encode.char(arr = buf, len = act_len)
|
||||
)
|
||||
return(model_str)
|
||||
|
||||
},
|
||||
|
||||
|
@ -511,36 +485,14 @@ Booster <- R6::R6Class(
|
|||
num_iteration <- self$best_iter
|
||||
}
|
||||
|
||||
buf_len <- as.integer(1024L * 1024L)
|
||||
act_len <- 0L
|
||||
buf <- raw(buf_len)
|
||||
.Call(
|
||||
model_str <- .Call(
|
||||
LGBM_BoosterDumpModel_R
|
||||
, private$handle
|
||||
, as.integer(num_iteration)
|
||||
, as.integer(feature_importance_type)
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
|
||||
if (act_len > buf_len) {
|
||||
buf_len <- act_len
|
||||
buf <- raw(buf_len)
|
||||
.Call(
|
||||
LGBM_BoosterDumpModel_R
|
||||
, private$handle
|
||||
, as.integer(num_iteration)
|
||||
, as.integer(feature_importance_type)
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
}
|
||||
|
||||
return(
|
||||
lgb.encode.char(arr = buf, len = act_len)
|
||||
)
|
||||
return(model_str)
|
||||
|
||||
},
|
||||
|
||||
|
@ -666,41 +618,20 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Check for evaluation names emptiness
|
||||
if (is.null(private$eval_names)) {
|
||||
|
||||
# Get evaluation names
|
||||
buf_len <- as.integer(1024L * 1024L)
|
||||
act_len <- 0L
|
||||
buf <- raw(buf_len)
|
||||
.Call(
|
||||
eval_names <- .Call(
|
||||
LGBM_BoosterGetEvalNames_R
|
||||
, private$handle
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
if (act_len > buf_len) {
|
||||
buf_len <- act_len
|
||||
buf <- raw(buf_len)
|
||||
.Call(
|
||||
LGBM_BoosterGetEvalNames_R
|
||||
, private$handle
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
}
|
||||
names <- lgb.encode.char(arr = buf, len = act_len)
|
||||
|
||||
# Check names' length
|
||||
if (nchar(names) > 0L) {
|
||||
if (length(eval_names) > 0L) {
|
||||
|
||||
# Parse and store privately names
|
||||
names <- strsplit(names, "\t")[[1L]]
|
||||
private$eval_names <- names
|
||||
private$eval_names <- eval_names
|
||||
|
||||
# some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
|
||||
# ndcg metric evaluated at the first "query result" in learning-to-rank
|
||||
metric_names <- gsub("@.*", "", names)
|
||||
metric_names <- gsub("@.*", "", eval_names)
|
||||
private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]
|
||||
|
||||
}
|
||||
|
|
|
@ -369,31 +369,10 @@ Dataset <- R6::R6Class(
|
|||
|
||||
# Check for handle
|
||||
if (!lgb.is.null.handle(x = private$handle)) {
|
||||
|
||||
# Get feature names and write them
|
||||
buf_len <- as.integer(1024L * 1024L)
|
||||
act_len <- 0L
|
||||
buf <- raw(buf_len)
|
||||
.Call(
|
||||
private$colnames <- .Call(
|
||||
LGBM_DatasetGetFeatureNames_R
|
||||
, private$handle
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
if (act_len > buf_len) {
|
||||
buf_len <- act_len
|
||||
buf <- raw(buf_len)
|
||||
.Call(
|
||||
LGBM_DatasetGetFeatureNames_R
|
||||
, private$handle
|
||||
, buf_len
|
||||
, act_len
|
||||
, buf
|
||||
)
|
||||
}
|
||||
cnames <- lgb.encode.char(arr = buf, len = act_len)
|
||||
private$colnames <- as.character(base::strsplit(cnames, "\t")[[1L]])
|
||||
return(private$colnames)
|
||||
|
||||
} else if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) {
|
||||
|
|
|
@ -18,13 +18,6 @@ lgb.is.null.handle <- function(x) {
|
|||
return(is.null(x) || is.na(x))
|
||||
}
|
||||
|
||||
lgb.encode.char <- function(arr, len) {
|
||||
if (!is.raw(arr)) {
|
||||
stop("lgb.encode.char: Can only encode from raw type")
|
||||
}
|
||||
return(rawToChar(arr[seq_len(len)]))
|
||||
}
|
||||
|
||||
# [description] Get the most recent error stored on the C++ side and raise it
|
||||
# as an R error.
|
||||
lgb.last_error <- function() {
|
||||
|
|
|
@ -96,8 +96,6 @@ typedef union { VECTOR_SER s; double align; } SEXPREC_ALIGN;
|
|||
|
||||
#define DATAPTR(x) ((reinterpret_cast<SEXPREC_ALIGN*>(x)) + 1)
|
||||
|
||||
#define R_CHAR_PTR(x) (reinterpret_cast<char*>DATAPTR(x))
|
||||
|
||||
#define R_IS_NULL(x) ((*reinterpret_cast<LGBM_SE>(x)).sxpinfo.type == 0)
|
||||
|
||||
// 64bit pointer
|
||||
|
|
|
@ -39,23 +39,9 @@
|
|||
return R_NilValue; \
|
||||
}
|
||||
|
||||
using LightGBM::Common::Join;
|
||||
using LightGBM::Common::Split;
|
||||
using LightGBM::Log;
|
||||
|
||||
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, SEXP actual_len, size_t str_len) {
|
||||
if (str_len > INT32_MAX) {
|
||||
Log::Fatal("Don't support large string in R-package");
|
||||
}
|
||||
INTEGER(actual_len)[0] = static_cast<int>(str_len);
|
||||
if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) {
|
||||
return dest;
|
||||
}
|
||||
auto ptr = R_CHAR_PTR(dest);
|
||||
std::memcpy(ptr, src, str_len);
|
||||
return dest;
|
||||
}
|
||||
|
||||
SEXP LGBM_GetLastError_R() {
|
||||
SEXP out;
|
||||
out = PROTECT(Rf_allocVector(STRSXP, 1));
|
||||
|
@ -153,10 +139,8 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
|
|||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE feature_names) {
|
||||
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) {
|
||||
SEXP feature_names;
|
||||
R_API_BEGIN();
|
||||
int len = 0;
|
||||
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
|
||||
|
@ -175,10 +159,29 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
|
|||
len, &out_len,
|
||||
reserved_string_size, &required_string_size,
|
||||
ptr_names.data()));
|
||||
// if any feature names were larger than allocated size,
|
||||
// allow for a larger size and try again
|
||||
if (required_string_size > reserved_string_size) {
|
||||
for (int i = 0; i < len; ++i) {
|
||||
names[i].resize(required_string_size);
|
||||
ptr_names[i] = names[i].data();
|
||||
}
|
||||
CHECK_CALL(
|
||||
LGBM_DatasetGetFeatureNames(
|
||||
R_GET_PTR(handle),
|
||||
len,
|
||||
&out_len,
|
||||
required_string_size,
|
||||
&required_string_size,
|
||||
ptr_names.data()));
|
||||
}
|
||||
CHECK_EQ(len, out_len);
|
||||
CHECK_GE(reserved_string_size, required_string_size);
|
||||
auto merge_str = Join<char*>(ptr_names, "\t");
|
||||
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
|
||||
feature_names = PROTECT(Rf_allocVector(STRSXP, len));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
SET_STRING_ELT(feature_names, i, Rf_mkChar(ptr_names[i]));
|
||||
}
|
||||
UNPROTECT(1);
|
||||
return feature_names;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
|
@ -432,10 +435,8 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
|
|||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE eval_names) {
|
||||
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) {
|
||||
SEXP eval_names;
|
||||
R_API_BEGIN();
|
||||
int len;
|
||||
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
|
||||
|
@ -456,10 +457,29 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
|
|||
len, &out_len,
|
||||
reserved_string_size, &required_string_size,
|
||||
ptr_names.data()));
|
||||
// if any eval names were larger than allocated size,
|
||||
// allow for a larger size and try again
|
||||
if (required_string_size > reserved_string_size) {
|
||||
for (int i = 0; i < len; ++i) {
|
||||
names[i].resize(required_string_size);
|
||||
ptr_names[i] = names[i].data();
|
||||
}
|
||||
CHECK_CALL(
|
||||
LGBM_BoosterGetEvalNames(
|
||||
R_GET_PTR(handle),
|
||||
len,
|
||||
&out_len,
|
||||
required_string_size,
|
||||
&required_string_size,
|
||||
ptr_names.data()));
|
||||
}
|
||||
CHECK_EQ(out_len, len);
|
||||
CHECK_GE(reserved_string_size, required_string_size);
|
||||
auto merge_names = Join<char*>(ptr_names, "\t");
|
||||
EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
|
||||
eval_names = PROTECT(Rf_allocVector(STRSXP, len));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
SET_STRING_ELT(eval_names, i, Rf_mkChar(ptr_names[i]));
|
||||
}
|
||||
UNPROTECT(1);
|
||||
return eval_names;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
|
@ -616,31 +636,47 @@ SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,
|
|||
|
||||
SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
|
||||
SEXP num_iteration,
|
||||
SEXP feature_importance_type,
|
||||
SEXP buffer_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE out_str) {
|
||||
SEXP feature_importance_type) {
|
||||
SEXP model_str;
|
||||
R_API_BEGIN();
|
||||
int64_t out_len = 0;
|
||||
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
|
||||
int64_t buf_len = 1024 * 1024;
|
||||
int64_t num_iter = Rf_asInteger(num_iteration);
|
||||
int64_t importance_type = Rf_asInteger(feature_importance_type);
|
||||
std::vector<char> inner_char_buf(buf_len);
|
||||
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
|
||||
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
|
||||
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
|
||||
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
|
||||
if (out_len > buf_len) {
|
||||
inner_char_buf.resize(out_len);
|
||||
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
|
||||
}
|
||||
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
|
||||
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
|
||||
UNPROTECT(1);
|
||||
return model_str;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
|
||||
SEXP num_iteration,
|
||||
SEXP feature_importance_type,
|
||||
SEXP buffer_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE out_str) {
|
||||
SEXP feature_importance_type) {
|
||||
SEXP model_str;
|
||||
R_API_BEGIN();
|
||||
int64_t out_len = 0;
|
||||
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
|
||||
int64_t buf_len = 1024 * 1024;
|
||||
int64_t num_iter = Rf_asInteger(num_iteration);
|
||||
int64_t importance_type = Rf_asInteger(feature_importance_type);
|
||||
std::vector<char> inner_char_buf(buf_len);
|
||||
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
|
||||
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
|
||||
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
|
||||
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
|
||||
if (out_len > buf_len) {
|
||||
inner_char_buf.resize(out_len);
|
||||
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
|
||||
}
|
||||
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
|
||||
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
|
||||
UNPROTECT(1);
|
||||
return model_str;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
|
@ -652,7 +688,7 @@ static const R_CallMethodDef CallEntries[] = {
|
|||
{"LGBM_DatasetCreateFromMat_R" , (DL_FUNC) &LGBM_DatasetCreateFromMat_R , 6},
|
||||
{"LGBM_DatasetGetSubset_R" , (DL_FUNC) &LGBM_DatasetGetSubset_R , 5},
|
||||
{"LGBM_DatasetSetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R , 2},
|
||||
{"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 4},
|
||||
{"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 1},
|
||||
{"LGBM_DatasetSaveBinary_R" , (DL_FUNC) &LGBM_DatasetSaveBinary_R , 2},
|
||||
{"LGBM_DatasetFree_R" , (DL_FUNC) &LGBM_DatasetFree_R , 1},
|
||||
{"LGBM_DatasetSetField_R" , (DL_FUNC) &LGBM_DatasetSetField_R , 4},
|
||||
|
@ -676,7 +712,7 @@ static const R_CallMethodDef CallEntries[] = {
|
|||
{"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
|
||||
{"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
|
||||
{"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
|
||||
{"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 4},
|
||||
{"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 1},
|
||||
{"LGBM_BoosterGetEval_R" , (DL_FUNC) &LGBM_BoosterGetEval_R , 3},
|
||||
{"LGBM_BoosterGetNumPredict_R" , (DL_FUNC) &LGBM_BoosterGetNumPredict_R , 3},
|
||||
{"LGBM_BoosterGetPredict_R" , (DL_FUNC) &LGBM_BoosterGetPredict_R , 3},
|
||||
|
@ -685,8 +721,8 @@ static const R_CallMethodDef CallEntries[] = {
|
|||
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
|
||||
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
|
||||
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
|
||||
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 6},
|
||||
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 6},
|
||||
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
|
||||
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
|
||||
{NULL, NULL, 0}
|
||||
};
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R(
|
|||
* \brief save feature names to Dataset
|
||||
* \param handle handle
|
||||
* \param feature_names feature names
|
||||
* \return R NULL value
|
||||
* \return R character vector of feature names
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R(
|
||||
LGBM_SE handle,
|
||||
|
@ -109,16 +109,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R(
|
|||
);
|
||||
|
||||
/*!
|
||||
* \brief save feature names to Dataset
|
||||
* \param handle handle
|
||||
* \param feature_names feature names
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
* \brief get feature names from Dataset
|
||||
* \param handle Dataset handle
|
||||
* \return an R character vector with feature names from the Dataset or NULL if no feature names
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R(
|
||||
LGBM_SE handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE feature_names
|
||||
LGBM_SE handle
|
||||
);
|
||||
|
||||
/*!
|
||||
|
@ -387,15 +383,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R(
|
|||
);
|
||||
|
||||
/*!
|
||||
* \brief Get Name of eval
|
||||
* \param eval_names eval names
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
* \brief Get names of eval metrics
|
||||
* \param handle Handle of booster
|
||||
* \return R character vector with names of eval metrics
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R(
|
||||
LGBM_SE handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE eval_names
|
||||
LGBM_SE handle
|
||||
);
|
||||
|
||||
/*!
|
||||
|
@ -583,34 +576,28 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
|
|||
|
||||
/*!
|
||||
* \brief create string containing model
|
||||
* \param handle handle
|
||||
* \param handle Booster handle
|
||||
* \param num_iteration, <= 0 means save all
|
||||
* \param out_str string of model
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
* \param feature_importance_type type of feature importance, 0: split, 1: gain
|
||||
* \return R character vector (length=1) with model string
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
|
||||
LGBM_SE handle,
|
||||
SEXP num_iteration,
|
||||
SEXP feature_importance_type,
|
||||
SEXP buffer_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE out_str
|
||||
SEXP feature_importance_type
|
||||
);
|
||||
|
||||
/*!
|
||||
* \brief dump model to json
|
||||
* \param handle handle
|
||||
* \brief dump model to JSON
|
||||
* \param handle Booster handle
|
||||
* \param num_iteration, <= 0 means save all
|
||||
* \param out_str json format string of model
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
* \param feature_importance_type type of feature importance, 0: split, 1: gain
|
||||
* \return R character vector (length=1) with model JSON
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
|
||||
LGBM_SE handle,
|
||||
SEXP num_iteration,
|
||||
SEXP feature_importance_type,
|
||||
SEXP buffer_len,
|
||||
SEXP actual_len,
|
||||
LGBM_SE out_str
|
||||
SEXP feature_importance_type
|
||||
);
|
||||
|
||||
#endif // LIGHTGBM_R_H_
|
||||
|
|
|
@ -278,3 +278,21 @@ test_that("lgb.Dataset: should be able to run lgb.cv() immediately after using l
|
|||
|
||||
expect_is(bst, "lgb.CVBooster")
|
||||
})
|
||||
|
||||
test_that("lgb.Dataset: should be able to use and retrieve long feature names", {
|
||||
# set one feature to a value longer than the default buffer size used
|
||||
# in LGBM_DatasetGetFeatureNames_R
|
||||
feature_names <- names(iris)
|
||||
long_name <- paste0(rep("a", 1000L), collapse = "")
|
||||
feature_names[1L] <- long_name
|
||||
names(iris) <- feature_names
|
||||
# check that feature name survived the trip from R to C++ and back
|
||||
dtrain <- lgb.Dataset(
|
||||
data = as.matrix(iris[, -5L])
|
||||
, label = as.numeric(iris$Species) - 1L
|
||||
)
|
||||
dtrain$construct()
|
||||
col_names <- dtrain$get_colnames()
|
||||
expect_equal(col_names[1L], long_name)
|
||||
expect_equal(nchar(col_names[1L]), 1000L)
|
||||
})
|
||||
|
|
|
@ -240,6 +240,74 @@ test_that("Loading a Booster from a string works", {
|
|||
expect_identical(pred, pred2)
|
||||
})
|
||||
|
||||
test_that("Saving a large model to string should work", {
|
||||
set.seed(708L)
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
train <- agaricus.train
|
||||
bst <- lightgbm(
|
||||
data = as.matrix(train$data)
|
||||
, label = train$label
|
||||
, num_leaves = 100L
|
||||
, learning_rate = 0.01
|
||||
, nrounds = 500L
|
||||
, objective = "binary"
|
||||
, save_name = tempfile(fileext = ".model")
|
||||
, verbose = -1L
|
||||
)
|
||||
|
||||
pred <- predict(bst, train$data)
|
||||
pred_leaf_indx <- predict(bst, train$data, predleaf = TRUE)
|
||||
pred_raw_score <- predict(bst, train$data, rawscore = TRUE)
|
||||
model_string <- bst$save_model_to_string()
|
||||
|
||||
# make sure this test is still producing a model bigger than the default
|
||||
# buffer size used in LGBM_BoosterSaveModelToString_R
|
||||
expect_gt(nchar(model_string), 1024L * 1024L)
|
||||
|
||||
# finalize the booster and destroy it so you know we aren't cheating
|
||||
bst$finalize()
|
||||
expect_null(bst$.__enclos_env__$private$handle)
|
||||
rm(bst)
|
||||
|
||||
# make sure a new model can be created from this string, and that it
|
||||
# produces expected results
|
||||
bst2 <- lgb.load(
|
||||
model_str = model_string
|
||||
)
|
||||
pred2 <- predict(bst2, train$data)
|
||||
pred2_leaf_indx <- predict(bst2, train$data, predleaf = TRUE)
|
||||
pred2_raw_score <- predict(bst2, train$data, rawscore = TRUE)
|
||||
expect_identical(pred, pred2)
|
||||
expect_identical(pred_leaf_indx, pred2_leaf_indx)
|
||||
expect_identical(pred_raw_score, pred2_raw_score)
|
||||
})
|
||||
|
||||
test_that("Saving a large model to JSON should work", {
|
||||
set.seed(708L)
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
train <- agaricus.train
|
||||
bst <- lightgbm(
|
||||
data = as.matrix(train$data)
|
||||
, label = train$label
|
||||
, num_leaves = 100L
|
||||
, learning_rate = 0.01
|
||||
, nrounds = 200L
|
||||
, objective = "binary"
|
||||
, save_name = tempfile(fileext = ".model")
|
||||
, verbose = -1L
|
||||
)
|
||||
|
||||
model_json <- bst$dump_model()
|
||||
|
||||
# make sure this test is still producing a model bigger than the default
|
||||
# buffer size used in LGBM_BoosterDumpModel_R
|
||||
expect_gt(nchar(model_json), 1024L * 1024L)
|
||||
|
||||
# check that it is valid JSON that looks like a LightGBM model
|
||||
model_list <- jsonlite::fromJSON(model_json)
|
||||
expect_equal(model_list[["objective"]], "binary sigmoid:1")
|
||||
})
|
||||
|
||||
test_that("If a string and a file are both passed to lgb.load() the file is used model_str is totally ignored", {
|
||||
set.seed(708L)
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
|
|
|
@ -1,12 +1,3 @@
|
|||
context("lgb.encode.char")
|
||||
|
||||
test_that("lgb.encode.char throws an informative error if it is passed a non-raw input", {
|
||||
x <- "some-string"
|
||||
expect_error({
|
||||
lgb.encode.char(x)
|
||||
}, regexp = "Can only encode from raw type")
|
||||
})
|
||||
|
||||
context("lgb.check.r6.class")
|
||||
|
||||
test_that("lgb.check.r6.class() should return FALSE for NULL input", {
|
||||
|
|
Загрузка…
Ссылка в новой задаче