зеркало из https://github.com/microsoft/LightGBM.git
Support early stopping of prediction in CLI (#565)
* fix multi-threading. * fix name style. * support in CLI version. * remove warnings. * Not default parameters. * fix if...else... . * fix bug. * fix warning. * refine c_api. * fix R-package. * fix R's warning. * fix tests. * fix pep8 .
This commit is contained in:
Родитель
e04a8bb4de
Коммит
6d4c7b03b7
|
@ -54,7 +54,7 @@ if(USE_GPU)
|
|||
endif(USE_GPU)
|
||||
|
||||
if(UNIX OR MINGW OR CYGWIN)
|
||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -O3 -Wextra -Wall -std=c++11 -Wno-ignored-attributes")
|
||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas")
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
|
|
|
@ -391,7 +391,7 @@ Booster <- R6Class(
|
|||
rawscore = FALSE,
|
||||
predleaf = FALSE,
|
||||
header = FALSE,
|
||||
reshape = FALSE) {
|
||||
reshape = FALSE, ...) {
|
||||
|
||||
# Check if number of iteration is non existent
|
||||
if (is.null(num_iteration)) {
|
||||
|
@ -399,7 +399,7 @@ Booster <- R6Class(
|
|||
}
|
||||
|
||||
# Predict on new data
|
||||
predictor <- Predictor$new(private$handle)
|
||||
predictor <- Predictor$new(private$handle, ...)
|
||||
predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
|
||||
|
||||
},
|
||||
|
@ -645,7 +645,7 @@ predict.lgb.Booster <- function(object, data,
|
|||
rawscore = FALSE,
|
||||
predleaf = FALSE,
|
||||
header = FALSE,
|
||||
reshape = FALSE) {
|
||||
reshape = FALSE, ...) {
|
||||
|
||||
# Check booster existence
|
||||
if (!lgb.is.Booster(object)) {
|
||||
|
@ -658,7 +658,7 @@ predict.lgb.Booster <- function(object, data,
|
|||
rawscore,
|
||||
predleaf,
|
||||
header,
|
||||
reshape)
|
||||
reshape, ...)
|
||||
}
|
||||
|
||||
#' Load LightGBM model
|
||||
|
|
|
@ -18,8 +18,9 @@ Predictor <- R6Class(
|
|||
},
|
||||
|
||||
# Initialize will create a starter model
|
||||
initialize = function(modelfile) {
|
||||
|
||||
initialize = function(modelfile, ...) {
|
||||
params <- list(...)
|
||||
private$params <- lgb.params2str(params)
|
||||
# Create new lgb handle
|
||||
handle <- lgb.new.handle()
|
||||
|
||||
|
@ -86,6 +87,7 @@ Predictor <- R6Class(
|
|||
as.integer(rawscore),
|
||||
as.integer(predleaf),
|
||||
as.integer(num_iteration),
|
||||
private$params,
|
||||
lgb.c_str(tmp_filename))
|
||||
|
||||
# Get predictions from file
|
||||
|
@ -121,7 +123,8 @@ Predictor <- R6Class(
|
|||
as.integer(ncol(data)),
|
||||
as.integer(rawscore),
|
||||
as.integer(predleaf),
|
||||
as.integer(num_iteration))
|
||||
as.integer(num_iteration),
|
||||
private$params)
|
||||
|
||||
} else if (is(data, "dgCMatrix")) {
|
||||
|
||||
|
@ -137,7 +140,8 @@ Predictor <- R6Class(
|
|||
nrow(data),
|
||||
as.integer(rawscore),
|
||||
as.integer(predleaf),
|
||||
as.integer(num_iteration))
|
||||
as.integer(num_iteration),
|
||||
private$params)
|
||||
|
||||
} else {
|
||||
|
||||
|
@ -178,5 +182,6 @@ Predictor <- R6Class(
|
|||
|
||||
),
|
||||
private = list(handle = NULL,
|
||||
need_free_handle = FALSE)
|
||||
need_free_handle = FALSE,
|
||||
params = "")
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "../../src/boosting/boosting.cpp"
|
||||
#include "../../src/boosting/gbdt.cpp"
|
||||
#include "../../src/boosting/gbdt_prediction.cpp"
|
||||
#include "../../src/boosting/prediction_early_stop.cpp"
|
||||
|
||||
// io
|
||||
#include "../../src/io/bin.cpp"
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "./src/boosting/boosting.cpp"
|
||||
#include "./src/boosting/gbdt.cpp"
|
||||
#include "./src/boosting/gbdt_prediction.cpp"
|
||||
#include "./src/boosting/prediction_early_stop.cpp"
|
||||
|
||||
// io
|
||||
#include "./src/io/bin.cpp"
|
||||
|
|
|
@ -498,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
|
|||
LGBM_SE is_rawscore,
|
||||
LGBM_SE is_leafidx,
|
||||
LGBM_SE num_iteration,
|
||||
LGBM_SE parameter,
|
||||
LGBM_SE result_filename,
|
||||
LGBM_SE call_state) {
|
||||
R_API_BEGIN();
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx);
|
||||
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
|
||||
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration),
|
||||
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
|
||||
R_CHAR_PTR(result_filename)));
|
||||
R_API_END();
|
||||
}
|
||||
|
@ -534,6 +535,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
|
|||
LGBM_SE is_rawscore,
|
||||
LGBM_SE is_leafidx,
|
||||
LGBM_SE num_iteration,
|
||||
LGBM_SE parameter,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state) {
|
||||
|
||||
|
@ -552,7 +554,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
|
|||
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
|
||||
p_indptr, C_API_DTYPE_INT32, p_indices,
|
||||
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
|
||||
nrow, pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret));
|
||||
nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
|
@ -563,6 +565,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
|
|||
LGBM_SE is_rawscore,
|
||||
LGBM_SE is_leafidx,
|
||||
LGBM_SE num_iteration,
|
||||
LGBM_SE parameter,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state) {
|
||||
|
||||
|
@ -577,7 +580,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
|
|||
int64_t out_len;
|
||||
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
|
||||
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
|
||||
pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret));
|
||||
pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
|
||||
|
||||
R_API_END();
|
||||
}
|
||||
|
|
|
@ -389,6 +389,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
|
|||
LGBM_SE is_rawscore,
|
||||
LGBM_SE is_leafidx,
|
||||
LGBM_SE num_iteration,
|
||||
LGBM_SE parameter,
|
||||
LGBM_SE result_filename,
|
||||
LGBM_SE call_state);
|
||||
|
||||
|
@ -438,6 +439,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
|
|||
LGBM_SE is_rawscore,
|
||||
LGBM_SE is_leafidx,
|
||||
LGBM_SE num_iteration,
|
||||
LGBM_SE parameter,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state);
|
||||
|
||||
|
@ -463,6 +465,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
|
|||
LGBM_SE is_rawscore,
|
||||
LGBM_SE is_leafidx,
|
||||
LGBM_SE num_iteration,
|
||||
LGBM_SE parameter,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state);
|
||||
|
||||
|
|
|
@ -192,6 +192,12 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
|
|||
* `num_iteration_predict`, default=`-1`, type=int
|
||||
* only used in prediction task, used to how many trained iterations will be used in prediction.
|
||||
* `<= 0` means no limit
|
||||
* `pred_early_stop`, default=`false`, type=bool
|
||||
* Set to `true` will use early-stopping to speed up the prediction. May affect the accuracy.
|
||||
* `pred_early_stop_freq`, default=`10`, type=int
|
||||
* The frequency of checking early-stopping prediction.
|
||||
* `pred_early_stop_margin`, default=`10.0`, type=double
|
||||
* The Threshold of margin in early-stopping prediction.
|
||||
* `use_missing`, default=`true`, type=bool
|
||||
* Set to `false` will disbale the special handle of missing value.
|
||||
|
||||
|
|
|
@ -117,19 +117,19 @@ public:
|
|||
* \brief Prediction for one record, not sigmoid transform
|
||||
* \param feature_values Feature value on this record
|
||||
* \param output Prediction result for this record
|
||||
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
|
||||
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
|
||||
*/
|
||||
virtual void PredictRaw(const double* features, double* output,
|
||||
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;
|
||||
const PredictionEarlyStopInstance* early_stop) const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Prediction for one record, sigmoid transformation will be used if needed
|
||||
* \param feature_values Feature value on this record
|
||||
* \param output Prediction result for this record
|
||||
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
|
||||
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
|
||||
*/
|
||||
virtual void Predict(const double* features, double* output,
|
||||
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;
|
||||
const PredictionEarlyStopInstance* early_stop) const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Prediction for one record with leaf index
|
||||
|
@ -220,6 +220,9 @@ public:
|
|||
*/
|
||||
virtual int NumberOfClasses() const = 0;
|
||||
|
||||
/*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
|
||||
virtual bool NeedAccuratePrediction() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Initial work for the prediction
|
||||
* \param num_iteration number of used iteration
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
|
||||
typedef void* DatasetHandle;
|
||||
typedef void* BoosterHandle;
|
||||
typedef void* PredictionEarlyStoppingHandle;
|
||||
|
||||
#define C_API_DTYPE_FLOAT32 (0)
|
||||
#define C_API_DTYPE_FLOAT64 (1)
|
||||
|
@ -522,7 +521,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
|
|||
int data_has_header,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
const char* result_filename);
|
||||
|
||||
/*!
|
||||
|
@ -578,7 +577,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
|
|||
int64_t num_col,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result);
|
||||
|
||||
|
@ -617,7 +616,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
|
|||
int64_t num_row,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result);
|
||||
|
||||
|
@ -650,7 +649,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
|
|||
int is_row_major,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result);
|
||||
|
||||
|
@ -721,25 +720,6 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
|
|||
int leaf_idx,
|
||||
double val);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief create an new prediction early stopping instance that can be used to speed up prediction
|
||||
* \param type early stopping type: "none", "multiclass" or "binary"
|
||||
* \param round_period how often the classifier score is checked for the early stopping condition
|
||||
* \param margin_threshold when the margin exceeds this value, early stopping kicks in and no more trees are evaluated
|
||||
* \param out handle of created instance
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
|
||||
int round_period,
|
||||
double margin_threshold,
|
||||
PredictionEarlyStoppingHandle* out);
|
||||
/*!
|
||||
\brief free prediction early stop instance
|
||||
\return 0 when succeed
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle);
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
// exception handle and error msg
|
||||
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; }
|
||||
|
@ -747,6 +727,7 @@ static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Eve
|
|||
static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; }
|
||||
#endif
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
inline void LGBM_SetLastError(const char* msg) {
|
||||
std::strcpy(LastErrorMsg(), msg);
|
||||
}
|
||||
|
|
|
@ -135,6 +135,14 @@ public:
|
|||
* Note: when using Index, it doesn't count the label index */
|
||||
std::string categorical_column = "";
|
||||
std::string device_type = "cpu";
|
||||
|
||||
/*! \brief Set to true if want to use early stop for the prediction */
|
||||
bool pred_early_stop = false;
|
||||
/*! \brief Frequency of checking the pred_early_stop */
|
||||
int pred_early_stop_freq = 10;
|
||||
/*! \brief Threshold of margin of pred_early_stop */
|
||||
double pred_early_stop_margin = 10.0f;
|
||||
|
||||
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
|
||||
private:
|
||||
void GetDeviceType(const std::unordered_map<std::string,
|
||||
|
|
|
@ -43,6 +43,9 @@ public:
|
|||
|
||||
virtual int NumPredictOneRow() const { return 1; }
|
||||
|
||||
/*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
|
||||
virtual bool NeedAccuratePrediction() const { return true; }
|
||||
|
||||
virtual void ConvertOutput(const double* input, double* output) const {
|
||||
output[0] = input[0];
|
||||
}
|
||||
|
|
|
@ -6,28 +6,28 @@
|
|||
|
||||
#include <LightGBM/export.h>
|
||||
|
||||
namespace LightGBM
|
||||
{
|
||||
struct PredictionEarlyStopInstance
|
||||
{
|
||||
/// Callback function type for early stopping.
|
||||
/// Takes current prediction and number of elements in prediction
|
||||
/// @returns true if prediction should stop according to criterion
|
||||
using FunctionType = std::function<bool(const double*, int)>;
|
||||
namespace LightGBM {
|
||||
|
||||
FunctionType callbackFunction; // callback function itself
|
||||
int roundPeriod; // call callbackFunction every `runPeriod` iterations
|
||||
};
|
||||
#pragma warning(disable : 4099)
|
||||
struct PredictionEarlyStopInstance {
|
||||
/// Callback function type for early stopping.
|
||||
/// Takes current prediction and number of elements in prediction
|
||||
/// @returns true if prediction should stop according to criterion
|
||||
using FunctionType = std::function<bool(const double*, int)>;
|
||||
|
||||
struct PredictionEarlyStopConfig
|
||||
{
|
||||
int roundPeriod;
|
||||
double marginThreshold;
|
||||
};
|
||||
FunctionType callback_function; // callback function itself
|
||||
int round_period; // call callback_function every `runPeriod` iterations
|
||||
};
|
||||
|
||||
/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold
|
||||
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
|
||||
const PredictionEarlyStopConfig& config);
|
||||
#pragma warning(disable : 4099)
|
||||
struct PredictionEarlyStopConfig {
|
||||
int round_period;
|
||||
double margin_threshold;
|
||||
};
|
||||
|
||||
/// Create an early stopping algorithm of type `type`, with given round_period and margin threshold
|
||||
LIGHTGBM_EXPORT PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
|
||||
const PredictionEarlyStopConfig& config);
|
||||
|
||||
} // namespace LightGBM
|
||||
|
||||
|
|
|
@ -508,7 +508,7 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
|
|||
// Buffer for merge.
|
||||
std::vector<_VTRanIt> temp_buf(len);
|
||||
_RanIt buf = temp_buf.begin();
|
||||
int s = inner_size;
|
||||
size_t s = inner_size;
|
||||
// Recursive merge
|
||||
while (s < len) {
|
||||
int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
|
||||
|
|
|
@ -6,7 +6,7 @@ Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors
|
|||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .basic import Booster, Dataset, PredictionEarlyStopInstance
|
||||
from .basic import Booster, Dataset
|
||||
from .callback import (early_stopping, print_evaluation, record_evaluation,
|
||||
reset_parameter)
|
||||
from .engine import cv, train
|
||||
|
@ -23,7 +23,7 @@ except ImportError:
|
|||
|
||||
__version__ = 0.2
|
||||
|
||||
__all__ = ['Dataset', 'Booster', 'PredictionEarlyStopInstance',
|
||||
__all__ = ['Dataset', 'Booster',
|
||||
'train', 'cv',
|
||||
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
|
||||
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
|
||||
|
|
|
@ -296,7 +296,7 @@ class _InnerPredictor(object):
|
|||
Only used for prediction, usually used for continued-train
|
||||
Note: Can convert from Booster, but cannot convert to Booster
|
||||
"""
|
||||
def __init__(self, model_file=None, booster_handle=None, early_stop_instance=None):
|
||||
def __init__(self, model_file=None, booster_handle=None, pred_parameter=None):
|
||||
"""Initialize the _InnerPredictor. Not expose to user
|
||||
|
||||
Parameters
|
||||
|
@ -305,8 +305,8 @@ class _InnerPredictor(object):
|
|||
Path to the model file.
|
||||
booster_handle : Handle of Booster
|
||||
use handle to init
|
||||
early_stop_instance: object of type PredictionEarlyStopInstance
|
||||
If None, no early stopping is applied
|
||||
pred_parameter: dict
|
||||
Other parameters for the prediciton
|
||||
"""
|
||||
self.handle = ctypes.c_void_p()
|
||||
self.__is_manage_handle = True
|
||||
|
@ -341,10 +341,8 @@ class _InnerPredictor(object):
|
|||
else:
|
||||
raise TypeError('Need Model file or Booster handle to create a predictor')
|
||||
|
||||
if early_stop_instance is None:
|
||||
self.early_stop_instance = PredictionEarlyStopInstance("none")
|
||||
else:
|
||||
self.early_stop_instance = early_stop_instance
|
||||
pred_parameter = {} if pred_parameter is None else pred_parameter
|
||||
self.pred_parameter = param_dict_to_str(pred_parameter)
|
||||
|
||||
def __del__(self):
|
||||
if self.__is_manage_handle:
|
||||
|
@ -401,7 +399,7 @@ class _InnerPredictor(object):
|
|||
ctypes.c_int(int_data_has_header),
|
||||
ctypes.c_int(predict_type),
|
||||
ctypes.c_int(num_iteration),
|
||||
self.early_stop_instance.handle,
|
||||
c_str(self.pred_parameter),
|
||||
c_str(f.name)))
|
||||
lines = f.readlines()
|
||||
nrow = len(lines)
|
||||
|
@ -475,7 +473,7 @@ class _InnerPredictor(object):
|
|||
ctypes.c_int(C_API_IS_ROW_MAJOR),
|
||||
ctypes.c_int(predict_type),
|
||||
ctypes.c_int(num_iteration),
|
||||
self.early_stop_instance.handle,
|
||||
c_str(self.pred_parameter),
|
||||
ctypes.byref(out_num_preds),
|
||||
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
|
||||
if n_preds != out_num_preds.value:
|
||||
|
@ -506,7 +504,7 @@ class _InnerPredictor(object):
|
|||
ctypes.c_int64(csr.shape[1]),
|
||||
ctypes.c_int(predict_type),
|
||||
ctypes.c_int(num_iteration),
|
||||
self.early_stop_instance.handle,
|
||||
c_str(self.pred_parameter),
|
||||
ctypes.byref(out_num_preds),
|
||||
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
|
||||
if n_preds != out_num_preds.value:
|
||||
|
@ -537,7 +535,7 @@ class _InnerPredictor(object):
|
|||
ctypes.c_int64(csc.shape[0]),
|
||||
ctypes.c_int(predict_type),
|
||||
ctypes.c_int(num_iteration),
|
||||
self.early_stop_instance.handle,
|
||||
c_str(self.pred_parameter),
|
||||
ctypes.byref(out_num_preds),
|
||||
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
|
||||
if n_preds != out_num_preds.value:
|
||||
|
@ -1581,7 +1579,7 @@ class Booster(object):
|
|||
return json.loads(string_buffer.value.decode())
|
||||
|
||||
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True,
|
||||
early_stop_instance=None):
|
||||
pred_parameter=None):
|
||||
"""
|
||||
Predict logic
|
||||
|
||||
|
@ -1600,21 +1598,21 @@ class Booster(object):
|
|||
Used for txt data
|
||||
is_reshape : bool
|
||||
Reshape to (nrow, ncol) if true
|
||||
early_stop_instance: object of type PredictionEarlyStopInstance.
|
||||
If None, no early stopping is applied
|
||||
pred_parameter: dict
|
||||
Other parameters for the prediction
|
||||
|
||||
Returns
|
||||
-------
|
||||
Prediction result
|
||||
"""
|
||||
predictor = self._to_predictor(early_stop_instance)
|
||||
predictor = self._to_predictor(pred_parameter)
|
||||
if num_iteration <= 0:
|
||||
num_iteration = self.best_iteration
|
||||
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)
|
||||
|
||||
def _to_predictor(self, early_stop_instance=None):
|
||||
def _to_predictor(self, pred_parameter=None):
|
||||
"""Convert to predictor"""
|
||||
predictor = _InnerPredictor(booster_handle=self.handle, early_stop_instance=early_stop_instance)
|
||||
predictor = _InnerPredictor(booster_handle=self.handle, pred_parameter=pred_parameter)
|
||||
predictor.pandas_categorical = self.pandas_categorical
|
||||
return predictor
|
||||
|
||||
|
@ -1800,35 +1798,3 @@ class Booster(object):
|
|||
self.__attr[key] = value
|
||||
else:
|
||||
self.__attr.pop(key, None)
|
||||
|
||||
|
||||
class PredictionEarlyStopInstance(object):
|
||||
""""PredictionEarlyStopInstance in LightGBM."""
|
||||
def __init__(self, early_stop_type="none", round_period=20, margin_threshold=1.5):
|
||||
"""
|
||||
Create an early stopping object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
early_stop_type: string
|
||||
None, "none", "binary" or "multiclass". Regression is not supported.
|
||||
round_period : int
|
||||
The score will be checked every round_period to check if the early stopping criteria is met
|
||||
margin_threshold : double
|
||||
Early stopping will kick in when the margin is greater than margin_threshold
|
||||
"""
|
||||
self.handle = ctypes.c_void_p(0)
|
||||
self.__attr = {}
|
||||
|
||||
if early_stop_type is None:
|
||||
early_stop_type = "none"
|
||||
|
||||
_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceCreate(
|
||||
c_str(early_stop_type),
|
||||
ctypes.c_int(round_period),
|
||||
ctypes.c_double(margin_threshold),
|
||||
ctypes.byref(self.handle)))
|
||||
|
||||
def __del__(self):
|
||||
if self.handle is not None:
|
||||
_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceFree(self.handle))
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include <LightGBM/dataset_loader.h>
|
||||
#include <LightGBM/boosting.h>
|
||||
#include <LightGBM/objective_function.h>
|
||||
#include <LightGBM/prediction_early_stop.h>
|
||||
#include <LightGBM/metric.h>
|
||||
|
||||
#include "predictor.hpp"
|
||||
|
@ -107,9 +108,10 @@ void Application::LoadData() {
|
|||
std::unique_ptr<Predictor> predictor;
|
||||
// prediction is needed if using input initial model(continued train)
|
||||
PredictFunction predict_fun = nullptr;
|
||||
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
|
||||
// need to continue training
|
||||
if (boosting_->NumberOfTotalModel() > 0) {
|
||||
predictor.reset(new Predictor(boosting_.get(), -1, true, false));
|
||||
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, -1, -1));
|
||||
predict_fun = predictor->GetPredictFunction();
|
||||
}
|
||||
|
||||
|
@ -250,7 +252,8 @@ void Application::Train() {
|
|||
void Application::Predict() {
|
||||
// create predictor
|
||||
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
|
||||
config_.io_config.is_predict_leaf_index);
|
||||
config_.io_config.is_predict_leaf_index, config_.io_config.pred_early_stop,
|
||||
config_.io_config.pred_early_stop_freq, config_.io_config.pred_early_stop_margin);
|
||||
predictor.Predict(config_.io_config.data_filename.c_str(),
|
||||
config_.io_config.output_result.c_str(), config_.io_config.has_header);
|
||||
Log::Info("Finished prediction");
|
||||
|
|
|
@ -32,7 +32,20 @@ public:
|
|||
*/
|
||||
Predictor(Boosting* boosting, int num_iteration,
|
||||
bool is_raw_score, bool is_predict_leaf_index,
|
||||
const PredictionEarlyStopInstance* earlyStop = nullptr) {
|
||||
bool early_stop, int early_stop_freq, double early_stop_margin) {
|
||||
|
||||
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
|
||||
if (early_stop && !boosting->NeedAccuratePrediction()) {
|
||||
PredictionEarlyStopConfig pred_early_stop_config;
|
||||
pred_early_stop_config.margin_threshold = early_stop_margin;
|
||||
pred_early_stop_config.round_period = early_stop_freq;
|
||||
if (boosting->NumberOfClasses() == 1) {
|
||||
early_stop_ = CreatePredictionEarlyStopInstance("binary", pred_early_stop_config);
|
||||
} else {
|
||||
early_stop_ = CreatePredictionEarlyStopInstance("multiclass", pred_early_stop_config);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel
|
||||
#pragma omp master
|
||||
{
|
||||
|
@ -55,17 +68,17 @@ public:
|
|||
|
||||
} else {
|
||||
if (is_raw_score) {
|
||||
predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) {
|
||||
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
|
||||
int tid = omp_get_thread_num();
|
||||
CopyToPredictBuffer(predict_buf_[tid].data(), features);
|
||||
boosting_->PredictRaw(predict_buf_[tid].data(), output, earlyStop);
|
||||
boosting_->PredictRaw(predict_buf_[tid].data(), output, &early_stop_);
|
||||
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
|
||||
};
|
||||
} else {
|
||||
predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) {
|
||||
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
|
||||
int tid = omp_get_thread_num();
|
||||
CopyToPredictBuffer(predict_buf_[tid].data(), features);
|
||||
boosting_->Predict(predict_buf_[tid].data(), output, earlyStop);
|
||||
boosting_->Predict(predict_buf_[tid].data(), output, &early_stop_);
|
||||
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
|
||||
};
|
||||
}
|
||||
|
@ -117,7 +130,11 @@ public:
|
|||
[this, &parser_fun, &result_file]
|
||||
(data_size_t, const std::vector<std::string>& lines) {
|
||||
std::vector<std::pair<int, double>> oneline_features;
|
||||
std::vector<std::string> result_to_write(lines.size());
|
||||
OMP_INIT_EX();
|
||||
#pragma omp parallel for schedule(static) firstprivate(oneline_features)
|
||||
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
|
||||
OMP_LOOP_EX_BEGIN();
|
||||
oneline_features.clear();
|
||||
// parser
|
||||
parser_fun(lines[i].c_str(), &oneline_features);
|
||||
|
@ -125,7 +142,12 @@ public:
|
|||
std::vector<double> result(num_pred_one_row_);
|
||||
predict_fun_(oneline_features, result.data());
|
||||
auto str_result = Common::Join<double>(result, "\t");
|
||||
fprintf(result_file, "%s\n", str_result.c_str());
|
||||
result_to_write[i] = str_result;
|
||||
OMP_LOOP_EX_END();
|
||||
}
|
||||
OMP_THROW_EX();
|
||||
for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
|
||||
fprintf(result_file, "%s\n", result_to_write[i].c_str());
|
||||
}
|
||||
};
|
||||
TextReader<data_size_t> predict_data_reader(data_filename, has_header);
|
||||
|
@ -137,7 +159,6 @@ private:
|
|||
|
||||
void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
|
||||
int loop_size = static_cast<int>(features.size());
|
||||
#pragma omp parallel for schedule(static,128) if (loop_size >= 256)
|
||||
for (int i = 0; i < loop_size; ++i) {
|
||||
if (features[i].first < num_feature_) {
|
||||
pred_buf[features[i].first] = features[i].second;
|
||||
|
@ -150,7 +171,6 @@ private:
|
|||
std::memset(pred_buf, 0, sizeof(double)*(buf_size));
|
||||
} else {
|
||||
int loop_size = static_cast<int>(features.size());
|
||||
#pragma omp parallel for schedule(static,128) if (loop_size >= 256)
|
||||
for (int i = 0; i < loop_size; ++i) {
|
||||
pred_buf[features[i].first] = 0.0f;
|
||||
}
|
||||
|
@ -161,6 +181,7 @@ private:
|
|||
const Boosting* boosting_;
|
||||
/*! \brief function for prediction */
|
||||
PredictFunction predict_fun_;
|
||||
PredictionEarlyStopInstance early_stop_;
|
||||
int num_feature_;
|
||||
int num_pred_one_row_;
|
||||
int num_threads_;
|
||||
|
|
|
@ -43,9 +43,9 @@ GBDT::GBDT()
|
|||
boost_from_average_(false) {
|
||||
#pragma omp parallel
|
||||
#pragma omp master
|
||||
{
|
||||
num_threads_ = omp_get_num_threads();
|
||||
}
|
||||
{
|
||||
num_threads_ = omp_get_num_threads();
|
||||
}
|
||||
}
|
||||
|
||||
GBDT::~GBDT() {
|
||||
|
@ -262,8 +262,6 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
|
|||
return cur_left_cnt;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void GBDT::Bagging(int iter) {
|
||||
// if need bagging
|
||||
if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
|
||||
|
@ -738,32 +736,27 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
|
|||
|
||||
std::stringstream pred_str_buf;
|
||||
|
||||
pred_str_buf << "\t" << "const auto noEarlyStop = createPredictionEarlyStopInstance(\"none\", PredictionEarlyStopConfig());" << std::endl;
|
||||
pred_str_buf << "\t" << "if (earlyStop == nullptr) {" << std::endl;
|
||||
pred_str_buf << "\t\t" << "earlyStop = &noEarlyStop;" << std::endl;
|
||||
pred_str_buf << "\t" << "}" << std::endl;
|
||||
|
||||
pred_str_buf << "\t" << "int earlyStopRoundCounter = 0;" << std::endl;
|
||||
pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
|
||||
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
|
||||
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
|
||||
pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
|
||||
pred_str_buf << "\t\t" << "}" << std::endl;
|
||||
pred_str_buf << "\t\t" << "++earlyStopRoundCounter;" << std::endl;
|
||||
pred_str_buf << "\t\t" << "if (earlyStop->roundPeriod == earlyStopRoundCounter) {" << std::endl;
|
||||
pred_str_buf << "\t\t\t" << "if (earlyStop->callbackFunction(output, num_tree_per_iteration_))" << std::endl;
|
||||
pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
|
||||
pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
|
||||
pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
|
||||
pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
|
||||
pred_str_buf << "\t\t\t" << "earlyStopRoundCounter = 0;" << std::endl;
|
||||
pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
|
||||
pred_str_buf << "\t\t" << "}" << std::endl;
|
||||
pred_str_buf << "\t" << "}" << std::endl;
|
||||
|
||||
str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl;
|
||||
str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
|
||||
str_buf << pred_str_buf.str();
|
||||
str_buf << "}" << std::endl;
|
||||
str_buf << std::endl;
|
||||
|
||||
// Predict
|
||||
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* earlyStop) const {" << std::endl;
|
||||
str_buf << "\t" << "PredictRaw(features, output, earlyStop);" << std::endl;
|
||||
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
|
||||
str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
|
||||
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
|
||||
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
|
||||
str_buf << "\t" << "}" << std::endl;
|
||||
|
@ -786,7 +779,6 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
|
|||
|
||||
str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
|
||||
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
|
||||
str_buf << "\t" << "#pragma omp parallel for schedule(static)" << std::endl;
|
||||
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
|
||||
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
|
||||
str_buf << "\t" << "}" << std::endl;
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
#define LIGHTGBM_BOOSTING_GBDT_H_
|
||||
|
||||
#include <LightGBM/boosting.h>
|
||||
#include <LightGBM/objective_function.h>
|
||||
|
||||
#include "score_updater.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
|
@ -93,6 +95,14 @@ public:
|
|||
|
||||
bool EvalAndCheckEarlyStopping() override;
|
||||
|
||||
bool NeedAccuratePrediction() const override {
|
||||
if (objective_function_ == nullptr) {
|
||||
return true;
|
||||
} else {
|
||||
return objective_function_->NeedAccuratePrediction();
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get evaluation result at data_idx data
|
||||
* \param data_idx 0: training data, 1: 1st validation data
|
||||
|
@ -137,7 +147,7 @@ public:
|
|||
}
|
||||
|
||||
void PredictRaw(const double* features, double* output,
|
||||
const PredictionEarlyStopInstance* earlyStop = nullptr) const override;
|
||||
const PredictionEarlyStopInstance* earlyStop) const override;
|
||||
|
||||
void Predict(const double* features, double* output,
|
||||
const PredictionEarlyStopInstance* earlyStop) const override;
|
||||
|
@ -365,7 +375,6 @@ protected:
|
|||
std::vector<double> class_default_output_;
|
||||
bool is_constant_hessian_;
|
||||
std::unique_ptr<ObjectiveFunction> loaded_objective_;
|
||||
|
||||
};
|
||||
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -16,21 +16,10 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
namespace
|
||||
{
|
||||
/// Singleton used when earlyStop is nullptr in PredictRaw()
|
||||
const auto noEarlyStop = LightGBM::createPredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
|
||||
}
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
void GBDT::PredictRaw(const double* features, double* output, const PredictionEarlyStopInstance* earlyStop) const {
|
||||
if (earlyStop == nullptr)
|
||||
{
|
||||
earlyStop = &noEarlyStop;
|
||||
}
|
||||
|
||||
int earlyStopRoundCounter = 0;
|
||||
void GBDT::PredictRaw(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
|
||||
int early_stop_round_counter = 0;
|
||||
for (int i = 0; i < num_iteration_for_pred_; ++i) {
|
||||
// predict all the trees for one iteration
|
||||
for (int k = 0; k < num_tree_per_iteration_; ++k) {
|
||||
|
@ -38,18 +27,18 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa
|
|||
}
|
||||
|
||||
// check early stopping
|
||||
++earlyStopRoundCounter;
|
||||
if (earlyStop->roundPeriod == earlyStopRoundCounter) {
|
||||
if (earlyStop->callbackFunction(output, num_tree_per_iteration_)) {
|
||||
++early_stop_round_counter;
|
||||
if (early_stop->round_period == early_stop_round_counter) {
|
||||
if (early_stop->callback_function(output, num_tree_per_iteration_)) {
|
||||
return;
|
||||
}
|
||||
earlyStopRoundCounter = 0;
|
||||
early_stop_round_counter = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* earlyStop) const {
|
||||
PredictRaw(features, output, earlyStop);
|
||||
void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
|
||||
PredictRaw(features, output, early_stop);
|
||||
|
||||
if (objective_function_ != nullptr) {
|
||||
objective_function_->ConvertOutput(output, output);
|
||||
|
@ -58,7 +47,6 @@ void GBDT::Predict(const double* features, double* output, const PredictionEarly
|
|||
|
||||
void GBDT::PredictLeafIndex(const double* features, double* output) const {
|
||||
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < total_tree; ++i) {
|
||||
output[i] = models_[i]->PredictLeafIndex(features);
|
||||
}
|
||||
|
|
|
@ -1,101 +1,90 @@
|
|||
#include <LightGBM/prediction_early_stop.h>
|
||||
#include <LightGBM/utils/log.h>
|
||||
|
||||
using namespace LightGBM;
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
namespace
|
||||
{
|
||||
PredictionEarlyStopInstance createNone(const PredictionEarlyStopConfig&)
|
||||
{
|
||||
return PredictionEarlyStopInstance{
|
||||
[](const double*, int)
|
||||
{
|
||||
return false;
|
||||
},
|
||||
std::numeric_limits<int>::max() // make sure the lambda is almost never called
|
||||
};
|
||||
}
|
||||
namespace {
|
||||
|
||||
PredictionEarlyStopInstance createMulticlass(const PredictionEarlyStopConfig& config)
|
||||
{
|
||||
// marginThreshold will be captured by value
|
||||
const double marginThreshold = config.marginThreshold;
|
||||
using namespace LightGBM;
|
||||
|
||||
return PredictionEarlyStopInstance{
|
||||
[marginThreshold](const double* pred, int sz)
|
||||
{
|
||||
if(sz < 2) {
|
||||
Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
|
||||
}
|
||||
PredictionEarlyStopInstance CreateNone(const PredictionEarlyStopConfig&) {
|
||||
return PredictionEarlyStopInstance{
|
||||
[](const double*, int) {
|
||||
return false;
|
||||
},
|
||||
std::numeric_limits<int>::max() // make sure the lambda is almost never called
|
||||
};
|
||||
}
|
||||
|
||||
// copy and sort
|
||||
std::vector<double> votes(static_cast<size_t>(sz));
|
||||
for (int i=0; i < sz; ++i) {
|
||||
votes[i] = pred[i];
|
||||
}
|
||||
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());
|
||||
PredictionEarlyStopInstance CreateMulticlass(const PredictionEarlyStopConfig& config) {
|
||||
// margin_threshold will be captured by value
|
||||
const double margin_threshold = config.margin_threshold;
|
||||
|
||||
const auto margin = votes[0] - votes[1];
|
||||
return PredictionEarlyStopInstance{
|
||||
[margin_threshold](const double* pred, int sz) {
|
||||
if (sz < 2) {
|
||||
Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
|
||||
}
|
||||
|
||||
if (margin > marginThreshold) {
|
||||
return true;
|
||||
}
|
||||
// copy and sort
|
||||
std::vector<double> votes(static_cast<size_t>(sz));
|
||||
for (int i = 0; i < sz; ++i) {
|
||||
votes[i] = pred[i];
|
||||
}
|
||||
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());
|
||||
|
||||
return false;
|
||||
},
|
||||
config.roundPeriod
|
||||
};
|
||||
}
|
||||
const auto margin = votes[0] - votes[1];
|
||||
|
||||
PredictionEarlyStopInstance createBinary(const PredictionEarlyStopConfig& config)
|
||||
{
|
||||
// marginThreshold will be captured by value
|
||||
const double marginThreshold = config.marginThreshold;
|
||||
if (margin > margin_threshold) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return PredictionEarlyStopInstance{
|
||||
[marginThreshold](const double* pred, int sz)
|
||||
{
|
||||
if(sz != 1) {
|
||||
Log::Fatal("Binary early stopping needs predictions to be of length one");
|
||||
}
|
||||
const auto margin = 2.0 * fabs(pred[0]);
|
||||
return false;
|
||||
},
|
||||
config.round_period
|
||||
};
|
||||
}
|
||||
|
||||
if (margin > marginThreshold) {
|
||||
return true;
|
||||
}
|
||||
PredictionEarlyStopInstance CreateBinary(const PredictionEarlyStopConfig& config) {
|
||||
// margin_threshold will be captured by value
|
||||
const double margin_threshold = config.margin_threshold;
|
||||
|
||||
return false;
|
||||
},
|
||||
config.roundPeriod
|
||||
};
|
||||
return PredictionEarlyStopInstance{
|
||||
[margin_threshold](const double* pred, int sz) {
|
||||
if (sz != 1) {
|
||||
Log::Fatal("Binary early stopping needs predictions to be of length one");
|
||||
}
|
||||
const auto margin = 2.0 * fabs(pred[0]);
|
||||
|
||||
if (margin > margin_threshold) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
},
|
||||
config.round_period
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
|
||||
const PredictionEarlyStopConfig& config) {
|
||||
if (type == "none") {
|
||||
return CreateNone(config);
|
||||
} else if (type == "multiclass") {
|
||||
return CreateMulticlass(config);
|
||||
} else if (type == "binary") {
|
||||
return CreateBinary(config);
|
||||
} else {
|
||||
throw std::runtime_error("Unknown early stopping type: " + type);
|
||||
}
|
||||
}
|
||||
|
||||
namespace LightGBM
|
||||
{
|
||||
PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
|
||||
const PredictionEarlyStopConfig& config)
|
||||
{
|
||||
if (type == "none")
|
||||
{
|
||||
return createNone(config);
|
||||
}
|
||||
else if (type == "multiclass")
|
||||
{
|
||||
return createMulticlass(config);
|
||||
}
|
||||
else if (type == "binary")
|
||||
{
|
||||
return createBinary(config);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unknown early stopping type: " + type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -163,7 +163,7 @@ public:
|
|||
|
||||
void Predict(int num_iteration, int predict_type, int nrow,
|
||||
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
double* out_result, int64_t* out_len) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
bool is_predict_leaf = false;
|
||||
|
@ -175,21 +175,28 @@ public:
|
|||
} else {
|
||||
is_raw_score = false;
|
||||
}
|
||||
auto param = ConfigBase::Str2Map(parameter);
|
||||
IOConfig config;
|
||||
config.Set(param);
|
||||
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
|
||||
reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle));
|
||||
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
|
||||
int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
|
||||
auto pred_fun = predictor.GetPredictFunction();
|
||||
auto pred_wrt_ptr = out_result;
|
||||
OMP_INIT_EX();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < nrow; ++i) {
|
||||
OMP_LOOP_EX_BEGIN();
|
||||
auto one_row = get_row_fun(i);
|
||||
auto pred_wrt_ptr = out_result + static_cast<size_t>(num_preb_in_one_row) * i;
|
||||
pred_fun(one_row, pred_wrt_ptr);
|
||||
pred_wrt_ptr += num_preb_in_one_row;
|
||||
OMP_LOOP_EX_END();
|
||||
}
|
||||
OMP_THROW_EX();
|
||||
*out_len = nrow * num_preb_in_one_row;
|
||||
}
|
||||
|
||||
void Predict(int num_iteration, int predict_type, const char* data_filename,
|
||||
int data_has_header, const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
int data_has_header, const char* parameter,
|
||||
const char* result_filename) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
bool is_predict_leaf = false;
|
||||
|
@ -201,8 +208,11 @@ public:
|
|||
} else {
|
||||
is_raw_score = false;
|
||||
}
|
||||
auto param = ConfigBase::Str2Map(parameter);
|
||||
IOConfig config;
|
||||
config.Set(param);
|
||||
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
|
||||
reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle));
|
||||
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
|
||||
bool bool_data_has_header = data_has_header > 0 ? true : false;
|
||||
predictor.Predict(data_filename, result_filename, bool_data_has_header);
|
||||
}
|
||||
|
@ -244,6 +254,7 @@ public:
|
|||
return ret;
|
||||
}
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
int GetEvalNames(char** out_strs) const {
|
||||
int idx = 0;
|
||||
for (const auto& metric : train_metric_) {
|
||||
|
@ -255,6 +266,7 @@ public:
|
|||
return idx;
|
||||
}
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
int GetFeatureNames(char** out_strs) const {
|
||||
int idx = 0;
|
||||
for (const auto& name : boosting_->FeatureNames()) {
|
||||
|
@ -681,6 +693,7 @@ int LGBM_DatasetSetFeatureNames(
|
|||
API_END();
|
||||
}
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
int LGBM_DatasetGetFeatureNames(
|
||||
DatasetHandle handle,
|
||||
char** feature_names,
|
||||
|
@ -695,6 +708,7 @@ int LGBM_DatasetGetFeatureNames(
|
|||
API_END();
|
||||
}
|
||||
|
||||
#pragma warning(disable : 4702)
|
||||
int LGBM_DatasetFree(DatasetHandle handle) {
|
||||
API_BEGIN();
|
||||
delete reinterpret_cast<Dataset*>(handle);
|
||||
|
@ -802,6 +816,7 @@ int LGBM_BoosterLoadModelFromString(
|
|||
API_END();
|
||||
}
|
||||
|
||||
#pragma warning(disable : 4702)
|
||||
int LGBM_BoosterFree(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
delete reinterpret_cast<Booster*>(handle);
|
||||
|
@ -955,12 +970,12 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
|
|||
int data_has_header,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
const char* result_filename) {
|
||||
API_BEGIN();
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
|
||||
early_stop_handle, result_filename);
|
||||
parameter, result_filename);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
@ -987,7 +1002,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
|
|||
int64_t,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result) {
|
||||
API_BEGIN();
|
||||
|
@ -995,7 +1010,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
|
|||
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
|
||||
int nrow = static_cast<int>(nindptr - 1);
|
||||
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
|
||||
early_stop_handle, out_result, out_len);
|
||||
parameter, out_result, out_len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
@ -1010,7 +1025,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
|
|||
int64_t num_row,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result) {
|
||||
API_BEGIN();
|
||||
|
@ -1021,7 +1036,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
|
|||
iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
|
||||
}
|
||||
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
|
||||
[&iterators, ncol](int i) {
|
||||
[&iterators, ncol] (int i) {
|
||||
std::vector<std::pair<int, double>> one_row;
|
||||
for (int j = 0; j < ncol; ++j) {
|
||||
auto val = iterators[j].Get(i);
|
||||
|
@ -1031,7 +1046,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
|
|||
}
|
||||
return one_row;
|
||||
};
|
||||
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, early_stop_handle,
|
||||
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, parameter,
|
||||
out_result, out_len);
|
||||
API_END();
|
||||
}
|
||||
|
@ -1044,14 +1059,14 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
|
|||
int is_row_major,
|
||||
int predict_type,
|
||||
int num_iteration,
|
||||
const PredictionEarlyStoppingHandle early_stop_handle,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result) {
|
||||
API_BEGIN();
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
|
||||
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
|
||||
early_stop_handle, out_result, out_len);
|
||||
parameter, out_result, out_len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
@ -1064,6 +1079,7 @@ int LGBM_BoosterSaveModel(BoosterHandle handle,
|
|||
API_END();
|
||||
}
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
|
||||
int num_iteration,
|
||||
int buffer_len,
|
||||
|
@ -1079,6 +1095,7 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
|
|||
API_END();
|
||||
}
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
int LGBM_BoosterDumpModel(BoosterHandle handle,
|
||||
int num_iteration,
|
||||
int buffer_len,
|
||||
|
@ -1114,31 +1131,6 @@ int LGBM_BoosterSetLeafValue(BoosterHandle handle,
|
|||
API_END();
|
||||
}
|
||||
|
||||
|
||||
int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
|
||||
int round_period,
|
||||
double margin_threshold,
|
||||
PredictionEarlyStoppingHandle* out)
|
||||
{
|
||||
API_BEGIN();
|
||||
PredictionEarlyStopConfig config;
|
||||
config.marginThreshold = margin_threshold;
|
||||
config.roundPeriod = round_period;
|
||||
|
||||
auto earlyStop = createPredictionEarlyStopInstance(type, config);
|
||||
|
||||
// create new by copying
|
||||
*out = new PredictionEarlyStopInstance(earlyStop);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle)
|
||||
{
|
||||
API_BEGIN();
|
||||
delete reinterpret_cast<const PredictionEarlyStopInstance*>(handle);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// ---- start of some help functions
|
||||
|
||||
std::function<std::vector<double>(int row_idx)>
|
||||
|
@ -1146,7 +1138,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
|
|||
if (data_type == C_API_DTYPE_FLOAT32) {
|
||||
const float* data_ptr = reinterpret_cast<const float*>(data);
|
||||
if (is_row_major) {
|
||||
return [data_ptr, num_col, num_row](int row_idx) {
|
||||
return [data_ptr, num_col, num_row] (int row_idx) {
|
||||
std::vector<double> ret(num_col);
|
||||
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
|
||||
for (int i = 0; i < num_col; ++i) {
|
||||
|
@ -1158,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
|
|||
return ret;
|
||||
};
|
||||
} else {
|
||||
return [data_ptr, num_col, num_row](int row_idx) {
|
||||
return [data_ptr, num_col, num_row] (int row_idx) {
|
||||
std::vector<double> ret(num_col);
|
||||
for (int i = 0; i < num_col; ++i) {
|
||||
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
|
||||
|
@ -1172,7 +1164,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
|
|||
} else if (data_type == C_API_DTYPE_FLOAT64) {
|
||||
const double* data_ptr = reinterpret_cast<const double*>(data);
|
||||
if (is_row_major) {
|
||||
return [data_ptr, num_col, num_row](int row_idx) {
|
||||
return [data_ptr, num_col, num_row] (int row_idx) {
|
||||
std::vector<double> ret(num_col);
|
||||
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
|
||||
for (int i = 0; i < num_col; ++i) {
|
||||
|
@ -1184,7 +1176,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
|
|||
return ret;
|
||||
};
|
||||
} else {
|
||||
return [data_ptr, num_col, num_row](int row_idx) {
|
||||
return [data_ptr, num_col, num_row] (int row_idx) {
|
||||
std::vector<double> ret(num_col);
|
||||
for (int i = 0; i < num_col; ++i) {
|
||||
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
|
||||
|
@ -1203,7 +1195,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)>
|
|||
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
|
||||
auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
|
||||
if (inner_function != nullptr) {
|
||||
return [inner_function](int row_idx) {
|
||||
return [inner_function] (int row_idx) {
|
||||
auto raw_values = inner_function(row_idx);
|
||||
std::vector<std::pair<int, double>> ret;
|
||||
for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
|
||||
|
@ -1223,7 +1215,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
|
|||
const float* data_ptr = reinterpret_cast<const float*>(data);
|
||||
if (indptr_type == C_API_DTYPE_INT32) {
|
||||
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
|
||||
std::vector<std::pair<int, double>> ret;
|
||||
int64_t start = ptr_indptr[idx];
|
||||
int64_t end = ptr_indptr[idx + 1];
|
||||
|
@ -1236,7 +1228,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
|
|||
};
|
||||
} else if (indptr_type == C_API_DTYPE_INT64) {
|
||||
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
|
||||
std::vector<std::pair<int, double>> ret;
|
||||
int64_t start = ptr_indptr[idx];
|
||||
int64_t end = ptr_indptr[idx + 1];
|
||||
|
@ -1252,7 +1244,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
|
|||
const double* data_ptr = reinterpret_cast<const double*>(data);
|
||||
if (indptr_type == C_API_DTYPE_INT32) {
|
||||
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
|
||||
std::vector<std::pair<int, double>> ret;
|
||||
int64_t start = ptr_indptr[idx];
|
||||
int64_t end = ptr_indptr[idx + 1];
|
||||
|
@ -1265,7 +1257,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
|
|||
};
|
||||
} else if (indptr_type == C_API_DTYPE_INT64) {
|
||||
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
|
||||
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
|
||||
std::vector<std::pair<int, double>> ret;
|
||||
int64_t start = ptr_indptr[idx];
|
||||
int64_t end = ptr_indptr[idx + 1];
|
||||
|
@ -1290,7 +1282,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
|
|||
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
|
||||
int64_t start = ptr_col_ptr[col_idx];
|
||||
int64_t end = ptr_col_ptr[col_idx + 1];
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
|
||||
int64_t i = static_cast<int64_t>(start + bias);
|
||||
if (i >= end) {
|
||||
return std::make_pair(-1, 0.0);
|
||||
|
@ -1304,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
|
|||
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
|
||||
int64_t start = ptr_col_ptr[col_idx];
|
||||
int64_t end = ptr_col_ptr[col_idx + 1];
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
|
||||
int64_t i = static_cast<int64_t>(start + bias);
|
||||
if (i >= end) {
|
||||
return std::make_pair(-1, 0.0);
|
||||
|
@ -1321,7 +1313,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
|
|||
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
|
||||
int64_t start = ptr_col_ptr[col_idx];
|
||||
int64_t end = ptr_col_ptr[col_idx + 1];
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
|
||||
int64_t i = static_cast<int64_t>(start + bias);
|
||||
if (i >= end) {
|
||||
return std::make_pair(-1, 0.0);
|
||||
|
@ -1335,7 +1327,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
|
|||
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
|
||||
int64_t start = ptr_col_ptr[col_idx];
|
||||
int64_t end = ptr_col_ptr[col_idx + 1];
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
|
||||
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
|
||||
int64_t i = static_cast<int64_t>(start + bias);
|
||||
if (i >= end) {
|
||||
return std::make_pair(-1, 0.0);
|
||||
|
|
|
@ -230,6 +230,11 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
|
|||
CHECK(min_data_in_bin > 0);
|
||||
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
|
||||
GetBool(params, "enable_bundle", &enable_bundle);
|
||||
|
||||
GetBool(params, "pred_early_stop", &pred_early_stop);
|
||||
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
|
||||
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
|
||||
|
||||
GetDeviceType(params);
|
||||
}
|
||||
|
||||
|
|
|
@ -129,6 +129,8 @@ public:
|
|||
|
||||
bool SkipEmptyClass() const override { return true; }
|
||||
|
||||
bool NeedAccuratePrediction() const override { return false; }
|
||||
|
||||
private:
|
||||
/*! \brief Number of data */
|
||||
data_size_t num_data_;
|
||||
|
|
|
@ -118,6 +118,8 @@ public:
|
|||
|
||||
int NumPredictOneRow() const override { return num_class_; }
|
||||
|
||||
bool NeedAccuratePrediction() const override { return false; }
|
||||
|
||||
private:
|
||||
/*! \brief Number of data */
|
||||
data_size_t num_data_;
|
||||
|
@ -208,6 +210,8 @@ public:
|
|||
|
||||
int NumPredictOneRow() const override { return num_class_; }
|
||||
|
||||
bool NeedAccuratePrediction() const override { return false; }
|
||||
|
||||
private:
|
||||
/*! \brief Number of data */
|
||||
data_size_t num_data_;
|
||||
|
|
|
@ -207,6 +207,8 @@ public:
|
|||
return str_buf.str();
|
||||
}
|
||||
|
||||
bool NeedAccuratePrediction() const override { return false; }
|
||||
|
||||
private:
|
||||
/*! \brief Gains for labels */
|
||||
std::vector<double> label_gain_;
|
||||
|
|
|
@ -228,8 +228,8 @@ def test_booster():
|
|||
1,
|
||||
1,
|
||||
50,
|
||||
ctypes.c_void_p(),
|
||||
c_str(''),
|
||||
ctypes.byref(num_preb),
|
||||
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
|
||||
LIB.LGBM_BoosterPredictForFile(booster2, c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')), 0, 0, 50, ctypes.c_void_p(), c_str('preb.txt'))
|
||||
LIB.LGBM_BoosterPredictForFile(booster2, c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')), 0, 0, 50, c_str(''), c_str('preb.txt'))
|
||||
LIB.LGBM_BoosterFree(booster2)
|
||||
|
|
|
@ -54,8 +54,8 @@ class TestBasic(unittest.TestCase):
|
|||
self.assertEqual(*preds)
|
||||
|
||||
# check early stopping is working. Make it stop very early, so the scores should be very close to zero
|
||||
estop = lgb.PredictionEarlyStopInstance("binary", round_period=5, margin_threshold=1.5)
|
||||
pred_early_stopping = bst.predict(X_test, early_stop_instance=estop)
|
||||
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
|
||||
pred_early_stopping = bst.predict(X_test, pred_parameter=pred_parameter)
|
||||
self.assertEqual(len(pred_from_matr), len(pred_early_stopping))
|
||||
for preds in zip(pred_early_stopping, pred_from_matr):
|
||||
# scores likely to be different, but prediction should still be the same
|
||||
|
|
|
@ -108,13 +108,13 @@ class TestEngine(unittest.TestCase):
|
|||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
|
||||
estop = lgb.PredictionEarlyStopInstance("multiclass", round_period=5, margin_threshold=1.5)
|
||||
ret = multi_logloss(y_test, gbm.predict(X_test, early_stop_instance=estop))
|
||||
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
|
||||
ret = multi_logloss(y_test, gbm.predict(X_test, pred_parameter=pred_parameter))
|
||||
self.assertLess(ret, 0.8)
|
||||
self.assertGreater(ret, 0.5) # loss will be higher than when evaluating the full model
|
||||
|
||||
estop = lgb.PredictionEarlyStopInstance("multiclass", round_period=5, margin_threshold=5.5)
|
||||
ret = multi_logloss(y_test, gbm.predict(X_test, early_stop_instance=estop))
|
||||
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 5.5}
|
||||
ret = multi_logloss(y_test, gbm.predict(X_test, pred_parameter=pred_parameter))
|
||||
self.assertLess(ret, 0.2)
|
||||
|
||||
def test_early_stopping(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче