feature importance type in saved model file (#3220)

* feature importance type in saved model file

* fix nullptr

* fixed formatting

* fix python/R

* Update src/c_api.cpp

* Apply suggestions from code review

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* fix c_api test

* fix swig

* minor docs improvements and added defines for importance types

Co-authored-by: StrikerRUS <nekit94-12@hotmail.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
Guolin Ke 2020-07-16 03:18:53 +08:00 коммит произвёл GitHub
Родитель 7b8b51518c
Коммит 87d46489f3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 132 добавлений и 51 удалений

Просмотреть файл

@ -424,7 +424,7 @@ Booster <- R6::R6Class(
},
# Save model
save_model = function(filename, num_iteration = NULL) {
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
# Check if number of iteration is non existent
if (is.null(num_iteration)) {
@ -437,6 +437,7 @@ Booster <- R6::R6Class(
, ret = NULL
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, lgb.c_str(filename)
)
@ -445,7 +446,7 @@ Booster <- R6::R6Class(
},
# Save model to string
save_model_to_string = function(num_iteration = NULL) {
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
# Check if number of iteration is non existent
if (is.null(num_iteration)) {
@ -457,12 +458,13 @@ Booster <- R6::R6Class(
"LGBM_BoosterSaveModelToString_R"
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
))
},
# Dump model in memory
dump_model = function(num_iteration = NULL) {
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
# Check if number of iteration is non existent
if (is.null(num_iteration)) {
@ -474,6 +476,7 @@ Booster <- R6::R6Class(
"LGBM_BoosterDumpModel_R"
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
)
},

Просмотреть файл

@ -632,15 +632,17 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename,
LGBM_SE call_state) {
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_CHAR_PTR(filename)));
R_API_END();
}
LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
@ -648,13 +650,14 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}
LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
@ -662,7 +665,7 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}
@ -707,9 +710,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 6},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 6},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 7},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 7},
{NULL, NULL, 0}
};

Просмотреть файл

@ -590,6 +590,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename,
LGBM_SE call_state
);
@ -604,6 +605,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
@ -620,6 +622,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterDumpModel_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,

Просмотреть файл

@ -574,6 +574,14 @@ Learning Control Parameters
- **Note**: can be used only in CLI version
- ``saved_feature_importance_type`` :raw-html:`<a id="saved_feature_importance_type" title="Permalink to this parameter" href="#saved_feature_importance_type">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
- the feature importance type in the saved model file
- ``0``: count-based feature importance (numbers of splits are counted); ``1``: gain-based feature importance (values of gain are counted)
- **Note**: can be used only in CLI version
- ``snapshot_freq`` :raw-html:`<a id="snapshot_freq" title="Permalink to this parameter" href="#snapshot_freq">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int, aliases: ``save_period``
- frequency of saving model file snapshot

Просмотреть файл

