зеркало из https://github.com/microsoft/LightGBM.git
fix test for continued train, due to default saved number of model is best_iteration now
This commit is contained in:
Родитель
28972b8667
Коммит
f6024c8bd3
|
@ -17,9 +17,9 @@
|
|||
namespace LightGBM {
|
||||
|
||||
GBDT::GBDT()
|
||||
:num_iteration_for_pred_(0),
|
||||
:iter_(0),
|
||||
num_iteration_for_pred_(0),
|
||||
num_init_iteration_(0) {
|
||||
|
||||
}
|
||||
|
||||
GBDT::~GBDT() {
|
||||
|
@ -581,6 +581,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
|
|||
Log::Info("Finished loading %d models", models_.size());
|
||||
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
|
||||
num_init_iteration_ = num_iteration_for_pred_;
|
||||
iter_ = 0;
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
|
||||
|
|
|
@ -9,8 +9,7 @@ import lightgbm as lgb
|
|||
class TestBasic(unittest.TestCase):
|
||||
|
||||
def test(self):
|
||||
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1)
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
|
||||
train_data = lgb.Dataset(X_train, max_bin=255, label=y_train)
|
||||
valid_data = train_data.create_valid(X_test, label=y_test)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ def multi_logloss(y_true, y_pred):
|
|||
def test_template(params = {'objective' : 'regression', 'metric' : 'l2'},
|
||||
X_y=load_boston(True), feval=mean_squared_error,
|
||||
num_round=100, init_model=None, custom_eval=None,
|
||||
return_data=False, return_model=False):
|
||||
return_data=False, return_model=False, early_stopping_rounds=10):
|
||||
X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
|
||||
lgb_train = lgb.Dataset(X_train, y_train, params=params)
|
||||
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, params=params)
|
||||
|
@ -31,7 +31,7 @@ def test_template(params = {'objective' : 'regression', 'metric' : 'l2'},
|
|||
verbose_eval=False,
|
||||
feval=custom_eval,
|
||||
evals_result=evals_result,
|
||||
early_stopping_rounds=10,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
init_model=init_model)
|
||||
if return_model: return gbm
|
||||
else: return evals_result, feval(y_test, gbm.predict(X_test, gbm.best_iteration))
|
||||
|
@ -71,7 +71,7 @@ class TestEngine(unittest.TestCase):
|
|||
'metric' : 'l1'
|
||||
}
|
||||
model_name = 'model.txt'
|
||||
gbm = test_template(params, num_round=20, return_model=True)
|
||||
gbm = test_template(params, num_round=20, return_model=True, early_stopping_rounds=-1)
|
||||
gbm.save_model(model_name)
|
||||
evals_result, ret = test_template(params, feval=mean_absolute_error,
|
||||
num_round=80, init_model=model_name,
|
||||
|
@ -91,7 +91,7 @@ class TestEngine(unittest.TestCase):
|
|||
'metric' : 'multi_logloss',
|
||||
'num_class' : 3
|
||||
}
|
||||
gbm = test_template(params, X_y, num_round=20, return_model=True)
|
||||
gbm = test_template(params, X_y, num_round=20, return_model=True, early_stopping_rounds=-1)
|
||||
evals_result, ret = test_template(params, X_y, feval=multi_logloss,
|
||||
num_round=80, init_model=gbm)
|
||||
self.assertLess(ret, 1.5)
|
||||
|
|
Загрузка…
Ссылка в новой задаче