not need for the data_has_label parameters any more.

This commit is contained in:
Guolin Ke 2016-10-18 13:27:01 +08:00
Родитель 728875b46c
Коммит aee30126d8
12 изменённых файлов: 75 добавлений и 45 удалений

Просмотреть файл

@ -4,5 +4,3 @@ task = predict
data = binary.test
input_model= LightGBM_model.txt
data_has_label = true

Просмотреть файл

@ -4,5 +4,3 @@ task = predict
data = rank.test
input_model= LightGBM_model.txt
data_has_label = true

Просмотреть файл

@ -5,4 +5,3 @@ data = binary.test
input_model= LightGBM_model.txt
data_has_label = true

Просмотреть файл

@ -4,5 +4,3 @@ task = predict
data = regression.test
input_model= LightGBM_model.txt
data_has_label = true

Просмотреть файл

@ -88,7 +88,6 @@ public:
int max_bin = 255;
int data_random_seed = 1;
std::string data_filename = "";
bool data_has_label = true;
std::vector<std::string> valid_data_filenames;
std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
@ -274,8 +273,6 @@ struct ParameterAlias {
{ "app", "objective" },
{ "train_data", "data" },
{ "train", "data" },
{ "has_label", "data_has_label" },
{ "is_data_has_label", "data_has_label" },
{ "model_output", "output_model" },
{ "model_out", "output_model" },
{ "model_input", "input_model" },

Просмотреть файл

@ -208,9 +208,11 @@ public:
/*!
* \brief Create a object of parser, will auto choose the format depend on file
* \param filename One Filename of data
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param has_label output, if num_features > 0, will output this data has label or not
* \return Object of parser
*/
static Parser* CreateParser(const char* filename);
static Parser* CreateParser(const char* filename, int num_features, bool* has_label);
};
using PredictFunction =
@ -299,6 +301,9 @@ public:
/*! \brief Get Number of used features */
inline int num_features() const { return num_features_; }
/*! \brief Get Number of total features */
inline int num_total_features() const { return num_total_features_; }
/*! \brief Get Number of data */
inline data_size_t num_data() const { return num_data_; }
@ -373,6 +378,8 @@ private:
std::vector<int> used_feature_map_;
/*! \brief Number of used features*/
int num_features_;
/*! \brief Number of total features*/
int num_total_features_;
/*! \brief Number of total data*/
data_size_t num_data_;
/*! \brief Store some label level data*/

Просмотреть файл

@ -253,8 +253,7 @@ void Application::Train() {
void Application::Predict() {
// create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid);
predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.data_has_label, config_.io_config.output_result.c_str());
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str());
Log::Stdout("finish predict");
}

Просмотреть файл

@ -96,7 +96,7 @@ public:
* \param has_label True if this data contains label
* \param result_filename Filename of output result
*/
void Predict(const char* data_filename, bool has_label, const char* result_filename) {
void Predict(const char* data_filename, const char* result_filename) {
FILE* result_file;
#ifdef _MSC_VER
@ -108,8 +108,8 @@ public:
if (result_file == NULL) {
Log::Stderr("predition result file %s doesn't exists", data_filename);
}
Parser* parser = Parser::CreateParser(data_filename);
bool has_label = false;
Parser* parser = Parser::CreateParser(data_filename, num_features_, &has_label);
if (parser == nullptr) {
Log::Stderr("can regonise input data format, filename %s", data_filename);

Просмотреть файл

@ -60,10 +60,7 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct
hessians_ = new score_t[num_data_];
// get max feature index
for (int i = 0; i < train_data->num_features(); ++i) {
max_feature_idx_ = Common::Max<int>(max_feature_idx_,
train_data->FeatureAt(i)->feature_index());
}
max_feature_idx_ = train_data_->num_total_features() - 1;
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {

Просмотреть файл

@ -15,11 +15,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
GetInt(params, "num_threads", &num_threads);
GetTaskType(params);
// prediction task, default not has label
if (task_type == TaskType::kPredict) {
io_config.data_has_label = false;
}
GetBoostingType(params);
GetObjectiveType(params);
GetMetricType(params);
@ -125,11 +120,6 @@ void OverallConfig::CheckParamConflict() {
TreeLearnerType::kDataParallelTreeLearner) {
is_parallel_find_bin = true;
}
if (task_type == TaskType::kTrain && io_config.data_has_label == false) {
Log::Stderr("Data should have label in training task");
}
}
void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
@ -141,7 +131,6 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
Log::Stderr("No training/prediction data, application quit");
}
GetInt(params, "num_model_predict", &num_model_predict);
GetBool(params, "data_has_label", &data_has_label);
GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetBool(params, "use_two_round_loading", &use_two_round_loading);

Просмотреть файл

@ -29,7 +29,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
// load weight, query information and initilize score
metadata_.Init(data_filename, init_score_filename);
// create text parser
parser_ = Parser::CreateParser(data_filename_);
parser_ = Parser::CreateParser(data_filename_, 0, nullptr);
if (parser_ == nullptr) {
Log::Stderr("cannot recognise input data format, filename: %s", data_filename_);
}
@ -189,7 +189,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// -1 means doesn't use this feature
used_feature_map_ = std::vector<int>(sample_values.size(), -1);
num_total_features_ = sample_values.size();
// start find bins
if (num_machines == 1) {
std::vector<BinMapper*> bin_mappers(sample_values.size());
@ -209,6 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_, is_enable_sparse_));
} else {
// if feature is trival(only 1 bin), free spaces
Log::Stdout("Warning: feture %d only contains one value, will ignore it", i);
delete bin_mappers[i];
}
}

Просмотреть файл

@ -20,7 +20,38 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt)
}
}
Parser* Parser::CreateParser(const char* filename) {
bool CheckHasLabelForLibsvm(std::string& str) {
str = Common::Trim(str);
auto pos_space = str.find_first_of(" \f\n\r\t\v");
auto pos_colon = str.find_first_of(":");
if (pos_colon == std::string::npos || pos_colon > pos_space) {
return true;
} else {
return false;
}
}
bool CheckHasLabelForTSV(std::string& str, int num_features) {
str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), '\t');
if (tokens.size() == num_features) {
return false;
} else {
return true;
}
}
bool CheckHasLabelForCSV(std::string& str, int num_features) {
str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), ',');
if (tokens.size() == num_features) {
return false;
} else {
return true;
}
}
Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_label) {
std::ifstream tmp_file;
tmp_file.open(filename);
if (!tmp_file.is_open()) {
@ -44,29 +75,45 @@ Parser* Parser::CreateParser(const char* filename) {
// Get some statistic from 2 line
GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
Parser* ret = nullptr;
if (line2.size() == 0) {
// if only have one line on file
if (colon_cnt > 0) {
return new LibSVMParser();
ret = new LibSVMParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForLibsvm(line1);
}
} else if (tab_cnt > 0) {
return new TSVParser();
ret = new TSVParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForTSV(line1, num_features);
}
} else if (comma_cnt > 0) {
return new CSVParser();
} else {
return nullptr;
}
ret = new CSVParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForCSV(line1, num_features);
}
}
} else {
if (colon_cnt > 0 || colon_cnt2 > 0) {
return new LibSVMParser();
ret = new LibSVMParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForLibsvm(line1);
}
}
else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
return new TSVParser();
ret = new TSVParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForTSV(line1, num_features);
}
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
return new CSVParser();
} else {
return nullptr;
ret = new CSVParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForCSV(line1, num_features);
}
}
}
return ret;
}
} // namespace LightGBM