зеркало из https://github.com/microsoft/LightGBM.git
not need for the data_has_label parameters any more.
This commit is contained in:
Родитель
728875b46c
Коммит
aee30126d8
|
@ -4,5 +4,3 @@ task = predict
|
|||
data = binary.test
|
||||
|
||||
input_model= LightGBM_model.txt
|
||||
|
||||
data_has_label = true
|
||||
|
|
|
@ -4,5 +4,3 @@ task = predict
|
|||
data = rank.test
|
||||
|
||||
input_model= LightGBM_model.txt
|
||||
|
||||
data_has_label = true
|
||||
|
|
|
@ -5,4 +5,3 @@ data = binary.test
|
|||
|
||||
input_model= LightGBM_model.txt
|
||||
|
||||
data_has_label = true
|
||||
|
|
|
@ -4,5 +4,3 @@ task = predict
|
|||
data = regression.test
|
||||
|
||||
input_model= LightGBM_model.txt
|
||||
|
||||
data_has_label = true
|
||||
|
|
|
@ -88,7 +88,6 @@ public:
|
|||
int max_bin = 255;
|
||||
int data_random_seed = 1;
|
||||
std::string data_filename = "";
|
||||
bool data_has_label = true;
|
||||
std::vector<std::string> valid_data_filenames;
|
||||
std::string output_model = "LightGBM_model.txt";
|
||||
std::string output_result = "LightGBM_predict_result.txt";
|
||||
|
@ -274,8 +273,6 @@ struct ParameterAlias {
|
|||
{ "app", "objective" },
|
||||
{ "train_data", "data" },
|
||||
{ "train", "data" },
|
||||
{ "has_label", "data_has_label" },
|
||||
{ "is_data_has_label", "data_has_label" },
|
||||
{ "model_output", "output_model" },
|
||||
{ "model_out", "output_model" },
|
||||
{ "model_input", "input_model" },
|
||||
|
|
|
@ -208,9 +208,11 @@ public:
|
|||
/*!
|
||||
* \brief Create a object of parser, will auto choose the format depend on file
|
||||
* \param filename One Filename of data
|
||||
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
|
||||
* \param has_label output, if num_features > 0, will output this data has label or not
|
||||
* \return Object of parser
|
||||
*/
|
||||
static Parser* CreateParser(const char* filename);
|
||||
static Parser* CreateParser(const char* filename, int num_features, bool* has_label);
|
||||
};
|
||||
|
||||
using PredictFunction =
|
||||
|
@ -299,6 +301,9 @@ public:
|
|||
/*! \brief Get Number of used features */
|
||||
inline int num_features() const { return num_features_; }
|
||||
|
||||
/*! \brief Get Number of total features */
|
||||
inline int num_total_features() const { return num_total_features_; }
|
||||
|
||||
/*! \brief Get Number of data */
|
||||
inline data_size_t num_data() const { return num_data_; }
|
||||
|
||||
|
@ -373,6 +378,8 @@ private:
|
|||
std::vector<int> used_feature_map_;
|
||||
/*! \brief Number of used features*/
|
||||
int num_features_;
|
||||
/*! \brief Number of total features*/
|
||||
int num_total_features_;
|
||||
/*! \brief Number of total data*/
|
||||
data_size_t num_data_;
|
||||
/*! \brief Store some label level data*/
|
||||
|
|
|
@ -253,8 +253,7 @@ void Application::Train() {
|
|||
void Application::Predict() {
|
||||
// create predictor
|
||||
Predictor predictor(boosting_, config_.io_config.is_sigmoid);
|
||||
predictor.Predict(config_.io_config.data_filename.c_str(),
|
||||
config_.io_config.data_has_label, config_.io_config.output_result.c_str());
|
||||
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str());
|
||||
Log::Stdout("finish predict");
|
||||
}
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ public:
|
|||
* \param has_label True if this data contains label
|
||||
* \param result_filename Filename of output result
|
||||
*/
|
||||
void Predict(const char* data_filename, bool has_label, const char* result_filename) {
|
||||
void Predict(const char* data_filename, const char* result_filename) {
|
||||
FILE* result_file;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
@ -108,8 +108,8 @@ public:
|
|||
if (result_file == NULL) {
|
||||
Log::Stderr("predition result file %s doesn't exists", data_filename);
|
||||
}
|
||||
|
||||
Parser* parser = Parser::CreateParser(data_filename);
|
||||
bool has_label = false;
|
||||
Parser* parser = Parser::CreateParser(data_filename, num_features_, &has_label);
|
||||
|
||||
if (parser == nullptr) {
|
||||
Log::Stderr("can regonise input data format, filename %s", data_filename);
|
||||
|
|
|
@ -60,10 +60,7 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct
|
|||
hessians_ = new score_t[num_data_];
|
||||
|
||||
// get max feature index
|
||||
for (int i = 0; i < train_data->num_features(); ++i) {
|
||||
max_feature_idx_ = Common::Max<int>(max_feature_idx_,
|
||||
train_data->FeatureAt(i)->feature_index());
|
||||
}
|
||||
max_feature_idx_ = train_data_->num_total_features() - 1;
|
||||
|
||||
// if need bagging, create buffer
|
||||
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
|
||||
|
|
|
@ -15,11 +15,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
|
|||
GetInt(params, "num_threads", &num_threads);
|
||||
GetTaskType(params);
|
||||
|
||||
// prediction task, default not has label
|
||||
if (task_type == TaskType::kPredict) {
|
||||
io_config.data_has_label = false;
|
||||
}
|
||||
|
||||
GetBoostingType(params);
|
||||
GetObjectiveType(params);
|
||||
GetMetricType(params);
|
||||
|
@ -125,11 +120,6 @@ void OverallConfig::CheckParamConflict() {
|
|||
TreeLearnerType::kDataParallelTreeLearner) {
|
||||
is_parallel_find_bin = true;
|
||||
}
|
||||
|
||||
if (task_type == TaskType::kTrain && io_config.data_has_label == false) {
|
||||
Log::Stderr("Data should have label in training task");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
|
||||
|
@ -141,7 +131,6 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
|
|||
Log::Stderr("No training/prediction data, application quit");
|
||||
}
|
||||
GetInt(params, "num_model_predict", &num_model_predict);
|
||||
GetBool(params, "data_has_label", &data_has_label);
|
||||
GetBool(params, "is_pre_partition", &is_pre_partition);
|
||||
GetBool(params, "is_enable_sparse", &is_enable_sparse);
|
||||
GetBool(params, "use_two_round_loading", &use_two_round_loading);
|
||||
|
|
|
@ -29,7 +29,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
|
|||
// load weight, query information and initilize score
|
||||
metadata_.Init(data_filename, init_score_filename);
|
||||
// create text parser
|
||||
parser_ = Parser::CreateParser(data_filename_);
|
||||
parser_ = Parser::CreateParser(data_filename_, 0, nullptr);
|
||||
if (parser_ == nullptr) {
|
||||
Log::Stderr("cannot recognise input data format, filename: %s", data_filename_);
|
||||
}
|
||||
|
@ -189,7 +189,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
|
|||
|
||||
// -1 means doesn't use this feature
|
||||
used_feature_map_ = std::vector<int>(sample_values.size(), -1);
|
||||
|
||||
num_total_features_ = sample_values.size();
|
||||
// start find bins
|
||||
if (num_machines == 1) {
|
||||
std::vector<BinMapper*> bin_mappers(sample_values.size());
|
||||
|
@ -209,6 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
|
|||
num_data_, is_enable_sparse_));
|
||||
} else {
|
||||
// if feature is trival(only 1 bin), free spaces
|
||||
Log::Stdout("Warning: feture %d only contains one value, will ignore it", i);
|
||||
delete bin_mappers[i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,38 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt)
|
|||
}
|
||||
}
|
||||
|
||||
Parser* Parser::CreateParser(const char* filename) {
|
||||
bool CheckHasLabelForLibsvm(std::string& str) {
|
||||
str = Common::Trim(str);
|
||||
auto pos_space = str.find_first_of(" \f\n\r\t\v");
|
||||
auto pos_colon = str.find_first_of(":");
|
||||
if (pos_colon == std::string::npos || pos_colon > pos_space) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckHasLabelForTSV(std::string& str, int num_features) {
|
||||
str = Common::Trim(str);
|
||||
auto tokens = Common::Split(str.c_str(), '\t');
|
||||
if (tokens.size() == num_features) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckHasLabelForCSV(std::string& str, int num_features) {
|
||||
str = Common::Trim(str);
|
||||
auto tokens = Common::Split(str.c_str(), ',');
|
||||
if (tokens.size() == num_features) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_label) {
|
||||
std::ifstream tmp_file;
|
||||
tmp_file.open(filename);
|
||||
if (!tmp_file.is_open()) {
|
||||
|
@ -44,29 +75,45 @@ Parser* Parser::CreateParser(const char* filename) {
|
|||
// Get some statistic from 2 line
|
||||
GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
|
||||
GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
|
||||
Parser* ret = nullptr;
|
||||
if (line2.size() == 0) {
|
||||
// if only have one line on file
|
||||
if (colon_cnt > 0) {
|
||||
return new LibSVMParser();
|
||||
ret = new LibSVMParser();
|
||||
if (num_features > 0 && has_label != nullptr) {
|
||||
*has_label = CheckHasLabelForLibsvm(line1);
|
||||
}
|
||||
} else if (tab_cnt > 0) {
|
||||
return new TSVParser();
|
||||
ret = new TSVParser();
|
||||
if (num_features > 0 && has_label != nullptr) {
|
||||
*has_label = CheckHasLabelForTSV(line1, num_features);
|
||||
}
|
||||
} else if (comma_cnt > 0) {
|
||||
return new CSVParser();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
ret = new CSVParser();
|
||||
if (num_features > 0 && has_label != nullptr) {
|
||||
*has_label = CheckHasLabelForCSV(line1, num_features);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (colon_cnt > 0 || colon_cnt2 > 0) {
|
||||
return new LibSVMParser();
|
||||
ret = new LibSVMParser();
|
||||
if (num_features > 0 && has_label != nullptr) {
|
||||
*has_label = CheckHasLabelForLibsvm(line1);
|
||||
}
|
||||
}
|
||||
else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
|
||||
return new TSVParser();
|
||||
ret = new TSVParser();
|
||||
if (num_features > 0 && has_label != nullptr) {
|
||||
*has_label = CheckHasLabelForTSV(line1, num_features);
|
||||
}
|
||||
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
|
||||
return new CSVParser();
|
||||
} else {
|
||||
return nullptr;
|
||||
ret = new CSVParser();
|
||||
if (num_features > 0 && has_label != nullptr) {
|
||||
*has_label = CheckHasLabelForCSV(line1, num_features);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
Загрузка…
Ссылка в новой задаче