[R-package] fix protection stack imbalance and unprotected objects (fixes #4390) (#4391)

* [R-package] fix protection stack imbalance and unprotected objects issues

* [R-package] fix minor linting issues

* [ci][R-package] change timeout-minutes in valgrind test

* [R-package] remove extra space

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* [R-package] remove counter for number of protected objects

* Update .github/workflows/r_valgrind.yml

Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
Fabio Sigrist 2021-06-27 07:08:49 +02:00 коммит произвёл GitHub
Родитель f62c490474
Коммит aacb4c8fd9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 129 добавлений и 70 удалений

Просмотреть файл

@ -28,15 +28,13 @@
#define R_API_BEGIN() \
try {
#define R_API_END() } \
catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \
catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \
catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \
return R_NilValue;
catch(std::exception& ex) { LGBM_SetLastError(ex.what()); } \
catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); } \
catch(...) { LGBM_SetLastError("unknown exception"); }
#define CHECK_CALL(x) \
if ((x) != 0) { \
Rf_error(LGBM_GetLastError()); \
return R_NilValue; \
}
using LightGBM::Common::Split;
@ -54,19 +52,20 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
SEXP parameters,
SEXP reference) {
SEXP ret;
R_API_BEGIN();
DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference);
}
CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)),
ref, &handle));
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1);
UNPROTECT(3);
return ret;
R_API_END();
}
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
@ -78,27 +77,27 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
SEXP parameters,
SEXP reference) {
SEXP ret;
R_API_BEGIN();
const int* p_indptr = INTEGER(indptr);
const int* p_indices = INTEGER(indices);
const double* p_data = REAL(data);
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference);
}
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, CHAR(Rf_asChar(parameters)), ref, &handle));
nrow, parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1);
UNPROTECT(2);
return ret;
R_API_END();
}
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
@ -107,22 +106,23 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
SEXP parameters,
SEXP reference) {
SEXP ret;
R_API_BEGIN();
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
double* p_mat = REAL(data);
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference);
}
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
CHAR(Rf_asChar(parameters)), ref, &handle));
parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1);
UNPROTECT(2);
return ret;
R_API_END();
}
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
@ -130,7 +130,6 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP len_used_row_indices,
SEXP parameters) {
SEXP ret;
R_API_BEGIN();
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len);
// convert from one-based to zero-based index
@ -138,36 +137,41 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
for (int32_t i = 0; i < len; ++i) {
idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
}
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle res = nullptr;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
idxvec.data(), len, CHAR(Rf_asChar(parameters)),
idxvec.data(), len, parameters_ptr,
&res));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1);
UNPROTECT(2);
return ret;
R_API_END();
}
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP feature_names) {
R_API_BEGIN();
auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t');
auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
std::vector<const char*> vec_sptr;
int len = static_cast<int>(vec_names.size());
for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str());
}
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
vec_sptr.data(), len));
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP feature_names;
R_API_BEGIN();
int len = 0;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
R_API_END();
const size_t reserved_string_size = 256;
std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len);
@ -177,12 +181,14 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
}
int out_len;
size_t required_string_size;
R_API_BEGIN();
CHECK_CALL(
LGBM_DatasetGetFeatureNames(
R_ExternalPtrAddr(handle),
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
R_API_END();
// if any feature names were larger than allocated size,
// allow for a larger size and try again
if (required_string_size > reserved_string_size) {
@ -190,6 +196,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
names[i].resize(required_string_size);
ptr_names[i] = names[i].data();
}
R_API_BEGIN();
CHECK_CALL(
LGBM_DatasetGetFeatureNames(
R_ExternalPtrAddr(handle),
@ -198,6 +205,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
required_string_size,
&required_string_size,
ptr_names.data()));
R_API_END();
}
CHECK_EQ(len, out_len);
feature_names = PROTECT(Rf_allocVector(STRSXP, len));
@ -206,15 +214,17 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
}
UNPROTECT(1);
return feature_names;
R_API_END();
}
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
SEXP filename) {
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
CHAR(Rf_asChar(filename))));
filename_ptr));
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_DatasetFree_R(SEXP handle) {
@ -224,15 +234,16 @@ SEXP LGBM_DatasetFree_R(SEXP handle) {
R_ClearExternalPtr(handle);
}
R_API_END();
return R_NilValue;
}
SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP field_name,
SEXP field_data,
SEXP num_element) {
R_API_BEGIN();
int len = Rf_asInteger(num_element);
const char* name = CHAR(Rf_asChar(field_name));
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
R_API_BEGIN();
if (!strcmp("group", name) || !strcmp("query", name)) {
std::vector<int32_t> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
@ -251,18 +262,19 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
}
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_DatasetGetField_R(SEXP handle,
SEXP field_name,
SEXP field_data) {
R_API_BEGIN();
const char* name = CHAR(Rf_asChar(field_name));
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
const void* res;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) {
auto p_data = reinterpret_cast<const int32_t*>(res);
// convert from boundaries to size
@ -284,29 +296,37 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
}
}
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
SEXP field_name,
SEXP out) {
R_API_BEGIN();
const char* name = CHAR(Rf_asChar(field_name));
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
const void* res;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) {
out_len -= 1;
}
INTEGER(out)[0] = out_len;
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
SEXP new_params) {
const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params)));
const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params)));
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetUpdateParamChecking(CHAR(Rf_asChar(old_params)), CHAR(Rf_asChar(new_params))));
CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
R_API_END();
UNPROTECT(2);
return R_NilValue;
}
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
@ -315,6 +335,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = nrow;
R_API_END();
return R_NilValue;
}
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
@ -324,6 +345,7 @@ SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
INTEGER(out)[0] = nfeature;
R_API_END();
return R_NilValue;
}
// --- start Booster interfaces
@ -339,45 +361,49 @@ SEXP LGBM_BoosterFree_R(SEXP handle) {
R_ClearExternalPtr(handle);
}
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterCreate_R(SEXP train_data,
SEXP parameters) {
SEXP ret;
R_API_BEGIN();
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), CHAR(Rf_asChar(parameters)), &handle));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1);
UNPROTECT(2);
return ret;
R_API_END();
}
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
SEXP ret;
R_API_BEGIN();
int out_num_iterations = 0;
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1);
UNPROTECT(2);
return ret;
R_API_END();
}
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
SEXP ret;
R_API_BEGIN();
int out_num_iterations = 0;
const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str)));
BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1);
UNPROTECT(2);
return ret;
R_API_END();
}
SEXP LGBM_BoosterMerge_R(SEXP handle,
@ -385,6 +411,7 @@ SEXP LGBM_BoosterMerge_R(SEXP handle,
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
@ -392,6 +419,7 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle,
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
@ -399,13 +427,17 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP parameters) {
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(parameters))));
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
@ -415,6 +447,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
INTEGER(out)[0] = num_class;
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
@ -422,6 +455,7 @@ SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
@ -439,12 +473,14 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
}
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
@ -454,6 +490,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
INTEGER(out)[0] = out_iteration;
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
@ -462,6 +499,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
@ -470,14 +508,15 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
SEXP eval_names;
R_API_BEGIN();
int len;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
R_API_END();
const size_t reserved_string_size = 128;
std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len);
@ -488,12 +527,14 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
int out_len;
size_t required_string_size;
R_API_BEGIN();
CHECK_CALL(
LGBM_BoosterGetEvalNames(
R_ExternalPtrAddr(handle),
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
R_API_END();
// if any eval names were larger than allocated size,
// allow for a larger size and try again
if (required_string_size > reserved_string_size) {
@ -501,6 +542,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
names[i].resize(required_string_size);
ptr_names[i] = names[i].data();
}
R_API_BEGIN();
CHECK_CALL(
LGBM_BoosterGetEvalNames(
R_ExternalPtrAddr(handle),
@ -509,6 +551,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
required_string_size,
&required_string_size,
ptr_names.data()));
R_API_END();
}
CHECK_EQ(out_len, len);
eval_names = PROTECT(Rf_allocVector(STRSXP, len));
@ -517,7 +560,6 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
}
UNPROTECT(1);
return eval_names;
R_API_END();
}
SEXP LGBM_BoosterGetEval_R(SEXP handle,
@ -531,6 +573,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
CHECK_EQ(out_len, len);
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
@ -541,6 +584,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
INTEGER(out)[0] = static_cast<int>(len);
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
@ -551,6 +595,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle,
int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
R_API_END();
return R_NilValue;
}
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
@ -577,12 +622,17 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
SEXP num_iteration,
SEXP parameter,
SEXP result_filename) {
R_API_BEGIN();
const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(data_filename)),
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)),
CHAR(Rf_asChar(result_filename))));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
result_filename_ptr));
R_API_END();
UNPROTECT(3);
return R_NilValue;
}
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
@ -600,6 +650,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
INTEGER(out_len)[0] = static_cast<int>(len);
R_API_END();
return R_NilValue;
}
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
@ -616,23 +667,24 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
SEXP num_iteration,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = INTEGER(indptr);
const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
const double* p_data = REAL(data);
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
double* ptr_ret = REAL(out_result);
int64_t out_len;
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
@ -646,75 +698,82 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
SEXP num_iteration,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
const double* p_mat = REAL(data);
double* ptr_ret = REAL(out_result);
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
int64_t out_len;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP filename) {
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename))));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
R_API_END();
UNPROTECT(1);
return R_NilValue;
}
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type) {
SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
R_API_END();
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
R_API_END();
}
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1);
return model_str;
R_API_END();
}
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type) {
SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
R_API_END();
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
R_API_END();
}
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1);
return model_str;
R_API_END();
}
// .Call() calls