refine predictor logic in c_api

This commit is contained in:
Guolin Ke 2016-12-28 15:58:50 +08:00
Родитель 728e50a9cf
Коммит 45e0da2cbe
1 изменённых файлов: 10 добавлений и 21 удалений

Просмотреть файл

@ -154,7 +154,7 @@ public:
boosting_->RollbackOneIter();
}
void PrepareForPrediction(int num_iteration, int predict_type) {
Predictor NewPredictor(int num_iteration, int predict_type) {
std::lock_guard<std::mutex> lock(mutex_);
boosting_->SetNumIterationForPred(num_iteration);
bool is_predict_leaf = false;
@ -166,22 +166,15 @@ public:
} else {
is_raw_score = false;
}
predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
// not threading safe now
// boosting_->SetNumIterationForPred may be set by other thread during prediction.
return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
}
void GetPredictAt(int data_idx, score_t* out_result, int64_t* out_len) {
boosting_->GetPredictAt(data_idx, out_result, out_len);
}
std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
return predictor_->GetPredictFunction()(features);
}
void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
std::lock_guard<std::mutex> lock(mutex_);
predictor_->Predict(data_filename, result_filename, data_has_header);
}
void SaveModelToFile(int num_iteration, const char* filename) {
boosting_->SaveModelToFile(num_iteration, filename);
}
@ -232,8 +225,6 @@ private:
std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
/*! \brief Training objective function */
std::unique_ptr<ObjectiveFunction> objective_fun_;
/*! \brief Using predictor for prediction task */
std::unique_ptr<Predictor> predictor_;
/*! \brief mutex for threading safe call */
std::mutex mutex_;
};
@ -692,9 +683,9 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* result_filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
bool bool_data_has_header = data_has_header > 0 ? true : false;
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
predictor.Predict(data_filename, result_filename, bool_data_has_header);
API_END();
}
@ -713,8 +704,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
float* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
@ -728,7 +718,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) {
auto one_row = get_row_fun(i);
auto predicton_result = ref_booster->Predict(one_row);
auto predicton_result = predictor.GetPredictFunction()(one_row);
for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
}
@ -749,8 +739,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
float* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
@ -763,7 +752,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) {
auto one_row = get_row_fun(i);
auto predicton_result = ref_booster->Predict(one_row);
auto predicton_result = predictor.GetPredictFunction()(one_row);
for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
}