[R-package] Use ALTREP system to return C++-allocated arrays (#6213)

This commit is contained in:
david-cortes 2024-05-29 06:19:38 +02:00 коммит произвёл GitHub
Родитель f6c8f5d8a1
Коммит dee8a18889
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 211 добавлений и 11 удалений

Просмотреть файл

@ -11,6 +11,7 @@
#include <LightGBM/utils/text_reader.h>
#include <R_ext/Rdynload.h>
#include <R_ext/Altrep.h>
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
@ -24,6 +25,150 @@
#include <utility>
#include <vector>
#include <algorithm>
#include <type_traits>
R_altrep_class_t lgb_altrepped_char_vec;
R_altrep_class_t lgb_altrepped_int_arr;
R_altrep_class_t lgb_altrepped_dbl_arr;
template <class T>
void delete_cpp_array(SEXP R_ptr) {
T *ptr_to_cpp_obj = static_cast<T*>(R_ExternalPtrAddr(R_ptr));
delete[] ptr_to_cpp_obj;
R_ClearExternalPtr(R_ptr);
}
void delete_cpp_char_vec(SEXP R_ptr) {
std::vector<char> *ptr_to_cpp_obj = static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_ptr));
delete ptr_to_cpp_obj;
R_ClearExternalPtr(R_ptr);
}
// Note: MSVC has issues with Altrep classes, so they are disabled for it.
// See: https://github.com/microsoft/LightGBM/pull/6213#issuecomment-2111025768
#ifdef _MSC_VER
# define LGB_NO_ALTREP
#endif
#ifndef LGB_NO_ALTREP
SEXP make_altrepped_raw_vec(void *void_ptr) {
std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
SEXP R_ptr = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
SEXP R_raw = PROTECT(R_new_altrep(lgb_altrepped_char_vec, R_NilValue, R_NilValue));
R_SetExternalPtrAddr(R_ptr, ptr_to_cpp_vec->get());
R_RegisterCFinalizerEx(R_ptr, delete_cpp_char_vec, TRUE);
ptr_to_cpp_vec->release();
R_set_altrep_data1(R_raw, R_ptr);
UNPROTECT(2);
return R_raw;
}
#else
SEXP make_r_raw_vec(void *void_ptr) {
std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
R_xlen_t len = ptr_to_cpp_vec->get()->size();
SEXP out = PROTECT(Rf_allocVector(RAWSXP, len));
std::copy(ptr_to_cpp_vec->get()->begin(), ptr_to_cpp_vec->get()->end(), reinterpret_cast<char*>(RAW(out)));
UNPROTECT(1);
return out;
}
#define make_altrepped_raw_vec make_r_raw_vec
#endif
std::vector<char>* get_ptr_from_altrepped_raw(SEXP R_raw) {
return static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_altrep_data1(R_raw)));
}
R_xlen_t get_altrepped_raw_len(SEXP R_raw) {
return get_ptr_from_altrepped_raw(R_raw)->size();
}
const void* get_altrepped_raw_dataptr_or_null(SEXP R_raw) {
return get_ptr_from_altrepped_raw(R_raw)->data();
}
void* get_altrepped_raw_dataptr(SEXP R_raw, Rboolean writeable) {
return get_ptr_from_altrepped_raw(R_raw)->data();
}
#ifndef LGB_NO_ALTREP
template <class T>
R_altrep_class_t get_altrep_class_for_type() {
if (std::is_same<T, double>::value) {
return lgb_altrepped_dbl_arr;
} else {
return lgb_altrepped_int_arr;
}
}
#else
template <class T>
SEXPTYPE get_sexptype_class_for_type() {
if (std::is_same<T, double>::value) {
return REALSXP;
} else {
return INTSXP;
}
}
template <class T>
T* get_r_vec_ptr(SEXP x) {
if (std::is_same<T, double>::value) {
return static_cast<T*>(static_cast<void*>(REAL(x)));
} else {
return static_cast<T*>(static_cast<void*>(INTEGER(x)));
}
}
#endif
template <class T>
struct arr_and_len {
T *arr;
int64_t len;
};
#ifndef LGB_NO_ALTREP
template <class T>
SEXP make_altrepped_vec_from_arr(void *void_ptr) {
T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
SEXP R_ptr = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
SEXP R_len = PROTECT(Rf_allocVector(REALSXP, 1));
SEXP R_vec = PROTECT(R_new_altrep(get_altrep_class_for_type<T>(), R_NilValue, R_NilValue));
REAL(R_len)[0] = static_cast<double>(len);
R_SetExternalPtrAddr(R_ptr, arr);
R_RegisterCFinalizerEx(R_ptr, delete_cpp_array<T>, TRUE);
R_set_altrep_data1(R_vec, R_ptr);
R_set_altrep_data2(R_vec, R_len);
UNPROTECT(3);
return R_vec;
}
#else
template <class T>
SEXP make_R_vec_from_arr(void *void_ptr) {
T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
SEXP out = PROTECT(Rf_allocVector(get_sexptype_class_for_type<T>(), len));
std::copy(arr, arr + len, get_r_vec_ptr<T>(out));
UNPROTECT(1);
return out;
}
#define make_altrepped_vec_from_arr make_R_vec_from_arr
#endif
R_xlen_t get_altrepped_vec_len(SEXP R_vec) {
return static_cast<R_xlen_t>(Rf_asReal(R_altrep_data2(R_vec)));
}
const void* get_altrepped_vec_dataptr_or_null(SEXP R_vec) {
return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}
void* get_altrepped_vec_dataptr(SEXP R_vec, Rboolean writeable) {
return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}
#define COL_MAJOR (0)
@ -964,8 +1109,6 @@ struct SparseOutputPointers {
void* indptr;
int32_t* indices;
void* data;
int indptr_type;
int data_type;
SparseOutputPointers(void* indptr, int32_t* indices, void* data)
: indptr(indptr), indices(indices), data(data) {}
};
@ -1015,15 +1158,26 @@ SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
&delete_SparseOutputPointers
};
SEXP out_indptr_R = safe_R_int(out_len[1], &cont_token);
SET_VECTOR_ELT(out, 0, out_indptr_R);
SEXP out_indices_R = safe_R_int(out_len[0], &cont_token);
SET_VECTOR_ELT(out, 1, out_indices_R);
SEXP out_data_R = safe_R_real(out_len[0], &cont_token);
SET_VECTOR_ELT(out, 2, out_data_R);
std::memcpy(INTEGER(out_indptr_R), out_indptr, out_len[1]*sizeof(int));
std::memcpy(INTEGER(out_indices_R), out_indices, out_len[0]*sizeof(int));
std::memcpy(REAL(out_data_R), out_data, out_len[0]*sizeof(double));
arr_and_len<int> indptr_str{static_cast<int*>(out_indptr), out_len[1]};
SET_VECTOR_ELT(
out, 0,
R_UnwindProtect(make_altrepped_vec_from_arr<int>,
static_cast<void*>(&indptr_str), throw_R_memerr, &cont_token, cont_token));
pointers_struct->indptr = nullptr;
arr_and_len<int> indices_str{static_cast<int*>(out_indices), out_len[0]};
SET_VECTOR_ELT(
out, 1,
R_UnwindProtect(make_altrepped_vec_from_arr<int>,
static_cast<void*>(&indices_str), throw_R_memerr, &cont_token, cont_token));
pointers_struct->indices = nullptr;
arr_and_len<double> data_str{static_cast<double*>(out_data), out_len[0]};
SET_VECTOR_ELT(
out, 2,
R_UnwindProtect(make_altrepped_vec_from_arr<double>,
static_cast<void*>(&data_str), throw_R_memerr, &cont_token, cont_token));
pointers_struct->data = nullptr;
UNPROTECT(3);
return out;
@ -1104,6 +1258,34 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
R_API_END();
}
// Note: for some reason, MSVC crashes when an error is thrown here
// if the buffer variable is defined as 'std::unique_ptr<std::vector<char>>',
// but not if it is defined as '<std::vector<char>'.
#ifndef _MSC_VER
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP start_iteration) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::unique_ptr<std::vector<char>> inner_char_buf(new std::vector<char>(buf_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf->data()));
inner_char_buf->resize(out_len);
if (out_len > buf_len) {
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf->data()));
}
SEXP out = R_UnwindProtect(make_altrepped_raw_vec, &inner_char_buf, throw_R_memerr, &cont_token, cont_token);
UNPROTECT(1);
return out;
R_API_END();
}
#else
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
@ -1129,6 +1311,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
return model_str;
R_API_END();
}
#endif
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration,
@ -1281,4 +1464,21 @@ LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);
void R_init_lightgbm(DllInfo *dll) {
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
R_useDynamicSymbols(dll, FALSE);
#ifndef LGB_NO_ALTREP
lgb_altrepped_char_vec = R_make_altraw_class("lgb_altrepped_char_vec", "lightgbm", dll);
R_set_altrep_Length_method(lgb_altrepped_char_vec, get_altrepped_raw_len);
R_set_altvec_Dataptr_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr);
R_set_altvec_Dataptr_or_null_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr_or_null);
lgb_altrepped_int_arr = R_make_altinteger_class("lgb_altrepped_int_arr", "lightgbm", dll);
R_set_altrep_Length_method(lgb_altrepped_int_arr, get_altrepped_vec_len);
R_set_altvec_Dataptr_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr);
R_set_altvec_Dataptr_or_null_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr_or_null);
lgb_altrepped_dbl_arr = R_make_altreal_class("lgb_altrepped_dbl_arr", "lightgbm", dll);
R_set_altrep_Length_method(lgb_altrepped_dbl_arr, get_altrepped_vec_len);
R_set_altvec_Dataptr_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr);
R_set_altvec_Dataptr_or_null_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr_or_null);
#endif
}