зеркало из https://github.com/microsoft/LightGBM.git
refine predictor logic in c_api
This commit is contained in:
Родитель
728e50a9cf
Коммит
45e0da2cbe
|
@ -154,7 +154,7 @@ public:
|
|||
boosting_->RollbackOneIter();
|
||||
}
|
||||
|
||||
void PrepareForPrediction(int num_iteration, int predict_type) {
|
||||
Predictor NewPredictor(int num_iteration, int predict_type) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
boosting_->SetNumIterationForPred(num_iteration);
|
||||
bool is_predict_leaf = false;
|
||||
|
@ -166,22 +166,15 @@ public:
|
|||
} else {
|
||||
is_raw_score = false;
|
||||
}
|
||||
predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
|
||||
// not threading safe now
|
||||
// boosting_->SetNumIterationForPred may be set by other thread during prediction.
|
||||
return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
|
||||
}
|
||||
|
||||
void GetPredictAt(int data_idx, score_t* out_result, int64_t* out_len) {
|
||||
boosting_->GetPredictAt(data_idx, out_result, out_len);
|
||||
}
|
||||
|
||||
std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
|
||||
return predictor_->GetPredictFunction()(features);
|
||||
}
|
||||
|
||||
void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
predictor_->Predict(data_filename, result_filename, data_has_header);
|
||||
}
|
||||
|
||||
void SaveModelToFile(int num_iteration, const char* filename) {
|
||||
boosting_->SaveModelToFile(num_iteration, filename);
|
||||
}
|
||||
|
@ -232,8 +225,6 @@ private:
|
|||
std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
|
||||
/*! \brief Training objective function */
|
||||
std::unique_ptr<ObjectiveFunction> objective_fun_;
|
||||
/*! \brief Using predictor for prediction task */
|
||||
std::unique_ptr<Predictor> predictor_;
|
||||
/*! \brief mutex for threading safe call */
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
@ -692,9 +683,9 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
|
|||
const char* result_filename) {
|
||||
API_BEGIN();
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
|
||||
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
|
||||
bool bool_data_has_header = data_has_header > 0 ? true : false;
|
||||
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
|
||||
predictor.Predict(data_filename, result_filename, bool_data_has_header);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
@ -713,8 +704,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
|
|||
float* out_result) {
|
||||
API_BEGIN();
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
|
||||
|
||||
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
|
||||
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
|
||||
int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
|
||||
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
|
||||
|
@ -728,7 +718,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
|
|||
#pragma omp parallel for schedule(guided)
|
||||
for (int i = 0; i < nrow; ++i) {
|
||||
auto one_row = get_row_fun(i);
|
||||
auto predicton_result = ref_booster->Predict(one_row);
|
||||
auto predicton_result = predictor.GetPredictFunction()(one_row);
|
||||
for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
|
||||
out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
|
||||
}
|
||||
|
@ -749,8 +739,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
|
|||
float* out_result) {
|
||||
API_BEGIN();
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
|
||||
|
||||
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
|
||||
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
|
||||
int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
|
||||
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
|
||||
|
@ -763,7 +752,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
|
|||
#pragma omp parallel for schedule(guided)
|
||||
for (int i = 0; i < nrow; ++i) {
|
||||
auto one_row = get_row_fun(i);
|
||||
auto predicton_result = ref_booster->Predict(one_row);
|
||||
auto predicton_result = predictor.GetPredictFunction()(one_row);
|
||||
for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
|
||||
out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче