зеркало из https://github.com/microsoft/LightGBM.git
refactor LGBM_DatasetGetFeatureNames (#3022)
This commit is contained in:
Родитель
b3a84df5af
Коммит
f30e0bb3d9
|
@ -158,16 +158,23 @@ LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
|
|||
R_API_BEGIN();
|
||||
int len = 0;
|
||||
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
|
||||
const size_t reserved_string_size = 256;
|
||||
std::vector<std::vector<char>> names(len);
|
||||
std::vector<char*> ptr_names(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
names[i].resize(256);
|
||||
names[i].resize(reserved_string_size);
|
||||
ptr_names[i] = names[i].data();
|
||||
}
|
||||
int out_len;
|
||||
CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
|
||||
ptr_names.data(), &out_len));
|
||||
size_t required_string_size;
|
||||
CHECK_CALL(
|
||||
LGBM_DatasetGetFeatureNames(
|
||||
R_GET_PTR(handle),
|
||||
len, &out_len,
|
||||
reserved_string_size, &required_string_size,
|
||||
ptr_names.data()));
|
||||
CHECK_EQ(len, out_len);
|
||||
CHECK_GE(reserved_string_size, required_string_size);
|
||||
auto merge_str = Join<char*>(ptr_names, "\t");
|
||||
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
|
||||
R_API_END();
|
||||
|
|
|
@ -280,13 +280,21 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(DatasetHandle handle,
|
|||
/*!
|
||||
* \brief Get feature names of dataset.
|
||||
* \param handle Handle of dataset
|
||||
* \param[out] feature_names Feature names, should pre-allocate memory
|
||||
* \param len Number of ``char*`` pointers stored at ``out_strs``.
|
||||
* If smaller than the max size, only this many strings are copied
|
||||
* \param[out] num_feature_names Number of feature names
|
||||
* \param buffer_len Size of pre-allocated strings.
|
||||
* Content is copied up to ``buffer_len - 1`` and null-terminated
|
||||
* \param[out] out_buffer_len String sizes required to do the full string copies
|
||||
* \param[out] feature_names Feature names, should pre-allocate memory
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(DatasetHandle handle,
|
||||
char** feature_names,
|
||||
int* num_feature_names);
|
||||
const int len,
|
||||
int* num_feature_names,
|
||||
const size_t buffer_len,
|
||||
size_t* out_buffer_len,
|
||||
char** feature_names);
|
||||
|
||||
/*!
|
||||
* \brief Free space for dataset.
|
||||
|
|
|
@ -1553,6 +1553,38 @@ class Dataset(object):
|
|||
self.set_field('group', group)
|
||||
return self
|
||||
|
||||
def get_feature_name(self):
|
||||
"""Get the names of columns (features) in the Dataset.
|
||||
|
||||
Returns
|
||||
-------
|
||||
feature_names : list
|
||||
The names of columns (features) in the Dataset.
|
||||
"""
|
||||
if self.handle is None:
|
||||
raise LightGBMError("Cannot get feature_name before construct dataset")
|
||||
num_feature = self.num_feature()
|
||||
tmp_out_len = ctypes.c_int(0)
|
||||
reserved_string_buffer_size = 255
|
||||
required_string_buffer_size = ctypes.c_size_t(0)
|
||||
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(num_feature)]
|
||||
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
|
||||
_safe_call(_LIB.LGBM_DatasetGetFeatureNames(
|
||||
self.handle,
|
||||
num_feature,
|
||||
ctypes.byref(tmp_out_len),
|
||||
reserved_string_buffer_size,
|
||||
ctypes.byref(required_string_buffer_size),
|
||||
ptr_string_buffers))
|
||||
if num_feature != tmp_out_len.value:
|
||||
raise ValueError("Length of feature names doesn't equal with num_feature")
|
||||
if reserved_string_buffer_size < required_string_buffer_size.value:
|
||||
raise BufferError(
|
||||
"Allocated feature name buffer size ({}) was inferior to the needed size ({})."
|
||||
.format(reserved_string_buffer_size, required_string_buffer_size.value)
|
||||
)
|
||||
return [string_buffers[i].value.decode('utf-8') for i in range_(num_feature)]
|
||||
|
||||
def get_label(self):
|
||||
"""Get the label of the Dataset.
|
||||
|
||||
|
|
|
@ -1110,15 +1110,23 @@ int LGBM_DatasetSetFeatureNames(
|
|||
}
|
||||
|
||||
int LGBM_DatasetGetFeatureNames(
|
||||
DatasetHandle handle,
|
||||
char** feature_names,
|
||||
int* num_feature_names) {
|
||||
DatasetHandle handle,
|
||||
const int len,
|
||||
int* num_feature_names,
|
||||
const size_t buffer_len,
|
||||
size_t* out_buffer_len,
|
||||
char** feature_names) {
|
||||
API_BEGIN();
|
||||
*out_buffer_len = 0;
|
||||
auto dataset = reinterpret_cast<Dataset*>(handle);
|
||||
auto inside_feature_name = dataset->feature_names();
|
||||
*num_feature_names = static_cast<int>(inside_feature_name.size());
|
||||
for (int i = 0; i < *num_feature_names; ++i) {
|
||||
std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
|
||||
if (i < len) {
|
||||
std::memcpy(feature_names[i], inside_feature_name[i].c_str(), std::min(inside_feature_name[i].size() + 1, buffer_len));
|
||||
feature_names[i][buffer_len - 1] = '\0';
|
||||
}
|
||||
*out_buffer_len = std::max(inside_feature_name[i].size() + 1, *out_buffer_len);
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
|
|
@ -271,15 +271,20 @@ class TestBasic(unittest.TestCase):
|
|||
self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]],
|
||||
data.label[0])))
|
||||
self.assertAlmostEqual(data.label[1], data.weight[1])
|
||||
self.assertListEqual(data.feature_name, data.get_feature_name())
|
||||
|
||||
X, y = load_breast_cancer(True)
|
||||
sequence = np.ones(y.shape[0])
|
||||
sequence[0] = np.nan
|
||||
sequence[1] = np.inf
|
||||
lgb_data = lgb.Dataset(X, sequence, weight=sequence, init_score=sequence).construct()
|
||||
feature_names = ['f{0}'.format(i) for i in range(X.shape[1])]
|
||||
lgb_data = lgb.Dataset(X, sequence,
|
||||
weight=sequence, init_score=sequence,
|
||||
feature_name=feature_names).construct()
|
||||
check_asserts(lgb_data)
|
||||
lgb_data = lgb.Dataset(X, y).construct()
|
||||
lgb_data.set_label(sequence)
|
||||
lgb_data.set_weight(sequence)
|
||||
lgb_data.set_init_score(sequence)
|
||||
lgb_data.set_feature_name(feature_names)
|
||||
check_asserts(lgb_data)
|
||||
|
|
Загрузка…
Ссылка в новой задаче