diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index ad144be1d..f3165e1fa 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -158,16 +158,23 @@ LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle, R_API_BEGIN(); int len = 0; CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len)); + const size_t reserved_string_size = 256; std::vector> names(len); std::vector ptr_names(len); for (int i = 0; i < len; ++i) { - names[i].resize(256); + names[i].resize(reserved_string_size); ptr_names[i] = names[i].data(); } int out_len; - CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle), - ptr_names.data(), &out_len)); + size_t required_string_size; + CHECK_CALL( + LGBM_DatasetGetFeatureNames( + R_GET_PTR(handle), + len, &out_len, + reserved_string_size, &required_string_size, + ptr_names.data())); CHECK_EQ(len, out_len); + CHECK_GE(reserved_string_size, required_string_size); auto merge_str = Join(ptr_names, "\t"); EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1); R_API_END(); diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 9a842fcd1..9d7c6e61d 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -280,13 +280,21 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(DatasetHandle handle, /*! * \brief Get feature names of dataset. * \param handle Handle of dataset - * \param[out] feature_names Feature names, should pre-allocate memory + * \param len Number of ``char*`` pointers stored at ``out_strs``. + * If smaller than the max size, only this many strings are copied * \param[out] num_feature_names Number of feature names + * \param buffer_len Size of pre-allocated strings. + * Content is copied up to ``buffer_len - 1`` and null-terminated + * \param[out] out_buffer_len String sizes required to do the full string copies + * \param[out] feature_names Feature names, should pre-allocate memory * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(DatasetHandle handle, - char** feature_names, - int* num_feature_names); + const int len, + int* num_feature_names, + const size_t buffer_len, + size_t* out_buffer_len, + char** feature_names); /*! * \brief Free space for dataset. diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 97ead0d61..07b7efd41 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1553,6 +1553,38 @@ class Dataset(object): self.set_field('group', group) return self + def get_feature_name(self): + """Get the names of columns (features) in the Dataset. + + Returns + ------- + feature_names : list + The names of columns (features) in the Dataset. + """ + if self.handle is None: + raise LightGBMError("Cannot get feature_name before construct dataset") + num_feature = self.num_feature() + tmp_out_len = ctypes.c_int(0) + reserved_string_buffer_size = 255 + required_string_buffer_size = ctypes.c_size_t(0) + string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(num_feature)] + ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) + _safe_call(_LIB.LGBM_DatasetGetFeatureNames( + self.handle, + num_feature, + ctypes.byref(tmp_out_len), + reserved_string_buffer_size, + ctypes.byref(required_string_buffer_size), + ptr_string_buffers)) + if num_feature != tmp_out_len.value: + raise ValueError("Length of feature names doesn't equal with num_feature") + if reserved_string_buffer_size < required_string_buffer_size.value: + raise BufferError( + "Allocated feature name buffer size ({}) was inferior to the needed size ({})." + .format(reserved_string_buffer_size, required_string_buffer_size.value) + ) + return [string_buffers[i].value.decode('utf-8') for i in range_(num_feature)] + def get_label(self): """Get the label of the Dataset. diff --git a/src/c_api.cpp b/src/c_api.cpp index cbdd98468..290f219fa 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1110,15 +1110,23 @@ int LGBM_DatasetSetFeatureNames( } int LGBM_DatasetGetFeatureNames( - DatasetHandle handle, - char** feature_names, - int* num_feature_names) { + DatasetHandle handle, + const int len, + int* num_feature_names, + const size_t buffer_len, + size_t* out_buffer_len, + char** feature_names) { API_BEGIN(); + *out_buffer_len = 0; auto dataset = reinterpret_cast(handle); auto inside_feature_name = dataset->feature_names(); *num_feature_names = static_cast(inside_feature_name.size()); for (int i = 0; i < *num_feature_names; ++i) { - std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1); + if (i < len) { + std::memcpy(feature_names[i], inside_feature_name[i].c_str(), std::min(inside_feature_name[i].size() + 1, buffer_len)); + feature_names[i][buffer_len - 1] = '\0'; + } + *out_buffer_len = std::max(inside_feature_name[i].size() + 1, *out_buffer_len); } API_END(); } diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 27c3aff13..85e9e728d 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -271,15 +271,20 @@ class TestBasic(unittest.TestCase): self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]], data.label[0]))) self.assertAlmostEqual(data.label[1], data.weight[1]) + self.assertListEqual(data.feature_name, data.get_feature_name()) X, y = load_breast_cancer(True) sequence = np.ones(y.shape[0]) sequence[0] = np.nan sequence[1] = np.inf - lgb_data = lgb.Dataset(X, sequence, weight=sequence, init_score=sequence).construct() + feature_names = ['f{0}'.format(i) for i in range(X.shape[1])] + lgb_data = lgb.Dataset(X, sequence, + weight=sequence, init_score=sequence, + feature_name=feature_names).construct() check_asserts(lgb_data) lgb_data = lgb.Dataset(X, y).construct() lgb_data.set_label(sequence) lgb_data.set_weight(sequence) lgb_data.set_init_score(sequence) + lgb_data.set_feature_name(feature_names) check_asserts(lgb_data)