From 01ed04dff27d1b98a7e8a64427e909b39fa1f21d Mon Sep 17 00:00:00 2001 From: wxchan Date: Thu, 3 Nov 2016 13:49:45 +0800 Subject: [PATCH] support init_score for multiclass classification (#62) support init_score for multiclass classification (#62) --- examples/multiclass_classification/README.md | 33 +++++++++++++ include/LightGBM/config.h | 1 + include/LightGBM/dataset.h | 18 ++++--- src/application/application.cpp | 15 ++++-- src/application/predictor.hpp | 13 +++--- src/boosting/gbdt.cpp | 1 - src/boosting/score_updater.hpp | 2 +- src/io/config.cpp | 3 +- src/io/dataset.cpp | 32 ++++++++----- src/io/metadata.cpp | 49 ++++++++++++++------ 10 files changed, 121 insertions(+), 46 deletions(-) create mode 100644 examples/multiclass_classification/README.md diff --git a/examples/multiclass_classification/README.md b/examples/multiclass_classification/README.md new file mode 100644 index 000000000..c408b4bc4 --- /dev/null +++ b/examples/multiclass_classification/README.md @@ -0,0 +1,33 @@ +Multiclass Classification Example +===================== +Here is an example for LightGBM to run multiclass classification task. + +***You should copy executable file to this folder first.*** + +#### Training + +For windows, by running following command in this folder: +``` +lightgbm.exe config=train.conf +``` + + +For linux, by running following command in this folder: +``` +./lightgbm config=train.conf +``` + +#### Prediction + +You should finish training first. + +For windows, by running following command in this folder: +``` +lightgbm.exe config=predict.conf +``` + +For linux, by running following command in this folder: +``` +./lightgbm config=predict.conf +``` + diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 750ae0ad7..e3e329538 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -86,6 +86,7 @@ enum TaskType { struct IOConfig: public ConfigBase { public: int max_bin = 256; + int num_class = 1; int data_random_seed = 1; std::string data_filename = ""; std::vector valid_data_filenames; diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 28156e4f2..109abc8b2 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -41,14 +41,15 @@ public: * \brief Initialization will load qurey level informations, since it is need for sampling data * \param data_filename Filename of data * \param init_score_filename Filename of initial score - * \param is_int_label True if label is int type + * \param num_class Number of classes */ - void Init(const char* data_filename, const char* init_score_filename); + void Init(const char* data_filename, const char* init_score_filename, const int num_class); /*! * \brief Initialize, only load initial score * \param init_score_filename Filename of initial score + * \param num_class Number of classes */ - void Init(const char* init_score_filename); + void Init(const char* init_score_filename, const int num_class); /*! * \brief Initial with binary memory * \param memory Pointer to memory @@ -60,10 +61,11 @@ public: /*! * \brief Initial work, will allocate space for label, weight(if exists) and query(if exists) * \param num_data Number of training data + * \param num_class Number of classes * \param weight_idx Index of weight column, < 0 means doesn't exists * \param query_idx Index of query id column, < 0 means doesn't exists */ - void Init(data_size_t num_data, int weight_idx, int query_idx); + void Init(data_size_t num_data, int num_class, int weight_idx, int query_idx); /*! * \brief Partition label by used indices @@ -167,7 +169,7 @@ public: * \return Pointer of initial scores */ inline const float* init_score() const { return init_score_; } - + /*! \brief Load initial scores from file */ void LoadInitialScore(); @@ -184,6 +186,8 @@ private: const char* init_score_filename_; /*! \brief Number of data */ data_size_t num_data_; + /*! \brief Number of classes */ + int num_class_; /*! \brief Number of weights, used to check correct weight file */ data_size_t num_weights_; /*! \brief Label data */ @@ -234,7 +238,7 @@ public: }; using PredictFunction = - std::function>&)>; + std::function(const std::vector>&)>; /*! \brief The main class of data set, * which are used to traning or validation @@ -398,6 +402,8 @@ private: int num_total_features_; /*! \brief Number of total data*/ data_size_t num_data_; + /*! \brief Number of classes*/ + int num_class_; /*! \brief Store some label level data*/ Metadata metadata_; /*! \brief Random generator*/ diff --git a/src/application/application.cpp b/src/application/application.cpp index f0a2892f9..7f97bfd9d 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -124,10 +124,17 @@ void Application::LoadData() { // need to continue train if (boosting_->NumberOfSubModels() > 0) { predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1); - predict_fun = - [&predictor](const std::vector>& features) { - return predictor->PredictRawOneLine(features); - }; + if (config_.io_config.num_class == 1){ + predict_fun = + [&predictor](const std::vector>& features) { + return predictor->PredictRawOneLine(features); + }; + } else { + predict_fun = + [&predictor](const std::vector>& features) { + return predictor->PredictMulticlassOneLine(features); + }; + } } // sync up random seed for data partition if (config_.is_parallel_find_bin) { diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index 932ab2e19..177db44bf 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -61,10 +61,10 @@ public: * \param features Feature for this record * \return Prediction result */ - double PredictRawOneLine(const std::vector>& features) { + std::vector PredictRawOneLine(const std::vector>& features) { const int tid = PutFeatureValuesToBuffer(features); // get result without sigmoid transformation - return boosting_->PredictRaw(features_[tid], num_used_model_); + return std::vector(1, boosting_->PredictRaw(features_[tid], num_used_model_)); } /*! @@ -83,10 +83,10 @@ public: * \param features Feature of this record * \return Prediction result */ - double PredictOneLine(const std::vector>& features) { + std::vector PredictOneLine(const std::vector>& features) { const int tid = PutFeatureValuesToBuffer(features); // get result with sigmoid transform if needed - return boosting_->Predict(features_[tid], num_used_model_); + return std::vector(1, boosting_->Predict(features_[tid], num_used_model_)); } /*! @@ -136,6 +136,7 @@ public: if (num_class_ > 1) { predict_fun = [this](const std::vector>& features){ std::vector prediction = PredictMulticlassOneLine(features); + Common::Softmax(&prediction); std::stringstream result_stream_buf; for (size_t i = 0; i < prediction.size(); ++i){ if (i > 0) { @@ -162,12 +163,12 @@ public: else { if (is_simgoid_) { predict_fun = [this](const std::vector>& features){ - return std::to_string(PredictOneLine(features)); + return std::to_string(PredictOneLine(features)[0]); }; } else { predict_fun = [this](const std::vector>& features){ - return std::to_string(PredictRawOneLine(features)); + return std::to_string(PredictRawOneLine(features)[0]); }; } } diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index ededd3991..a9d118919 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -503,7 +503,6 @@ std::vector GBDT::PredictMulticlass(const double* value, int num_used_mo ret[j] += models_[i * num_class_ + j] -> Predict(value); } } - Common::Softmax(&ret); return ret; } diff --git a/src/boosting/score_updater.hpp b/src/boosting/score_updater.hpp index edae2d7a4..d7099b9a9 100644 --- a/src/boosting/score_updater.hpp +++ b/src/boosting/score_updater.hpp @@ -27,7 +27,7 @@ public: const float* init_score = data->metadata().init_score(); // if exists initial score, will start from it if (init_score != nullptr) { - for (data_size_t i = 0; i < num_data_; ++i) { + for (data_size_t i = 0; i < num_data_ * num_class; ++i) { score_[i] = init_score[i]; } } diff --git a/src/io/config.cpp b/src/io/config.cpp index 84e93ea01..ab9329684 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -184,6 +184,7 @@ void OverallConfig::CheckParamConflict() { void IOConfig::Set(const std::unordered_map& params) { GetInt(params, "max_bin", &max_bin); CHECK(max_bin > 0); + GetInt(params, "num_class", &num_class); GetInt(params, "data_random_seed", &data_random_seed); if (!GetString(params, "data", &data_filename)) { @@ -236,7 +237,6 @@ void ObjectiveConfig::Set(const std::unordered_map& pa void MetricConfig::Set(const std::unordered_map& params) { GetDouble(params, "sigmoid", &sigmoid); GetInt(params, "num_class", &num_class); - CHECK(num_class >= 1); std::string tmp_str = ""; if (GetString(params, "label_gain", &tmp_str)) { label_gain = Common::StringToDoubleArray(tmp_str, ','); @@ -294,7 +294,6 @@ void BoostingConfig::Set(const std::unordered_map& par CHECK(output_freq >= 0); GetBool(params, "is_training_metric", &is_provide_training_metric); GetInt(params, "num_class", &num_class); - CHECK(num_class >= 1); } void GBDTConfig::GetTreeLearnerType(const std::unordered_map& params) { diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index e2200133f..2664576af 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -20,6 +20,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, :data_filename_(data_filename), random_(io_config.data_random_seed), max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) { + num_class_ = io_config.num_class; + CheckCanLoadFromBin(); if (is_loading_from_binfile_ && predict_fun != nullptr) { Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead"); @@ -28,7 +30,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, if (!is_loading_from_binfile_) { // load weight, query information and initilize score - metadata_.Init(data_filename, init_score_filename); + metadata_.Init(data_filename, init_score_filename, num_class_); // create text reader text_reader_ = new TextReader(data_filename, io_config.has_header); @@ -152,7 +154,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, } } else { // only need to load initilize score, other meta data will be loaded from bin flie - metadata_.Init(init_score_filename); + metadata_.Init(init_score_filename, num_class_); Log::Info("Loading data set from binary file"); parser_ = nullptr; text_reader_ = nullptr; @@ -436,7 +438,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b // construct feature bin mappers ConstructBinMappers(rank, num_machines, sample_data); // initialize label - metadata_.Init(num_data_, weight_idx_, group_idx_); + metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_); // extract features ExtractFeaturesFromMemory(); } else { @@ -446,7 +448,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b // construct feature bin mappers ConstructBinMappers(rank, num_machines, sample_data); // initialize label - metadata_.Init(num_data_, weight_idx_, group_idx_); + metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_); // extract features ExtractFeaturesFromFile(); @@ -471,7 +473,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo // read data in memory LoadDataToMemory(0, 1, false); // initialize label - metadata_.Init(num_data_, weight_idx_, group_idx_); + metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_); features_.clear(); // copy feature bin mapper data for (Feature* feature : train_set->features_) { @@ -487,7 +489,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo // Get number of lines of data file num_data_ = static_cast(text_reader_->CountLine()); // initialize label - metadata_.Init(num_data_, weight_idx_, group_idx_); + metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_); features_.clear(); // copy feature bin mapper data for (Feature* feature : train_set->features_) { @@ -545,7 +547,7 @@ void Dataset::ExtractFeaturesFromMemory() { } } else { // if need to prediction with initial model - float* init_score = new float[num_data_]; + float* init_score = new float[num_data_ * num_class_]; #pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label) for (data_size_t i = 0; i < num_data_; ++i) { const int tid = omp_get_thread_num(); @@ -553,7 +555,10 @@ void Dataset::ExtractFeaturesFromMemory() { // parser parser_->ParseOneLine(text_reader_->Lines()[i].c_str(), &oneline_features, &tmp_label); // set initial score - init_score[i] = static_cast(predict_fun_(oneline_features)); + std::vector oneline_init_score = predict_fun_(oneline_features); + for (int k = 0; k < num_class_; ++k){ + init_score[k * num_data_ + i] = static_cast(oneline_init_score[k]); + } // set label metadata_.SetLabelAt(i, static_cast(tmp_label)); // free processed line: @@ -577,7 +582,7 @@ void Dataset::ExtractFeaturesFromMemory() { } } // metadata_ will manage space of init_score - metadata_.SetInitScore(init_score, num_data_); + metadata_.SetInitScore(init_score, num_data_ * num_class_); delete[] init_score; } @@ -593,7 +598,7 @@ void Dataset::ExtractFeaturesFromMemory() { void Dataset::ExtractFeaturesFromFile() { float* init_score = nullptr; if (predict_fun_ != nullptr) { - init_score = new float[num_data_]; + init_score = new float[num_data_ * num_class_]; } std::function&)> process_fun = [this, &init_score] @@ -608,7 +613,10 @@ void Dataset::ExtractFeaturesFromFile() { parser_->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label); // set initial score if (init_score != nullptr) { - init_score[start_idx + i] = static_cast(predict_fun_(oneline_features)); + std::vector oneline_init_score = predict_fun_(oneline_features); + for (int k = 0; k < num_class_; ++k){ + init_score[k * num_data_ + start_idx + i] = static_cast(oneline_init_score[k]); + } } // set label metadata_.SetLabelAt(start_idx + i, static_cast(tmp_label)); @@ -640,7 +648,7 @@ void Dataset::ExtractFeaturesFromFile() { // metadata_ will manage space of init_score if (init_score != nullptr) { - metadata_.SetInitScore(init_score, num_data_); + metadata_.SetInitScore(init_score, num_data_ * num_class_); delete[] init_score; } diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index b4fa1cd31..632e71e55 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -14,9 +14,10 @@ Metadata::Metadata() } -void Metadata::Init(const char * data_filename, const char* init_score_filename) { +void Metadata::Init(const char * data_filename, const char* init_score_filename, const int num_class) { data_filename_ = data_filename; init_score_filename_ = init_score_filename; + num_class_ = num_class; // for lambdarank, it needs query data for partition data in parallel learning LoadQueryBoundaries(); LoadWeights(); @@ -24,8 +25,9 @@ void Metadata::Init(const char * data_filename, const char* init_score_filename) LoadInitialScore(); } -void Metadata::Init(const char* init_score_filename) { +void Metadata::Init(const char* init_score_filename, const int num_class) { init_score_filename_ = init_score_filename; + num_class_ = num_class; LoadInitialScore(); } @@ -40,8 +42,9 @@ Metadata::~Metadata() { } -void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { +void Metadata::Init(data_size_t num_data, int num_class, int weight_idx, int query_idx) { num_data_ = num_data; + num_class_ = num_class; label_ = new float[num_data_]; if (weight_idx >= 0) { if (weights_ != nullptr) { @@ -200,9 +203,11 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector(reader.Lines().size()); - init_score_ = new float[num_init_score_]; + + init_score_ = new float[num_init_score_ * num_class_]; double tmp = 0.0f; - for (data_size_t i = 0; i < num_init_score_; ++i) { - Common::Atof(reader.Lines()[i].c_str(), &tmp); - init_score_[i] = static_cast(tmp); + + if (num_class_ == 1){ + for (data_size_t i = 0; i < num_init_score_; ++i) { + Common::Atof(reader.Lines()[i].c_str(), &tmp); + init_score_[i] = static_cast(tmp); + } + } else { + std::vector oneline_init_score; + for (data_size_t i = 0; i < num_init_score_; ++i) { + oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t'); + if (static_cast(oneline_init_score.size()) != num_class_){ + Log::Fatal("Invalid initial score file. Redundant or insufficient columns."); + } + for (int k = 0; k < num_class_; ++k) { + Common::Atof(oneline_init_score[k].c_str(), &tmp); + init_score_[k * num_init_score_ + i] = static_cast(tmp); + } + } } }