@ -176,9 +176,10 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Dump model to json format string
* \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Json format string of model
*/
virtual std::string DumpModel(int start_iteration, int num_iteration) const = 0;
virtual std::string DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const = 0;
/*!
* \brief Translate model to if-else statement
@ -199,19 +200,20 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Save model to file
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \param filename Filename that want to save to
* \return true if succeeded
*/
virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const = 0;
virtual bool SaveModelToFile(int start_iteration, int num_iterations, int feature_importance_type, const char* filename) const = 0;
/*!
* \brief Save model to string
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int start_iteration, int num_iterations) const = 0;
virtual std::string SaveModelToString(int start_iteration, int num_iterations, int feature_importance_type) const = 0;
/*!
* \brief Restore from a serialized string

Просмотреть файл

@ -36,6 +36,9 @@ typedef void* BoosterHandle; /*!< \brief Handle of booster. */
#define C_API_MATRIX_TYPE_CSR (0) /*!< \brief CSR sparse matrix type. */
#define C_API_MATRIX_TYPE_CSC (1) /*!< \brief CSC sparse matrix type. */
#define C_API_FEATURE_IMPORTANCE_SPLIT (0) /*!< \brief Split type of feature importance. */
#define C_API_FEATURE_IMPORTANCE_GAIN (1) /*!< \brief Gain type of feature importance. */
/*!
* \brief Get string message of the last error.
* \return Error information
@ -996,12 +999,14 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param filename The name of the file
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
const char* filename);
/*!
@ -1009,6 +1014,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str String of model, should pre-allocate memory
@ -1017,6 +1023,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
@ -1026,6 +1033,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be dumped
* \param num_iteration Index of the iteration that should be dumped, <= 0 means dump all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str JSON format string of model, should pre-allocate memory
@ -1034,6 +1042,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
@ -1069,8 +1078,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
* \param handle Handle of booster
* \param num_iteration Number of iterations for which feature importance is calculated, <= 0 means use all
* \param importance_type Method of importance calculation:
* - 0 for split, result contains numbers of times the feature is used in a model;
* - 1 for gain, result contains total gains of splits which use the feature
* - ``C_API_FEATURE_IMPORTANCE_SPLIT``: result contains numbers of times the feature is used in a model;
* - ``C_API_FEATURE_IMPORTANCE_GAIN``: result contains total gains of splits which use the feature
* \param[out] out_results Result array with feature importance
* \return 0 when succeed, -1 when failure happens
*/

Просмотреть файл

