Add capability to get possible max and min values for a model (#2737)

* Add capability to get possible max and min values for a model

* Change implementation to have return value in tree.cpp, change naming to upper and lower bound, move implementation to gdbt.cpp

* Update include/LightGBM/c_api.h

Co-Authored-By: Nikita Titov <nekit94-08@mail.ru>

* Change iteration to avoid potential overflow, add bindings to R and Python and a basic test

* Adjust test values

* Consider const correctness and multithreading protection

* Update test values

* Update test values

* Add test to check that model is exactly the same in all platforms

* Try to parse the model to get the expected values

* Try to parse the model to get the expected values

* Fix implementation, num_leaves can be lower than the leaf_value_ size

* Do not check for num_leaves to be smaller than actual size and get back to test with hardcoded value

* Change test order

* Add gpu_use_dp option in test

* Remove helper test method

* Update src/c_api.cpp

Co-Authored-By: Nikita Titov <nekit94-08@mail.ru>

* Update src/io/tree.cpp

Co-Authored-By: Nikita Titov <nekit94-08@mail.ru>

* Update src/io/tree.cpp

Co-Authored-By: Nikita Titov <nekit94-08@mail.ru>

* Update tests/python_package_test/test_basic.py

Co-Authored-By: Nikita Titov <nekit94-08@mail.ru>

* Remoove imports

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Joan Fontanals 2020-02-20 02:42:36 +01:00 коммит произвёл GitHub
Родитель 25d149d8ce
Коммит 18e7de4f5d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 191 добавлений и 2 удалений

Просмотреть файл

@ -321,6 +321,30 @@ Booster <- R6::R6Class(
},
# Get upper bound
upper_bound_ = function() {
upper_bound <- 0L
lgb.call(
"LGBM_BoosterGetUpperBoundValue_R"
, ret = upper_bound
, private$handle
)
},
# Get lower bound
lower_bound_ = function() {
lower_bound <- 0L
lgb.call(
"LGBM_BoosterGetLowerBoundValue_R"
, ret = upper_bound
, private$handle
)
},
# Evaluate data on metrics
eval = function(data, name, feval = NULL) {

Просмотреть файл

@ -228,6 +228,18 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;
/*!
* \brief Calculate upper bound value
* \return max possible value
*/
virtual double GetUpperBoundValue() const = 0;
/*!
* \brief Calculate lower bound value
* \return min possible value
*/
virtual double GetLowerBoundValue() const = 0;
/*!
* \brief Get max feature index of this model
* \return Max feature index of this model

Просмотреть файл

@ -988,6 +988,24 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int importance_type,
double* out_results);
/*!
* \brief Get model upper bound value.
* \param handle Handle of booster
* \param[out] out_results Result pointing to max value
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetUpperBoundValue(BoosterHandle handle,
double* out_results);
/*!
* \brief Get model lower bound value.
* \param handle Handle of booster
* \param[out] out_results Result pointing to min value
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
double* out_results);
/*!
* \brief Initialize the network.
* \param machines List of machines in format 'ip1:port1,ip2:port2'

Просмотреть файл

@ -118,6 +118,16 @@ class Tree {
const data_size_t* used_data_indices,
data_size_t num_data, double* score) const;
/*!
* \brief Get upper bound leaf value of this tree model
*/
double GetUpperBoundValue() const;
/*!
* \brief Get lower bound leaf value of this tree model
*/
double GetLowerBoundValue() const;
/*!
* \brief Prediction on one record
* \param feature_values Feature value of this record

Просмотреть файл

@ -2243,6 +2243,34 @@ class Booster(object):
ctypes.byref(num_trees)))
return num_trees.value
def upper_bound(self):
"""Get upper bound value of a model.
Returns
-------
upper_bound : double
Upper bound value of the model.
"""
ret = ctypes.c_double(0)
_safe_call(_LIB.LGBM_BoosterGetUpperBoundValue(
self.handle,
ctypes.byref(ret)))
return ret.value
def lower_bound(self):
"""Get lower bound value of a model.
Returns
-------
lower_bound : double
Lower bound value of the model.
"""
ret = ctypes.c_double(0)
_safe_call(_LIB.LGBM_BoosterGetLowerBoundValue(
self.handle,
ctypes.byref(ret)))
return ret.value
def eval(self, data, name, feval=None):
"""Evaluate for data.

Просмотреть файл

@ -667,6 +667,22 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
}
}
double GBDT::GetUpperBoundValue() const {
double max_value = 0.0;
for (const auto &tree: models_) {
max_value += tree->GetUpperBoundValue();
}
return max_value;
}
double GBDT::GetLowerBoundValue() const {
double min_value = 0.0;
for (const auto &tree: models_) {
min_value += tree->GetLowerBoundValue();
}
return min_value;
}
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {

Просмотреть файл

@ -295,6 +295,18 @@ class GBDT : public GBDTBase {
*/
std::vector<double> FeatureImportance(int num_iteration, int importance_type) const override;
/*!
* \brief Calculate upper bound value
* \return upper bound value
*/
double GetUpperBoundValue() const override;
/*!
* \brief Calculate lower bound value
* \return lower bound value
*/
double GetLowerBoundValue() const override;
/*!
* \brief Get max feature index of this model
* \return Max feature index of this model

Просмотреть файл

@ -468,6 +468,16 @@ class Booster {
return boosting_->FeatureImportance(num_iteration, importance_type);
}
double UpperBoundValue() const {
std::lock_guard<std::mutex> lock(mutex_);
return boosting_->GetUpperBoundValue();
}
double LowerBoundValue() const {
std::lock_guard<std::mutex> lock(mutex_);
return boosting_->GetLowerBoundValue();
}
double GetLeafValue(int tree_idx, int leaf_idx) const {
return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
}
@ -526,7 +536,7 @@ class Booster {
/*! \brief Training objective function */
std::unique_ptr<ObjectiveFunction> objective_fun_;
/*! \brief mutex for threading safe call */
std::mutex mutex_;
mutable std::mutex mutex_;
};
} // namespace LightGBM
@ -1694,6 +1704,24 @@ int LGBM_BoosterFeatureImportance(BoosterHandle handle,
API_END();
}
int LGBM_BoosterGetUpperBoundValue(BoosterHandle handle,
double* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
double max_value = ref_booster->UpperBoundValue();
*out_results = max_value;
API_END();
}
int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
double* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
double min_value = ref_booster->LowerBoundValue();
*out_results = min_value;
API_END();
}
int LGBM_NetworkInit(const char* machines,
int local_listen_port,
int listen_time_out,

Просмотреть файл

@ -206,6 +206,26 @@ void Tree::AddPredictionToScore(const Dataset* data,
#undef PredictionFun
double Tree::GetUpperBoundValue() const {
double upper_bound = leaf_value_[0];
for (int i = 1; i < num_leaves_; ++i) {
if (leaf_value_[i] > upper_bound) {
upper_bound = leaf_value_[i];
}
}
return upper_bound;
}
double Tree::GetLowerBoundValue() const {
double lower_bound = leaf_value_[0];
for (int i = 1; i < num_leaves_; ++i) {
if (leaf_value_[i] < lower_bound) {
lower_bound = leaf_value_[i];
}
}
return lower_bound;
}
std::string Tree::ToString() const {
std::stringstream str_buf;
str_buf << "num_leaves=" << num_leaves_ << '\n';

Просмотреть файл

@ -422,6 +422,24 @@ LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
R_API_END();
}
LGBM_SE LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
LGBM_SE out_result,
LGBM_SE call_state) {
R_API_BEGIN();
double* ptr_ret = R_REAL_PTR(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
R_API_END();
}
LGBM_SE LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
LGBM_SE out_result,
LGBM_SE call_state) {
R_API_BEGIN();
double* ptr_ret = R_REAL_PTR(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
R_API_END();
}
LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
LGBM_SE buf_len,
LGBM_SE actual_len,

Просмотреть файл

@ -26,7 +26,8 @@ class TestBasic(unittest.TestCase):
"num_leaves": 15,
"verbose": -1,
"num_threads": 1,
"max_bin": 255
"max_bin": 255,
"gpu_use_dp": True
}
bst = lgb.Booster(params, train_data)
bst.add_valid(valid_data, "valid_1")
@ -39,6 +40,8 @@ class TestBasic(unittest.TestCase):
self.assertEqual(bst.current_iteration(), 20)
self.assertEqual(bst.num_trees(), 20)
self.assertEqual(bst.num_model_per_iteration(), 1)
self.assertAlmostEqual(bst.lower_bound(), -2.9040190126976606)
self.assertAlmostEqual(bst.upper_bound(), 3.3182142872462883)
bst.save_model("model.txt")
pred_from_matr = bst.predict(X_test)