зеркало из https://github.com/microsoft/LightGBM.git
[optional] support protobuf (#908)
This commit is contained in:
Родитель
fa45a97b82
Коммит
53b99854aa
|
@ -23,6 +23,7 @@ env:
|
|||
- TASK=if-else
|
||||
- TASK=sdist PYTHON_VERSION=3.4
|
||||
- TASK=bdist PYTHON_VERSION=3.5
|
||||
- TASK=proto
|
||||
- TASK=gpu METHOD=source
|
||||
- TASK=gpu METHOD=pip
|
||||
|
||||
|
@ -38,6 +39,8 @@ matrix:
|
|||
env: TASK=pylint
|
||||
- os: osx
|
||||
env: TASK=check-docs
|
||||
- os: osx
|
||||
env: TASK=proto
|
||||
|
||||
before_install:
|
||||
- test -n $CC && unset CC
|
||||
|
|
|
@ -50,12 +50,24 @@ if [[ ${TASK} == "if-else" ]]; then
|
|||
conda create -q -n test-env python=$PYTHON_VERSION numpy
|
||||
source activate test-env
|
||||
mkdir build && cd build && cmake .. && make lightgbm || exit -1
|
||||
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
|
||||
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf convert_model_language=cpp convert_model=../../src/boosting/gbdt_prediction.cpp && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
|
||||
cd $TRAVIS_BUILD_DIR/build && make lightgbm || exit -1
|
||||
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=predict.conf output_result=ifelse.pred && python test.py || exit -1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ ${TASK} == "proto" ]]; then
|
||||
conda create -q -n test-env python=$PYTHON_VERSION numpy
|
||||
source activate test-env
|
||||
mkdir build && cd build && cmake .. && make lightgbm || exit -1
|
||||
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
|
||||
cd $TRAVIS_BUILD_DIR && git clone https://github.com/google/protobuf && cd protobuf && ./autogen.sh && ./configure && make && sudo make install && sudo ldconfig
|
||||
cd $TRAVIS_BUILD_DIR/build && rm -rf * && cmake -DUSE_PROTO=ON .. && make lightgbm || exit -1
|
||||
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf model_format=proto && ../../lightgbm config=predict.conf output_result=proto.pred model_format=proto || exit -1
|
||||
cd $TRAVIS_BUILD_DIR/tests/cpp_test && python test.py || exit -1
|
||||
exit 0
|
||||
fi
|
||||
|
||||
conda create -q -n test-env python=$PYTHON_VERSION numpy nose scipy scikit-learn pandas matplotlib pytest
|
||||
source activate test-env
|
||||
|
||||
|
|
|
@ -124,8 +124,24 @@ file(GLOB SOURCES
|
|||
src/treelearner/*.cpp
|
||||
)
|
||||
|
||||
add_executable(lightgbm src/main.cpp ${SOURCES})
|
||||
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES})
|
||||
if (USE_PROTO)
|
||||
find_package(Protobuf REQUIRED)
|
||||
PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS proto/model.proto)
|
||||
include_directories(${PROTOBUF_INCLUDE_DIRS})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
SET(PROTO_FILES src/proto/gbdt_model_proto.cpp ${PROTO_HDRS} ${PROTO_SRCS})
|
||||
else()
|
||||
include_directories(src/proto/not_implemented)
|
||||
SET(PROTO_FILES src/proto/not_implemented/gbdt_model_proto.cpp)
|
||||
endif(USE_PROTO)
|
||||
|
||||
add_executable(lightgbm src/main.cpp ${SOURCES} ${PROTO_FILES})
|
||||
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES} ${PROTO_FILES})
|
||||
|
||||
if (USE_PROTO)
|
||||
TARGET_LINK_LIBRARIES(lightgbm ${PROTOBUF_LIBRARIES})
|
||||
TARGET_LINK_LIBRARIES(_lightgbm ${PROTOBUF_LIBRARIES})
|
||||
endif(USE_PROTO)
|
||||
|
||||
if(MSVC)
|
||||
set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm")
|
||||
|
|
|
@ -309,6 +309,20 @@ IO Parameters
|
|||
|
||||
- file name of prediction result in ``prediction`` task
|
||||
|
||||
- ``model_format``, default=\ ``text``, type=string
|
||||
|
||||
- format to save and load model.
|
||||
|
||||
- ``text``, use text string.
|
||||
|
||||
- ``proto``, use protocol buffer binary format.
|
||||
|
||||
- save multiple formats by joining them with comma, like ``text,proto``, in this case, ``model_format`` will be add as suffix after ``output_model``.
|
||||
|
||||
- not support loading with multiple formats.
|
||||
|
||||
- Note: you need to cmake with -DUSE_PROTO=ON to use this parameter.
|
||||
|
||||
- ``is_pre_partition``, default=\ ``false``, type=bool
|
||||
|
||||
- used for parallel learning (not include feature parallel)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include <LightGBM/meta.h>
|
||||
#include <LightGBM/config.h>
|
||||
#include "model.pb.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -166,7 +167,7 @@ public:
|
|||
|
||||
/*!
|
||||
* \brief Save model to file
|
||||
* \param num_used_model Number of model that want to save, -1 means save all
|
||||
* \param num_iterations Number of model that want to save, -1 means save all
|
||||
* \param is_finish Is training finished or not
|
||||
* \param filename Filename that want to save to
|
||||
* \return true if succeeded
|
||||
|
@ -175,7 +176,7 @@ public:
|
|||
|
||||
/*!
|
||||
* \brief Save model to string
|
||||
* \param num_used_model Number of model that want to save, -1 means save all
|
||||
* \param num_iterations Number of model that want to save, -1 means save all
|
||||
* \return Non-empty string if succeeded
|
||||
*/
|
||||
virtual std::string SaveModelToString(int num_iterations) const = 0;
|
||||
|
@ -187,6 +188,20 @@ public:
|
|||
*/
|
||||
virtual bool LoadModelFromString(const std::string& model_str) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Save model with protobuf
|
||||
* \param num_iterations Number of model that want to save, -1 means save all
|
||||
* \param filename Filename that want to save to
|
||||
*/
|
||||
virtual void SaveModelToProto(int num_iteration, const char* filename) const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Restore from a serialized protobuf file
|
||||
* \param filename Filename that want to restore from
|
||||
* \return true if succeeded
|
||||
*/
|
||||
virtual bool LoadModelFromProto(const char* filename) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Calculate feature importances
|
||||
* \param num_iteration Number of model that want to use for feature importance, -1 means use all
|
||||
|
@ -251,23 +266,17 @@ public:
|
|||
/*! \brief Disable copy */
|
||||
Boosting(const Boosting&) = delete;
|
||||
|
||||
static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
|
||||
static bool LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename);
|
||||
|
||||
/*!
|
||||
* \brief Create boosting object
|
||||
* \param type Type of boosting
|
||||
* \param format Format of model
|
||||
* \param config config for boosting
|
||||
* \param filename name of model file, if existing will continue to train from this model
|
||||
* \return The boosting object
|
||||
*/
|
||||
static Boosting* CreateBoosting(const std::string& type, const char* filename);
|
||||
|
||||
/*!
|
||||
* \brief Create boosting object from model file
|
||||
* \param filename name of model file
|
||||
* \return The boosting object
|
||||
*/
|
||||
static Boosting* CreateBoosting(const char* filename);
|
||||
static Boosting* CreateBoosting(const std::string& type, const std::string& format, const char* filename);
|
||||
|
||||
};
|
||||
|
||||
|
|
|
@ -105,6 +105,7 @@ public:
|
|||
std::string output_result = "LightGBM_predict_result.txt";
|
||||
std::string convert_model = "gbdt_prediction.cpp";
|
||||
std::string input_model = "";
|
||||
std::string model_format = "text";
|
||||
int verbosity = 1;
|
||||
int num_iteration_predict = -1;
|
||||
bool is_pre_partition = false;
|
||||
|
@ -445,7 +446,7 @@ struct ParameterAlias {
|
|||
const std::unordered_set<std::string> parameter_set({
|
||||
"config", "config_file", "task", "device",
|
||||
"num_threads", "seed", "boosting_type", "objective", "data",
|
||||
"output_model", "input_model", "output_result", "valid_data",
|
||||
"output_model", "input_model", "output_result", "model_format", "valid_data",
|
||||
"is_enable_sparse", "is_pre_partition", "is_training_metric",
|
||||
"ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf",
|
||||
"num_leaves", "feature_fraction", "num_iterations",
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include <LightGBM/meta.h>
|
||||
#include <LightGBM/dataset.h>
|
||||
#include "model.pb.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -31,6 +32,12 @@ public:
|
|||
*/
|
||||
explicit Tree(const std::string& str);
|
||||
|
||||
/*!
|
||||
* \brief Construtor, from a protobuf object
|
||||
* \param model_tree Model protobuf object
|
||||
*/
|
||||
explicit Tree(const LightGBM::Model_Tree& model_tree);
|
||||
|
||||
~Tree();
|
||||
|
||||
/*!
|
||||
|
@ -165,6 +172,9 @@ public:
|
|||
/*! \brief Serialize this object to if-else statement*/
|
||||
std::string ToIfElse(int index, bool is_predict_leaf_index) const;
|
||||
|
||||
/*! \brief Serialize this object to protobuf object*/
|
||||
void ToProto(Model_Tree& model_tree) const;
|
||||
|
||||
inline static bool IsZero(double fval) {
|
||||
if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) {
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package LightGBM;
|
||||
|
||||
message Model {
|
||||
string name = 1;
|
||||
uint32 num_class = 2;
|
||||
uint32 num_tree_per_iteration = 3;
|
||||
uint32 label_index = 4;
|
||||
uint32 max_feature_idx = 5;
|
||||
string objective = 6;
|
||||
bool average_output = 7;
|
||||
repeated string feature_names = 8;
|
||||
repeated string feature_infos = 9;
|
||||
message Tree {
|
||||
uint32 num_leaves = 1;
|
||||
uint32 num_cat = 2;
|
||||
repeated uint32 split_feature = 3;
|
||||
repeated double split_gain = 4;
|
||||
repeated double threshold = 5;
|
||||
repeated uint32 decision_type = 6;
|
||||
repeated sint32 left_child = 7;
|
||||
repeated sint32 right_child = 8;
|
||||
repeated double leaf_value = 9;
|
||||
repeated uint32 leaf_count = 10;
|
||||
repeated double internal_value = 11;
|
||||
repeated double internal_count = 12;
|
||||
repeated sint32 cat_boundaries = 13;
|
||||
repeated uint32 cat_threshold = 14;
|
||||
double shrinkage = 15;
|
||||
}
|
||||
repeated Tree trees = 10;
|
||||
}
|
|
@ -180,6 +180,7 @@ void Application::InitTrain() {
|
|||
// create boosting
|
||||
boosting_.reset(
|
||||
Boosting::CreateBoosting(config_.boosting_type,
|
||||
config_.io_config.model_format.c_str(),
|
||||
config_.io_config.input_model.c_str()));
|
||||
// create objective function
|
||||
objective_fun_.reset(
|
||||
|
@ -203,6 +204,22 @@ void Application::InitTrain() {
|
|||
void Application::Train() {
|
||||
Log::Info("Started training...");
|
||||
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
|
||||
std::vector<std::string> model_formats = Common::Split(config_.io_config.model_format.c_str(), ',');
|
||||
bool save_with_multiple_format = (model_formats.size() > 1);
|
||||
for (auto model_format: model_formats) {
|
||||
std::string save_file_name = config_.io_config.output_model;
|
||||
if (save_with_multiple_format) {
|
||||
// use suffix to distinguish different model format
|
||||
save_file_name += "." + model_format;
|
||||
}
|
||||
if (model_format == std::string("text")) {
|
||||
boosting_->SaveModelToFile(-1, save_file_name.c_str());
|
||||
} else if (model_format == std::string("proto")) {
|
||||
boosting_->SaveModelToProto(-1, save_file_name.c_str());
|
||||
} else {
|
||||
Log::Fatal("Unknown model format during saving: %s", model_format.c_str());
|
||||
}
|
||||
}
|
||||
// convert model to if-else statement code
|
||||
if (config_.convert_model_language == std::string("cpp")) {
|
||||
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
|
||||
|
@ -223,13 +240,15 @@ void Application::Predict() {
|
|||
|
||||
void Application::InitPredict() {
|
||||
boosting_.reset(
|
||||
Boosting::CreateBoosting(config_.io_config.input_model.c_str()));
|
||||
Boosting::CreateBoosting("gbdt", config_.io_config.model_format.c_str(),
|
||||
config_.io_config.input_model.c_str()));
|
||||
Log::Info("Finished initializing prediction");
|
||||
}
|
||||
|
||||
void Application::ConvertModel() {
|
||||
boosting_.reset(
|
||||
Boosting::CreateBoosting(config_.boosting_type,
|
||||
config_.io_config.model_format.c_str(),
|
||||
config_.io_config.input_model.c_str()));
|
||||
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
|
||||
}
|
||||
|
|
|
@ -12,21 +12,30 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
|
|||
return type;
|
||||
}
|
||||
|
||||
bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
|
||||
bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) {
|
||||
if (boosting != nullptr) {
|
||||
TextReader<size_t> model_reader(filename, true);
|
||||
model_reader.ReadAllLines();
|
||||
std::stringstream str_buf;
|
||||
for (auto& line : model_reader.Lines()) {
|
||||
str_buf << line << '\n';
|
||||
if (format == std::string("text")) {
|
||||
TextReader<size_t> model_reader(filename, true);
|
||||
model_reader.ReadAllLines();
|
||||
std::stringstream str_buf;
|
||||
for (auto& line : model_reader.Lines()) {
|
||||
str_buf << line << '\n';
|
||||
}
|
||||
if (!boosting->LoadModelFromString(str_buf.str())) {
|
||||
return false;
|
||||
}
|
||||
} else if (format == std::string("proto")) {
|
||||
if (!boosting->LoadModelFromProto(filename)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
Log::Fatal("Unknown model format during loading: %s", format.c_str());
|
||||
}
|
||||
if (!boosting->LoadModelFromString(str_buf.str()))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
|
||||
Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) {
|
||||
if (filename == nullptr || filename[0] == '\0') {
|
||||
if (type == std::string("gbdt")) {
|
||||
return new GBDT();
|
||||
|
@ -41,8 +50,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
|
|||
}
|
||||
} else {
|
||||
std::unique_ptr<Boosting> ret;
|
||||
auto type_in_file = GetBoostingTypeFromModelFile(filename);
|
||||
if (type_in_file == std::string("tree")) {
|
||||
if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
|
||||
if (type == std::string("gbdt")) {
|
||||
ret.reset(new GBDT());
|
||||
} else if (type == std::string("dart")) {
|
||||
|
@ -54,24 +62,12 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
|
|||
} else {
|
||||
Log::Fatal("unknown boosting type %s", type.c_str());
|
||||
}
|
||||
LoadFileToBoosting(ret.get(), filename);
|
||||
LoadFileToBoosting(ret.get(), format, filename);
|
||||
} else {
|
||||
Log::Fatal("unknown submodel type in model file %s", filename);
|
||||
Log::Fatal("unknown model format or submodel type in model file %s", filename);
|
||||
}
|
||||
return ret.release();
|
||||
}
|
||||
}
|
||||
|
||||
Boosting* Boosting::CreateBoosting(const char* filename) {
|
||||
auto type = GetBoostingTypeFromModelFile(filename);
|
||||
std::unique_ptr<Boosting> ret;
|
||||
if (type == std::string("tree")) {
|
||||
ret.reset(new GBDT());
|
||||
} else {
|
||||
Log::Fatal("unknown submodel type in model file %s", filename);
|
||||
}
|
||||
LoadFileToBoosting(ret.get(), filename);
|
||||
return ret.release();
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -352,7 +352,6 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
|
|||
SaveModelToFile(-1, snapshot_out.c_str());
|
||||
}
|
||||
}
|
||||
SaveModelToFile(-1, model_output_path.c_str());
|
||||
}
|
||||
|
||||
double GBDT::BoostFromAverage() {
|
||||
|
|
|
@ -236,6 +236,20 @@ public:
|
|||
*/
|
||||
bool LoadModelFromString(const std::string& model_str) override;
|
||||
|
||||
/*!
|
||||
* \brief Save model with protobuf
|
||||
* \param num_iterations Number of model that want to save, -1 means save all
|
||||
* \param filename Filename that want to save to
|
||||
*/
|
||||
void SaveModelToProto(int num_iteration, const char* filename) const override;
|
||||
|
||||
/*!
|
||||
* \brief Restore from a serialized protobuf file
|
||||
* \param filename Filename that want to restore from
|
||||
* \return true if succeeded
|
||||
*/
|
||||
bool LoadModelFromProto(const char* filename) override;
|
||||
|
||||
/*!
|
||||
* \brief Calculate feature importances
|
||||
* \param num_iteration Number of model that want to use for feature importance, -1 means use all
|
||||
|
|
|
@ -29,11 +29,7 @@ namespace LightGBM {
|
|||
class Booster {
|
||||
public:
|
||||
explicit Booster(const char* filename) {
|
||||
boosting_.reset(Boosting::CreateBoosting(filename));
|
||||
}
|
||||
|
||||
Booster() {
|
||||
boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
|
||||
boosting_.reset(Boosting::CreateBoosting("gbdt", "text", filename));
|
||||
}
|
||||
|
||||
Booster(const Dataset* train_data,
|
||||
|
@ -50,7 +46,7 @@ public:
|
|||
please use continued train with input score");
|
||||
}
|
||||
|
||||
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
|
||||
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "text", nullptr));
|
||||
|
||||
train_data_ = train_data;
|
||||
CreateObjectiveAndMetrics();
|
||||
|
@ -838,7 +834,7 @@ int LGBM_BoosterLoadModelFromString(
|
|||
int* out_num_iterations,
|
||||
BoosterHandle* out) {
|
||||
API_BEGIN();
|
||||
auto ret = std::unique_ptr<Booster>(new Booster());
|
||||
auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
|
||||
ret->LoadModelFromString(model_str);
|
||||
*out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
|
||||
*out = ret.release();
|
||||
|
|
|
@ -269,6 +269,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
|
|||
GetString(params, "input_model", &input_model);
|
||||
GetString(params, "convert_model", &convert_model);
|
||||
GetString(params, "output_result", &output_result);
|
||||
GetString(params, "model_format", &model_format);
|
||||
std::string tmp_str = "";
|
||||
if (GetString(params, "valid_data", &tmp_str)) {
|
||||
valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
|
||||
|
|
|
@ -0,0 +1,191 @@
|
|||
#include "../boosting/gbdt.h"
|
||||
|
||||
#include <LightGBM/tree.h>
|
||||
#include <LightGBM/utils/common.h>
|
||||
#include <LightGBM/objective_function.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
void GBDT::SaveModelToProto(int num_iteration, const char* filename) const {
|
||||
LightGBM::Model model;
|
||||
|
||||
model.set_name(SubModelName());
|
||||
model.set_num_class(num_class_);
|
||||
model.set_num_tree_per_iteration(num_tree_per_iteration_);
|
||||
model.set_label_index(label_idx_);
|
||||
model.set_max_feature_idx(max_feature_idx_);
|
||||
if (objective_function_ != nullptr) {
|
||||
model.set_objective(objective_function_->ToString());
|
||||
}
|
||||
model.set_average_output(average_output_);
|
||||
for(auto feature_name: feature_names_) {
|
||||
model.add_feature_names(feature_name);
|
||||
}
|
||||
for(auto feature_info: feature_infos_) {
|
||||
model.add_feature_infos(feature_info);
|
||||
}
|
||||
|
||||
int num_used_model = static_cast<int>(models_.size());
|
||||
if (num_iteration > 0) {
|
||||
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
|
||||
}
|
||||
for (int i = 0; i < num_used_model; ++i) {
|
||||
models_[i]->ToProto(*model.add_trees());
|
||||
}
|
||||
|
||||
std::filebuf fb;
|
||||
fb.open(filename, std::ios::out | std::ios::binary);
|
||||
std::ostream os(&fb);
|
||||
if (!model.SerializeToOstream(&os)) {
|
||||
Log::Fatal("Cannot serialize model to binary file.");
|
||||
}
|
||||
fb.close();
|
||||
}
|
||||
|
||||
bool GBDT::LoadModelFromProto(const char* filename) {
|
||||
models_.clear();
|
||||
LightGBM::Model model;
|
||||
std::filebuf fb;
|
||||
if (fb.open(filename, std::ios::in | std::ios::binary))
|
||||
{
|
||||
std::istream is(&fb);
|
||||
if (!model.ParseFromIstream(&is)) {
|
||||
Log::Fatal("Cannot parse model from binary file.");
|
||||
}
|
||||
fb.close();
|
||||
} else {
|
||||
Log::Fatal("Cannot open file: %s.", filename);
|
||||
}
|
||||
|
||||
num_class_ = model.num_class();
|
||||
num_tree_per_iteration_ = model.num_tree_per_iteration();
|
||||
label_idx_ = model.label_index();
|
||||
max_feature_idx_ = model.max_feature_idx();
|
||||
average_output_ = model.average_output();
|
||||
feature_names_.reserve(model.feature_names_size());
|
||||
for (auto feature_name: model.feature_names()) {
|
||||
feature_names_.push_back(feature_name);
|
||||
}
|
||||
feature_infos_.reserve(model.feature_infos_size());
|
||||
for (auto feature_info: model.feature_infos()) {
|
||||
feature_infos_.push_back(feature_info);
|
||||
}
|
||||
loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(model.objective()));
|
||||
objective_function_ = loaded_objective_.get();
|
||||
|
||||
for (auto tree: model.trees()) {
|
||||
models_.emplace_back(new Tree(tree));
|
||||
}
|
||||
Log::Info("Finished loading %d models", models_.size());
|
||||
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
|
||||
num_init_iteration_ = num_iteration_for_pred_;
|
||||
iter_ = 0;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Tree::ToProto(LightGBM::Model_Tree& model_tree) const {
|
||||
|
||||
model_tree.set_num_leaves(num_leaves_);
|
||||
model_tree.set_num_cat(num_cat_);
|
||||
for (int i = 0; i < num_leaves_ - 1; ++i) {
|
||||
model_tree.add_split_feature(split_feature_[i]);
|
||||
model_tree.add_split_gain(split_gain_[i]);
|
||||
model_tree.add_threshold(threshold_[i]);
|
||||
model_tree.add_decision_type(decision_type_[i]);
|
||||
model_tree.add_left_child(left_child_[i]);
|
||||
model_tree.add_right_child(right_child_[i]);
|
||||
model_tree.add_internal_value(internal_value_[i]);
|
||||
model_tree.add_internal_count(internal_count_[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_leaves_; ++i) {
|
||||
model_tree.add_leaf_value(leaf_value_[i]);
|
||||
model_tree.add_leaf_count(leaf_count_[i]);
|
||||
}
|
||||
|
||||
if (num_cat_ > 0) {
|
||||
for (int i = 0; i < num_cat_ + 1; ++i) {
|
||||
model_tree.add_cat_boundaries(cat_boundaries_[i]);
|
||||
}
|
||||
for (size_t i = 0; i < cat_threshold_.size(); ++i) {
|
||||
model_tree.add_cat_threshold(cat_threshold_[i]);
|
||||
}
|
||||
}
|
||||
model_tree.set_shrinkage(shrinkage_);
|
||||
}
|
||||
|
||||
Tree::Tree(const LightGBM::Model_Tree& model_tree) {
|
||||
|
||||
num_leaves_ = model_tree.num_leaves();
|
||||
if (num_leaves_ <= 1) { return; }
|
||||
num_cat_ = model_tree.num_cat();
|
||||
|
||||
leaf_value_.reserve(model_tree.leaf_value_size());
|
||||
for(auto leaf_value: model_tree.leaf_value()) {
|
||||
leaf_value_.push_back(leaf_value);
|
||||
}
|
||||
|
||||
left_child_.reserve(model_tree.left_child_size());
|
||||
for(auto left_child: model_tree.left_child()) {
|
||||
left_child_.push_back(left_child);
|
||||
}
|
||||
|
||||
right_child_.reserve(model_tree.right_child_size());
|
||||
for(auto right_child: model_tree.right_child()) {
|
||||
right_child_.push_back(right_child);
|
||||
}
|
||||
|
||||
split_feature_.reserve(model_tree.split_feature_size());
|
||||
for(auto split_feature: model_tree.split_feature()) {
|
||||
split_feature_.push_back(split_feature);
|
||||
}
|
||||
|
||||
threshold_.reserve(model_tree.threshold_size());
|
||||
for(auto threshold: model_tree.threshold()) {
|
||||
threshold_.push_back(threshold);
|
||||
}
|
||||
|
||||
split_gain_.reserve(model_tree.split_gain_size());
|
||||
for(auto split_gain: model_tree.split_gain()) {
|
||||
split_gain_.push_back(split_gain);
|
||||
}
|
||||
|
||||
internal_count_.reserve(model_tree.internal_count_size());
|
||||
for(auto internal_count: model_tree.internal_count()) {
|
||||
internal_count_.push_back(internal_count);
|
||||
}
|
||||
|
||||
internal_value_.reserve(model_tree.internal_value_size());
|
||||
for(auto internal_value: model_tree.internal_value()) {
|
||||
internal_value_.push_back(internal_value);
|
||||
}
|
||||
|
||||
leaf_count_.reserve(model_tree.leaf_count_size());
|
||||
for(auto leaf_count: model_tree.leaf_count()) {
|
||||
leaf_count_.push_back(leaf_count);
|
||||
}
|
||||
|
||||
decision_type_.reserve(model_tree.decision_type_size());
|
||||
for(auto decision_type: model_tree.decision_type()) {
|
||||
decision_type_.push_back(decision_type);
|
||||
}
|
||||
|
||||
if (num_cat_ > 0) {
|
||||
cat_boundaries_.reserve(model_tree.cat_boundaries_size());
|
||||
for(auto cat_boundaries: model_tree.cat_boundaries()) {
|
||||
cat_boundaries_.push_back(cat_boundaries);
|
||||
}
|
||||
|
||||
cat_threshold_.reserve(model_tree.cat_threshold_size());
|
||||
for(auto cat_threshold: model_tree.cat_threshold()) {
|
||||
cat_threshold_.push_back(cat_threshold);
|
||||
}
|
||||
}
|
||||
|
||||
shrinkage_ = model_tree.shrinkage();
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
|
@ -0,0 +1,22 @@
|
|||
#include "../../boosting/gbdt.h"
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
void GBDT::SaveModelToProto(int, const char*) const {
|
||||
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
|
||||
}
|
||||
|
||||
bool GBDT::LoadModelFromProto(const char*) {
|
||||
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
|
||||
return false;
|
||||
}
|
||||
|
||||
void Tree::ToProto(LightGBM::Model_Tree&) const {
|
||||
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
|
||||
}
|
||||
|
||||
Tree::Tree(const LightGBM::Model_Tree&) {
|
||||
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
|
@ -0,0 +1,9 @@
|
|||
#ifndef PROTOBUF_model_2eproto__INCLUDED
|
||||
#define PROTOBUF_model_2eproto__INCLUDED
|
||||
|
||||
namespace LightGBM {
|
||||
class Model;
|
||||
class Model_Tree;
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // PROTOBUF_model_2eproto__INCLUDED
|
|
@ -3,7 +3,3 @@ data=../data/categorical.data
|
|||
app=binary
|
||||
|
||||
num_trees=10
|
||||
|
||||
convert_model=../../src/boosting/gbdt_prediction.cpp
|
||||
|
||||
convert_model_language=cpp
|
||||
|
|
|
@ -247,7 +247,8 @@
|
|||
<ClCompile Include="..\src\application\application.cpp" />
|
||||
<ClCompile Include="..\src\boosting\boosting.cpp" />
|
||||
<ClCompile Include="..\src\boosting\gbdt.cpp" />
|
||||
<ClCompile Include="..\src\boosting\gbdt_model.cpp" />
|
||||
<ClCompile Include="..\src\boosting\gbdt_model_text.cpp" />
|
||||
<ClCompile Include="..\src\boosting\gbdt_model_proto.cpp" />
|
||||
<ClCompile Include="..\src\boosting\gbdt_prediction.cpp" />
|
||||
<ClCompile Include="..\src\boosting\prediction_early_stop.cpp" />
|
||||
<ClCompile Include="..\src\c_api.cpp" />
|
||||
|
|
|
@ -278,7 +278,10 @@
|
|||
<ClCompile Include="..\src\lightgbm_R.cpp">
|
||||
<Filter>src</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\boosting\gbdt_model.cpp">
|
||||
<ClCompile Include="..\src\boosting\gbdt_model_text.cpp">
|
||||
<Filter>src\boosting</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\boosting\gbdt_model_proto.cpp">
|
||||
<Filter>src\boosting</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
|
|
Загрузка…
Ссылка в новой задаче