@ -532,6 +532,11 @@ struct Config {
// desc = **Note**: can be used only in CLI version
std::string output_model = "LightGBM_model.txt";
// desc = the feature importance type in the saved model file
// desc = ``0``: count-based feature importance (numbers of splits are counted); ``1``: gain-based feature importance (values of gain are counted)
// desc = **Note**: can be used only in CLI version
int saved_feature_importance_type = 0;
// [no-save]
// alias = save_period
// desc = frequency of saving model file snapshot

Просмотреть файл

@ -284,12 +284,20 @@ C_API_PREDICT_CONTRIB = 3
C_API_MATRIX_TYPE_CSR = 0
C_API_MATRIX_TYPE_CSC = 1
"""Macro definition of feature importance type"""
C_API_FEATURE_IMPORTANCE_SPLIT = 0
C_API_FEATURE_IMPORTANCE_GAIN = 1
"""Data type of data field"""
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
"weight": C_API_DTYPE_FLOAT32,
"init_score": C_API_DTYPE_FLOAT64,
"group": C_API_DTYPE_INT32}
"""String name to int feature importance type mapper"""
FEATURE_IMPORTANCE_TYPE_MAPPER = {"split": C_API_FEATURE_IMPORTANCE_SPLIT,
"gain": C_API_FEATURE_IMPORTANCE_GAIN}
def convert_from_sliced_object(data):
"""Fix the memory of multi-dimensional sliced object."""
@ -2600,7 +2608,7 @@ class Booster(object):
return [item for i in range_(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
def save_model(self, filename, num_iteration=None, start_iteration=0):
def save_model(self, filename, num_iteration=None, start_iteration=0, importance_type='split'):
"""Save Booster to file.
Parameters
@ -2613,6 +2621,10 @@ class Booster(object):
If <= 0, all iterations are saved.
start_iteration : int, optional (default=0)
Start index of the iteration that should be saved.
importance_type : string, optional (default="split")
What type of feature importance should be saved.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns
-------
@ -2621,10 +2633,12 @@ class Booster(object):
"""
if num_iteration is None:
num_iteration = self.best_iteration
importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
_safe_call(_LIB.LGBM_BoosterSaveModel(
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
c_str(filename)))
_dump_pandas_categorical(self.pandas_categorical, filename)
return self
@ -2685,7 +2699,7 @@ class Booster(object):
self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
return self
def model_to_string(self, num_iteration=None, start_iteration=0):
def model_to_string(self, num_iteration=None, start_iteration=0, importance_type='split'):
"""Save Booster to string.
Parameters
@ -2696,6 +2710,10 @@ class Booster(object):
If <= 0, all iterations are saved.
start_iteration : int, optional (default=0)
Start index of the iteration that should be saved.
importance_type : string, optional (default="split")
What type of feature importance should be saved.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns
-------
@ -2704,6 +2722,7 @@ class Booster(object):
"""
if num_iteration is None:
num_iteration = self.best_iteration
importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
@ -2712,6 +2731,7 @@ class Booster(object):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
@ -2724,6 +2744,7 @@ class Booster(object):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
@ -2731,7 +2752,7 @@ class Booster(object):
ret += _dump_pandas_categorical(self.pandas_categorical)
return ret
def dump_model(self, num_iteration=None, start_iteration=0):
def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split'):
"""Dump Booster to JSON format.
Parameters
@ -2742,6 +2763,10 @@ class Booster(object):
If <= 0, all iterations are dumped.
start_iteration : int, optional (default=0)
Start index of the iteration that should be dumped.
importance_type : string, optional (default="split")
What type of feature importance should be dumped.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns
-------
@ -2750,6 +2775,7 @@ class Booster(object):
"""
if num_iteration is None:
num_iteration = self.best_iteration
importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
@ -2758,6 +2784,7 @@ class Booster(object):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
@ -2770,6 +2797,7 @@ class Booster(object):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
@ -2969,12 +2997,7 @@ class Booster(object):
"""
if iteration is None:
iteration = self.best_iteration
if importance_type == "split":
importance_type_int = 0
elif importance_type == "gain":
importance_type_int = 1
else:
importance_type_int = -1
importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
result = np.zeros(self.num_feature(), dtype=np.float64)
_safe_call(_LIB.LGBM_BoosterFeatureImportance(
self.handle,

Просмотреть файл

@ -201,7 +201,8 @@ void Application::InitTrain() {
void Application::Train() {
Log::Info("Started training...");
boosting_->Train(config_.snapshot_freq, config_.output_model);
boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str());
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
@ -233,7 +234,8 @@ void Application::Predict() {
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str());
Log::Info("Finished RefitTree");
} else {
// create predictor

Просмотреть файл

@ -258,7 +258,7 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
if (snapshot_freq > 0
&& (iter + 1) % snapshot_freq == 0) {
std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
SaveModelToFile(0, -1, snapshot_out.c_str());
SaveModelToFile(0, -1, config_->saved_feature_importance_type, snapshot_out.c_str());
}
}
}

Просмотреть файл

@ -249,9 +249,11 @@ class GBDT : public GBDTBase {
* \brief Dump model to json format string
* \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Json format string of model
*/
std::string DumpModel(int start_iteration, int num_iteration) const override;
std::string DumpModel(int start_iteration, int num_iteration,
int feature_importance_type) const override;
/*!
* \brief Translate model to if-else statement
@ -272,18 +274,22 @@ class GBDT : public GBDTBase {
* \brief Save model to file
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const override;
bool SaveModelToFile(int start_iteration, int num_iterations,
int feature_importance_type,
const char* filename) const override;
/*!
* \brief Save model to string
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Non-empty string if succeeded
*/
std::string SaveModelToString(int start_iteration, int num_iterations) const override;
std::string SaveModelToString(int start_iteration, int num_iterations, int feature_importance_type) const override;
/*!
* \brief Restore from a serialized buffer

Просмотреть файл

@ -18,7 +18,7 @@ namespace LightGBM {
const char* kModelVersion = "v3";
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const {
std::stringstream str_buf;
str_buf << "{";
@ -95,7 +95,8 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
}
str_buf << "]," << '\n';
std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
std::vector<double> feature_importances = FeatureImportance(
num_iteration, feature_importance_type);
// store the importance first
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
@ -302,7 +303,7 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
return static_cast<bool>(output_file);
}
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const {
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int feature_importance_type) const {
std::stringstream ss;
// output model type
@ -363,8 +364,8 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons
tree_strs[i].clear();
}
ss << "end of trees" << "\n";
std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
std::vector<double> feature_importances = FeatureImportance(
num_iteration, feature_importance_type);
// store the importance first
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
@ -395,11 +396,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons
return ss.str();
}
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const {
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename, std::ios::out | std::ios::binary);
std::string str_to_write = SaveModelToString(start_iteration, num_iteration);
std::string str_to_write = SaveModelToString(start_iteration, num_iteration, feature_importance_type);
output_file.write(str_to_write.c_str(), str_to_write.size());
output_file.close();

Просмотреть файл

@ -689,8 +689,8 @@ class Booster {
boosting_->GetPredictAt(data_idx, out_result, out_len);
}
void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) {
boosting_->SaveModelToFile(start_iteration, num_iteration, feature_importance_type, filename);
}
void LoadModelFromString(const char* model_str) {
@ -698,12 +698,16 @@ class Booster {
boosting_->LoadModelFromString(model_str, len);
}
std::string SaveModelToString(int start_iteration, int num_iteration) {
return boosting_->SaveModelToString(start_iteration, num_iteration);
std::string SaveModelToString(int start_iteration, int num_iteration,
int feature_importance_type) {
return boosting_->SaveModelToString(start_iteration,
num_iteration, feature_importance_type);
}
std::string DumpModel(int start_iteration, int num_iteration) {
return boosting_->DumpModel(start_iteration, num_iteration);
std::string DumpModel(int start_iteration, int num_iteration,
int feature_importance_type) {
return boosting_->DumpModel(start_iteration, num_iteration,
feature_importance_type);
}
std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
@ -2010,22 +2014,26 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
const char* filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
ref_booster->SaveModelToFile(start_iteration, num_iteration,
feature_importance_type, filename);
API_END();
}
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
std::string model = ref_booster->SaveModelToString(
start_iteration, num_iteration, feature_importance_type);
*out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::memcpy(out_str, model.c_str(), *out_len);
@ -2036,12 +2044,14 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
std::string model = ref_booster->DumpModel(start_iteration, num_iteration,
feature_importance_type);
*out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::memcpy(out_str, model.c_str(), *out_len);

Просмотреть файл

@ -234,6 +234,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"verbosity",
"input_model",
"output_model",
"saved_feature_importance_type",
"snapshot_freq",
"max_bin",
"max_bin_by_feature",
@ -463,6 +464,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetString(params, "output_model", &output_model);
GetInt(params, "saved_feature_importance_type", &saved_feature_importance_type);
GetInt(params, "snapshot_freq", &snapshot_freq);
GetInt(params, "max_bin", &max_bin);
@ -664,6 +667,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[path_smooth: " << path_smooth << "]\n";
str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";

Просмотреть файл

@ -37,16 +37,17 @@
char * LGBM_BoosterSaveModelToStringSWIG(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len) {
char* dst = new char[buffer_len];
int result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, buffer_len, out_len, dst);
int result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, feature_importance_type, buffer_len, out_len, dst);
// Reallocate to use larger length
if (*out_len > buffer_len) {
delete [] dst;
int64_t realloc_len = *out_len;
dst = new char[realloc_len];
result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, realloc_len, out_len, dst);
result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, feature_importance_type, realloc_len, out_len, dst);
}
if (result != 0) {
return nullptr;
@ -57,16 +58,17 @@
char * LGBM_BoosterDumpModelSWIG(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len) {
char* dst = new char[buffer_len];
int result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, buffer_len, out_len, dst);
int result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, feature_importance_type, buffer_len, out_len, dst);
// Reallocate to use larger length
if (*out_len > buffer_len) {
delete [] dst;
int64_t realloc_len = *out_len;
dst = new char[realloc_len];
result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, realloc_len, out_len, dst);
result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, feature_importance_type, realloc_len, out_len, dst);
}
if (result != 0) {
return nullptr;

Просмотреть файл

@ -236,7 +236,7 @@ def test_booster():
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
if i % 10 == 0:
print('%d iteration test AUC %f' % (i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, 0, -1, c_str('model.txt'))
LIB.LGBM_BoosterSaveModel(booster, 0, -1, 0, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster)
free_dataset(train)
free_dataset(test)