зеркало из https://github.com/microsoft/LightGBM.git
Add capability to get possible max and min values for a model (#2737)
* Add capability to get possible max and min values for a model * Change implementation to have return value in tree.cpp, change naming to upper and lower bound, move implementation to gdbt.cpp * Update include/LightGBM/c_api.h Co-Authored-By: Nikita Titov <nekit94-08@mail.ru> * Change iteration to avoid potential overflow, add bindings to R and Python and a basic test * Adjust test values * Consider const correctness and multithreading protection * Update test values * Update test values * Add test to check that model is exactly the same in all platforms * Try to parse the model to get the expected values * Try to parse the model to get the expected values * Fix implementation, num_leaves can be lower than the leaf_value_ size * Do not check for num_leaves to be smaller than actual size and get back to test with hardcoded value * Change test order * Add gpu_use_dp option in test * Remove helper test method * Update src/c_api.cpp Co-Authored-By: Nikita Titov <nekit94-08@mail.ru> * Update src/io/tree.cpp Co-Authored-By: Nikita Titov <nekit94-08@mail.ru> * Update src/io/tree.cpp Co-Authored-By: Nikita Titov <nekit94-08@mail.ru> * Update tests/python_package_test/test_basic.py Co-Authored-By: Nikita Titov <nekit94-08@mail.ru> * Remoove imports Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
25d149d8ce
Коммит
18e7de4f5d
|
@ -321,6 +321,30 @@ Booster <- R6::R6Class(
|
|||
|
||||
},
|
||||
|
||||
# Get upper bound
|
||||
upper_bound_ = function() {
|
||||
|
||||
upper_bound <- 0L
|
||||
lgb.call(
|
||||
"LGBM_BoosterGetUpperBoundValue_R"
|
||||
, ret = upper_bound
|
||||
, private$handle
|
||||
)
|
||||
|
||||
},
|
||||
|
||||
# Get lower bound
|
||||
lower_bound_ = function() {
|
||||
|
||||
lower_bound <- 0L
|
||||
lgb.call(
|
||||
"LGBM_BoosterGetLowerBoundValue_R"
|
||||
, ret = upper_bound
|
||||
, private$handle
|
||||
)
|
||||
|
||||
},
|
||||
|
||||
# Evaluate data on metrics
|
||||
eval = function(data, name, feval = NULL) {
|
||||
|
||||
|
|
|
@ -228,6 +228,18 @@ class LIGHTGBM_EXPORT Boosting {
|
|||
*/
|
||||
virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Calculate upper bound value
|
||||
* \return max possible value
|
||||
*/
|
||||
virtual double GetUpperBoundValue() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Calculate lower bound value
|
||||
* \return min possible value
|
||||
*/
|
||||
virtual double GetLowerBoundValue() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Get max feature index of this model
|
||||
* \return Max feature index of this model
|
||||
|
|
|
@ -988,6 +988,24 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle,
|
|||
int importance_type,
|
||||
double* out_results);
|
||||
|
||||
/*!
|
||||
* \brief Get model upper bound value.
|
||||
* \param handle Handle of booster
|
||||
* \param[out] out_results Result pointing to max value
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_BoosterGetUpperBoundValue(BoosterHandle handle,
|
||||
double* out_results);
|
||||
|
||||
/*!
|
||||
* \brief Get model lower bound value.
|
||||
* \param handle Handle of booster
|
||||
* \param[out] out_results Result pointing to min value
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
|
||||
double* out_results);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the network.
|
||||
* \param machines List of machines in format 'ip1:port1,ip2:port2'
|
||||
|
|
|
@ -118,6 +118,16 @@ class Tree {
|
|||
const data_size_t* used_data_indices,
|
||||
data_size_t num_data, double* score) const;
|
||||
|
||||
/*!
|
||||
* \brief Get upper bound leaf value of this tree model
|
||||
*/
|
||||
double GetUpperBoundValue() const;
|
||||
|
||||
/*!
|
||||
* \brief Get lower bound leaf value of this tree model
|
||||
*/
|
||||
double GetLowerBoundValue() const;
|
||||
|
||||
/*!
|
||||
* \brief Prediction on one record
|
||||
* \param feature_values Feature value of this record
|
||||
|
|
|
@ -2243,6 +2243,34 @@ class Booster(object):
|
|||
ctypes.byref(num_trees)))
|
||||
return num_trees.value
|
||||
|
||||
def upper_bound(self):
|
||||
"""Get upper bound value of a model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
upper_bound : double
|
||||
Upper bound value of the model.
|
||||
"""
|
||||
ret = ctypes.c_double(0)
|
||||
_safe_call(_LIB.LGBM_BoosterGetUpperBoundValue(
|
||||
self.handle,
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def lower_bound(self):
|
||||
"""Get lower bound value of a model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
lower_bound : double
|
||||
Lower bound value of the model.
|
||||
"""
|
||||
ret = ctypes.c_double(0)
|
||||
_safe_call(_LIB.LGBM_BoosterGetLowerBoundValue(
|
||||
self.handle,
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def eval(self, data, name, feval=None):
|
||||
"""Evaluate for data.
|
||||
|
||||
|
|
|
@ -667,6 +667,22 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
|
|||
}
|
||||
}
|
||||
|
||||
double GBDT::GetUpperBoundValue() const {
|
||||
double max_value = 0.0;
|
||||
for (const auto &tree: models_) {
|
||||
max_value += tree->GetUpperBoundValue();
|
||||
}
|
||||
return max_value;
|
||||
}
|
||||
|
||||
double GBDT::GetLowerBoundValue() const {
|
||||
double min_value = 0.0;
|
||||
for (const auto &tree: models_) {
|
||||
min_value += tree->GetLowerBoundValue();
|
||||
}
|
||||
return min_value;
|
||||
}
|
||||
|
||||
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
|
||||
const std::vector<const Metric*>& training_metrics) {
|
||||
if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
|
||||
|
|
|
@ -295,6 +295,18 @@ class GBDT : public GBDTBase {
|
|||
*/
|
||||
std::vector<double> FeatureImportance(int num_iteration, int importance_type) const override;
|
||||
|
||||
/*!
|
||||
* \brief Calculate upper bound value
|
||||
* \return upper bound value
|
||||
*/
|
||||
double GetUpperBoundValue() const override;
|
||||
|
||||
/*!
|
||||
* \brief Calculate lower bound value
|
||||
* \return lower bound value
|
||||
*/
|
||||
double GetLowerBoundValue() const override;
|
||||
|
||||
/*!
|
||||
* \brief Get max feature index of this model
|
||||
* \return Max feature index of this model
|
||||
|
|
|
@ -468,6 +468,16 @@ class Booster {
|
|||
return boosting_->FeatureImportance(num_iteration, importance_type);
|
||||
}
|
||||
|
||||
double UpperBoundValue() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return boosting_->GetUpperBoundValue();
|
||||
}
|
||||
|
||||
double LowerBoundValue() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return boosting_->GetLowerBoundValue();
|
||||
}
|
||||
|
||||
double GetLeafValue(int tree_idx, int leaf_idx) const {
|
||||
return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
|
||||
}
|
||||
|
@ -526,7 +536,7 @@ class Booster {
|
|||
/*! \brief Training objective function */
|
||||
std::unique_ptr<ObjectiveFunction> objective_fun_;
|
||||
/*! \brief mutex for threading safe call */
|
||||
std::mutex mutex_;
|
||||
mutable std::mutex mutex_;
|
||||
};
|
||||
|
||||
} // namespace LightGBM
|
||||
|
@ -1694,6 +1704,24 @@ int LGBM_BoosterFeatureImportance(BoosterHandle handle,
|
|||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_BoosterGetUpperBoundValue(BoosterHandle handle,
|
||||
double* out_results) {
|
||||
API_BEGIN();
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
double max_value = ref_booster->UpperBoundValue();
|
||||
*out_results = max_value;
|
||||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
|
||||
double* out_results) {
|
||||
API_BEGIN();
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
double min_value = ref_booster->LowerBoundValue();
|
||||
*out_results = min_value;
|
||||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_NetworkInit(const char* machines,
|
||||
int local_listen_port,
|
||||
int listen_time_out,
|
||||
|
|
|
@ -206,6 +206,26 @@ void Tree::AddPredictionToScore(const Dataset* data,
|
|||
|
||||
#undef PredictionFun
|
||||
|
||||
double Tree::GetUpperBoundValue() const {
|
||||
double upper_bound = leaf_value_[0];
|
||||
for (int i = 1; i < num_leaves_; ++i) {
|
||||
if (leaf_value_[i] > upper_bound) {
|
||||
upper_bound = leaf_value_[i];
|
||||
}
|
||||
}
|
||||
return upper_bound;
|
||||
}
|
||||
|
||||
double Tree::GetLowerBoundValue() const {
|
||||
double lower_bound = leaf_value_[0];
|
||||
for (int i = 1; i < num_leaves_; ++i) {
|
||||
if (leaf_value_[i] < lower_bound) {
|
||||
lower_bound = leaf_value_[i];
|
||||
}
|
||||
}
|
||||
return lower_bound;
|
||||
}
|
||||
|
||||
std::string Tree::ToString() const {
|
||||
std::stringstream str_buf;
|
||||
str_buf << "num_leaves=" << num_leaves_ << '\n';
|
||||
|
|
|
@ -422,6 +422,24 @@ LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
|
|||
R_API_END();
|
||||
}
|
||||
|
||||
LGBM_SE LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state) {
|
||||
R_API_BEGIN();
|
||||
double* ptr_ret = R_REAL_PTR(out_result);
|
||||
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
LGBM_SE LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state) {
|
||||
R_API_BEGIN();
|
||||
double* ptr_ret = R_REAL_PTR(out_result);
|
||||
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
|
||||
LGBM_SE buf_len,
|
||||
LGBM_SE actual_len,
|
||||
|
|
|
@ -26,7 +26,8 @@ class TestBasic(unittest.TestCase):
|
|||
"num_leaves": 15,
|
||||
"verbose": -1,
|
||||
"num_threads": 1,
|
||||
"max_bin": 255
|
||||
"max_bin": 255,
|
||||
"gpu_use_dp": True
|
||||
}
|
||||
bst = lgb.Booster(params, train_data)
|
||||
bst.add_valid(valid_data, "valid_1")
|
||||
|
@ -39,6 +40,8 @@ class TestBasic(unittest.TestCase):
|
|||
self.assertEqual(bst.current_iteration(), 20)
|
||||
self.assertEqual(bst.num_trees(), 20)
|
||||
self.assertEqual(bst.num_model_per_iteration(), 1)
|
||||
self.assertAlmostEqual(bst.lower_bound(), -2.9040190126976606)
|
||||
self.assertAlmostEqual(bst.upper_bound(), 3.3182142872462883)
|
||||
|
||||
bst.save_model("model.txt")
|
||||
pred_from_matr = bst.predict(X_test)
|
||||
|
|
Загрузка…
Ссылка в новой задаче