зеркало из https://github.com/microsoft/LightGBM.git
[python] remove `verbose_eval` argument of `train()` and `cv()` functions (#4878)
* remove `verbose_eval` argument * update example Notebook
This commit is contained in:
Родитель
8066261899
Коммит
9f13a9c897
|
@ -149,7 +149,7 @@
|
|||
" feature_name=[f'f{i + 1}' for i in range(X_train.shape[-1])],\n",
|
||||
" categorical_feature=[21],\n",
|
||||
" evals_result=evals_result,\n",
|
||||
" verbose_eval=10)"
|
||||
" callbacks=[lgb.log_evaluation(10)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -43,7 +43,7 @@ gbm = lgb.train(params,
|
|||
feature_name=[f'f{i + 1}' for i in range(X_train.shape[-1])],
|
||||
categorical_feature=[21],
|
||||
evals_result=evals_result,
|
||||
verbose_eval=10)
|
||||
callbacks=[lgb.log_evaluation(10)])
|
||||
|
||||
print('Plotting metrics recorded during training...')
|
||||
ax = lgb.plot_metric(evals_result, metric='l1')
|
||||
|
|
|
@ -35,7 +35,6 @@ def train(
|
|||
categorical_feature: Union[List[str], List[int], str] = 'auto',
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
evals_result: Optional[Dict[str, Any]] = None,
|
||||
verbose_eval: Union[bool, int, str] = 'warn',
|
||||
keep_training_booster: bool = False,
|
||||
callbacks: Optional[List[Callable]] = None
|
||||
) -> Booster:
|
||||
|
@ -133,17 +132,6 @@ def train(
|
|||
returns {'train': {'logloss': ['0.48253', '0.35953', ...]},
|
||||
'eval': {'logloss': ['0.480385', '0.357756', ...]}}.
|
||||
|
||||
verbose_eval : bool or int, optional (default=True)
|
||||
Requires at least one validation data.
|
||||
If True, the eval metric on the valid set is printed at each boosting stage.
|
||||
If int, the eval metric on the valid set is printed at every ``verbose_eval`` boosting stage.
|
||||
The last boosting stage or the boosting stage found by using ``early_stopping_rounds`` is also printed.
|
||||
|
||||
.. rubric:: Example
|
||||
|
||||
With ``verbose_eval`` = 4 and at least one item in ``valid_sets``,
|
||||
an evaluation metric is printed every 4 (instead of 1) boosting stages.
|
||||
|
||||
keep_training_booster : bool, optional (default=False)
|
||||
Whether the returned Booster will be used to keep training.
|
||||
If False, the returned value will be converted into _InnerPredictor before returning.
|
||||
|
@ -230,21 +218,8 @@ def train(
|
|||
callbacks_set = set(callbacks)
|
||||
|
||||
# Most of legacy advanced options becomes callbacks
|
||||
if verbose_eval != "warn":
|
||||
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
|
||||
"Pass 'log_evaluation()' callback via 'callbacks' argument instead.")
|
||||
else:
|
||||
if callbacks_set: # assume user has already specified log_evaluation callback
|
||||
verbose_eval = False
|
||||
else:
|
||||
verbose_eval = True
|
||||
if verbose_eval is True:
|
||||
callbacks_set.add(callback.log_evaluation())
|
||||
elif isinstance(verbose_eval, int):
|
||||
callbacks_set.add(callback.log_evaluation(verbose_eval))
|
||||
|
||||
if early_stopping_rounds is not None and early_stopping_rounds > 0:
|
||||
callbacks_set.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
|
||||
callbacks_set.add(callback.early_stopping(early_stopping_rounds, first_metric_only))
|
||||
|
||||
if evals_result is not None:
|
||||
_log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. "
|
||||
|
@ -426,8 +401,7 @@ def cv(params, train_set, num_boost_round=100,
|
|||
metrics=None, fobj=None, feval=None, init_model=None,
|
||||
feature_name='auto', categorical_feature='auto',
|
||||
early_stopping_rounds=None, fpreproc=None,
|
||||
verbose_eval=None, show_stdv=True, seed=0,
|
||||
callbacks=None, eval_train_metric=False,
|
||||
seed=0, callbacks=None, eval_train_metric=False,
|
||||
return_cvbooster=False):
|
||||
"""Perform the cross-validation with given parameters.
|
||||
|
||||
|
@ -522,13 +496,6 @@ def cv(params, train_set, num_boost_round=100,
|
|||
fpreproc : callable or None, optional (default=None)
|
||||
Preprocessing function that takes (dtrain, dtest, params)
|
||||
and returns transformed versions of those.
|
||||
verbose_eval : bool, int, or None, optional (default=None)
|
||||
Whether to display the progress.
|
||||
If True, progress will be displayed at every boosting stage.
|
||||
If int, progress will be displayed at every given ``verbose_eval`` boosting stage.
|
||||
show_stdv : bool, optional (default=True)
|
||||
Whether to display the standard deviation in progress.
|
||||
Results are not affected by this parameter, and always contain std.
|
||||
seed : int, optional (default=0)
|
||||
Seed used to generate the folds (passed to numpy.random.seed).
|
||||
callbacks : list of callable, or None, optional (default=None)
|
||||
|
@ -606,13 +573,6 @@ def cv(params, train_set, num_boost_round=100,
|
|||
callbacks = set(callbacks)
|
||||
if early_stopping_rounds is not None and early_stopping_rounds > 0:
|
||||
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False))
|
||||
if verbose_eval is not None:
|
||||
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
|
||||
"Pass 'log_evaluation()' callback via 'callbacks' argument instead.")
|
||||
if verbose_eval is True:
|
||||
callbacks.add(callback.log_evaluation(show_stdv=show_stdv))
|
||||
elif isinstance(verbose_eval, int):
|
||||
callbacks.add(callback.log_evaluation(verbose_eval, show_stdv=show_stdv))
|
||||
|
||||
callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
|
||||
callbacks_after_iter = callbacks - callbacks_before_iter
|
||||
|
|
|
@ -65,7 +65,6 @@ def test_binary():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=20,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = log_loss(y_test, gbm.predict(X_test))
|
||||
assert ret < 0.14
|
||||
|
@ -92,7 +91,6 @@ def test_rf():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=50,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = log_loss(y_test, gbm.predict(X_test))
|
||||
assert ret < 0.19
|
||||
|
@ -112,7 +110,6 @@ def test_regression():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=50,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = mean_squared_error(y_test, gbm.predict(X_test))
|
||||
assert ret < 7
|
||||
|
@ -138,7 +135,6 @@ def test_missing_value_handle():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=20,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = mean_squared_error(y_train, gbm.predict(X_train))
|
||||
assert ret < 0.005
|
||||
|
@ -164,7 +160,6 @@ def test_missing_value_handle_more_na():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=20,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = mean_squared_error(y_train, gbm.predict(X_train))
|
||||
assert ret < 0.005
|
||||
|
@ -195,7 +190,6 @@ def test_missing_value_handle_na():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=1,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
pred = gbm.predict(X_train)
|
||||
np.testing.assert_allclose(pred, y)
|
||||
|
@ -228,7 +222,6 @@ def test_missing_value_handle_zero():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=1,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
pred = gbm.predict(X_train)
|
||||
np.testing.assert_allclose(pred, y)
|
||||
|
@ -261,7 +254,6 @@ def test_missing_value_handle_none():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=1,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
pred = gbm.predict(X_train)
|
||||
assert pred[0] == pytest.approx(pred[1])
|
||||
|
@ -300,7 +292,6 @@ def test_categorical_handle():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=1,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
pred = gbm.predict(X_train)
|
||||
np.testing.assert_allclose(pred, y)
|
||||
|
@ -338,7 +329,6 @@ def test_categorical_handle_na():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=1,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
pred = gbm.predict(X_train)
|
||||
np.testing.assert_allclose(pred, y)
|
||||
|
@ -376,7 +366,6 @@ def test_categorical_non_zero_inputs():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=1,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
pred = gbm.predict(X_train)
|
||||
np.testing.assert_allclose(pred, y)
|
||||
|
@ -400,7 +389,6 @@ def test_multiclass():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=50,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = multi_logloss(y_test, gbm.predict(X_test))
|
||||
assert ret < 0.16
|
||||
|
@ -429,7 +417,6 @@ def test_multiclass_rf():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=50,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = multi_logloss(y_test, gbm.predict(X_test))
|
||||
assert ret < 0.23
|
||||
|
@ -470,7 +457,7 @@ def test_multi_class_error():
|
|||
predict_default = est.predict(X)
|
||||
results = {}
|
||||
est = lgb.train(dict(params, multi_error_top_k=1), lgb_data, num_boost_round=10,
|
||||
valid_sets=[lgb_data], evals_result=results, verbose_eval=False)
|
||||
valid_sets=[lgb_data], evals_result=results)
|
||||
predict_1 = est.predict(X)
|
||||
# check that default gives same result as k = 1
|
||||
np.testing.assert_allclose(predict_1, predict_default)
|
||||
|
@ -480,14 +467,14 @@ def test_multi_class_error():
|
|||
# check against independent calculation for k = 2
|
||||
results = {}
|
||||
est = lgb.train(dict(params, multi_error_top_k=2), lgb_data, num_boost_round=10,
|
||||
valid_sets=[lgb_data], evals_result=results, verbose_eval=False)
|
||||
valid_sets=[lgb_data], evals_result=results)
|
||||
predict_2 = est.predict(X)
|
||||
err = top_k_error(y, predict_2, 2)
|
||||
assert results['training']['multi_error@2'][-1] == pytest.approx(err)
|
||||
# check against independent calculation for k = 10
|
||||
results = {}
|
||||
est = lgb.train(dict(params, multi_error_top_k=10), lgb_data, num_boost_round=10,
|
||||
valid_sets=[lgb_data], evals_result=results, verbose_eval=False)
|
||||
valid_sets=[lgb_data], evals_result=results)
|
||||
predict_3 = est.predict(X)
|
||||
err = top_k_error(y, predict_3, 10)
|
||||
assert results['training']['multi_error@10'][-1] == pytest.approx(err)
|
||||
|
@ -498,11 +485,11 @@ def test_multi_class_error():
|
|||
params['num_classes'] = 2
|
||||
results = {}
|
||||
lgb.train(params, lgb_data, num_boost_round=10,
|
||||
valid_sets=[lgb_data], evals_result=results, verbose_eval=False)
|
||||
valid_sets=[lgb_data], evals_result=results)
|
||||
assert results['training']['multi_error'][-1] == pytest.approx(1)
|
||||
results = {}
|
||||
lgb.train(dict(params, multi_error_top_k=2), lgb_data, num_boost_round=10,
|
||||
valid_sets=[lgb_data], evals_result=results, verbose_eval=False)
|
||||
valid_sets=[lgb_data], evals_result=results)
|
||||
assert results['training']['multi_error@2'][-1] == pytest.approx(0)
|
||||
|
||||
|
||||
|
@ -626,7 +613,6 @@ def test_early_stopping():
|
|||
num_boost_round=10,
|
||||
valid_sets=lgb_eval,
|
||||
valid_names=valid_set_name,
|
||||
verbose_eval=False,
|
||||
early_stopping_rounds=5)
|
||||
assert gbm.best_iteration == 10
|
||||
assert valid_set_name in gbm.best_score
|
||||
|
@ -636,7 +622,6 @@ def test_early_stopping():
|
|||
num_boost_round=40,
|
||||
valid_sets=lgb_eval,
|
||||
valid_names=valid_set_name,
|
||||
verbose_eval=False,
|
||||
early_stopping_rounds=5)
|
||||
assert gbm.best_iteration <= 39
|
||||
assert valid_set_name in gbm.best_score
|
||||
|
@ -735,7 +720,6 @@ def test_continue_train():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=30,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
# test custom eval metrics
|
||||
feval=(lambda p, d: ('custom_mae', mean_absolute_error(p, d.get_label()), False)),
|
||||
evals_result=evals_result,
|
||||
|
@ -776,7 +760,6 @@ def test_continue_train_dart():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=50,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result,
|
||||
init_model=init_gbm)
|
||||
ret = mean_absolute_error(y_test, gbm.predict(X_test))
|
||||
|
@ -800,7 +783,6 @@ def test_continue_train_multiclass():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=30,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result,
|
||||
init_model=init_gbm)
|
||||
ret = multi_logloss(y_test, gbm.predict(X_test))
|
||||
|
@ -815,21 +797,20 @@ def test_cv():
|
|||
# shuffle = False, override metric in params
|
||||
params_with_metric = {'metric': 'l2', 'verbose': -1}
|
||||
cv_res = lgb.cv(params_with_metric, lgb_train, num_boost_round=10,
|
||||
nfold=3, stratified=False, shuffle=False,
|
||||
metrics='l1', verbose_eval=False)
|
||||
nfold=3, stratified=False, shuffle=False, metrics='l1')
|
||||
assert 'l1-mean' in cv_res
|
||||
assert 'l2-mean' not in cv_res
|
||||
assert len(cv_res['l1-mean']) == 10
|
||||
# shuffle = True, callbacks
|
||||
cv_res = lgb.cv(params, lgb_train, num_boost_round=10, nfold=3, stratified=False, shuffle=True,
|
||||
metrics='l1', verbose_eval=False,
|
||||
cv_res = lgb.cv(params, lgb_train, num_boost_round=10, nfold=3,
|
||||
stratified=False, shuffle=True, metrics='l1',
|
||||
callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)])
|
||||
assert 'l1-mean' in cv_res
|
||||
assert len(cv_res['l1-mean']) == 10
|
||||
# enable display training loss
|
||||
cv_res = lgb.cv(params_with_metric, lgb_train, num_boost_round=10,
|
||||
nfold=3, stratified=False, shuffle=False,
|
||||
metrics='l1', verbose_eval=False, eval_train_metric=True)
|
||||
metrics='l1', eval_train_metric=True)
|
||||
assert 'train l1-mean' in cv_res
|
||||
assert 'valid l1-mean' in cv_res
|
||||
assert 'train l2-mean' not in cv_res
|
||||
|
@ -839,10 +820,8 @@ def test_cv():
|
|||
# self defined folds
|
||||
tss = TimeSeriesSplit(3)
|
||||
folds = tss.split(X_train)
|
||||
cv_res_gen = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=folds,
|
||||
verbose_eval=False)
|
||||
cv_res_obj = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=tss,
|
||||
verbose_eval=False)
|
||||
cv_res_gen = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=folds)
|
||||
cv_res_obj = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=tss)
|
||||
np.testing.assert_allclose(cv_res_gen['l2-mean'], cv_res_obj['l2-mean'])
|
||||
# LambdaRank
|
||||
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
|
||||
|
@ -851,19 +830,16 @@ def test_cv():
|
|||
params_lambdarank = {'objective': 'lambdarank', 'verbose': -1, 'eval_at': 3}
|
||||
lgb_train = lgb.Dataset(X_train, y_train, group=q_train)
|
||||
# ... with l2 metric
|
||||
cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3,
|
||||
metrics='l2', verbose_eval=False)
|
||||
cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, metrics='l2')
|
||||
assert len(cv_res_lambda) == 2
|
||||
assert not np.isnan(cv_res_lambda['l2-mean']).any()
|
||||
# ... with NDCG (default) metric
|
||||
cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3,
|
||||
verbose_eval=False)
|
||||
cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3)
|
||||
assert len(cv_res_lambda) == 2
|
||||
assert not np.isnan(cv_res_lambda['ndcg@3-mean']).any()
|
||||
# self defined folds with lambdarank
|
||||
cv_res_lambda_obj = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10,
|
||||
folds=GroupKFold(n_splits=3),
|
||||
verbose_eval=False)
|
||||
folds=GroupKFold(n_splits=3))
|
||||
np.testing.assert_allclose(cv_res_lambda['ndcg@3-mean'], cv_res_lambda_obj['ndcg@3-mean'])
|
||||
|
||||
|
||||
|
@ -880,7 +856,6 @@ def test_cvbooster():
|
|||
cv_res = lgb.cv(params, lgb_train,
|
||||
num_boost_round=25,
|
||||
early_stopping_rounds=5,
|
||||
verbose_eval=False,
|
||||
nfold=3,
|
||||
return_cvbooster=True)
|
||||
assert 'cvbooster' in cv_res
|
||||
|
@ -901,7 +876,6 @@ def test_cvbooster():
|
|||
# without early stopping
|
||||
cv_res = lgb.cv(params, lgb_train,
|
||||
num_boost_round=20,
|
||||
verbose_eval=False,
|
||||
nfold=3,
|
||||
return_cvbooster=True)
|
||||
cvb = cv_res['cvbooster']
|
||||
|
@ -1099,7 +1073,7 @@ def test_reference_chain():
|
|||
evals_result = {}
|
||||
lgb.train(params, tmp_dat_train, num_boost_round=20,
|
||||
valid_sets=[tmp_dat_train, tmp_dat_val],
|
||||
verbose_eval=False, evals_result=evals_result)
|
||||
evals_result=evals_result)
|
||||
assert len(evals_result['training']['rmse']) == 20
|
||||
assert len(evals_result['valid_1']['rmse']) == 20
|
||||
|
||||
|
@ -1706,14 +1680,14 @@ def test_metrics():
|
|||
params_metric_none_verbose = {'metric': 'None', 'verbose': -1}
|
||||
|
||||
def get_cv_result(params=params_obj_verbose, **kwargs):
|
||||
return lgb.cv(params, lgb_train, num_boost_round=2, verbose_eval=False, **kwargs)
|
||||
return lgb.cv(params, lgb_train, num_boost_round=2, **kwargs)
|
||||
|
||||
def train_booster(params=params_obj_verbose, **kwargs):
|
||||
lgb.train(params, lgb_train,
|
||||
num_boost_round=2,
|
||||
valid_sets=[lgb_valid],
|
||||
evals_result=evals_result,
|
||||
verbose_eval=False, **kwargs)
|
||||
**kwargs)
|
||||
|
||||
# no fobj, no feval
|
||||
# default metric
|
||||
|
@ -2248,7 +2222,7 @@ def test_early_stopping_for_only_first_metric():
|
|||
}
|
||||
gbm = lgb.train(dict(params, first_metric_only=first_metric_only), lgb_train,
|
||||
num_boost_round=25, valid_sets=valid_sets, feval=feval,
|
||||
early_stopping_rounds=5, verbose_eval=False)
|
||||
early_stopping_rounds=5)
|
||||
assert assumed_iteration == gbm.best_iteration
|
||||
|
||||
def metrics_combination_cv_regression(metric_list, assumed_iteration,
|
||||
|
@ -2265,7 +2239,7 @@ def test_early_stopping_for_only_first_metric():
|
|||
ret = lgb.cv(dict(params, first_metric_only=first_metric_only),
|
||||
train_set=lgb_train, num_boost_round=25,
|
||||
stratified=False, feval=feval,
|
||||
early_stopping_rounds=5, verbose_eval=False,
|
||||
early_stopping_rounds=5,
|
||||
eval_train_metric=eval_train_metric)
|
||||
assert assumed_iteration == len(ret[list(ret.keys())[0]])
|
||||
|
||||
|
@ -2363,7 +2337,6 @@ def test_node_level_subcol():
|
|||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=25,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
ret = log_loss(y_test, gbm.predict(X_test))
|
||||
assert ret < 0.14
|
||||
|
|
|
@ -198,8 +198,7 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
|
|||
valid_sets=[train_data, test_data],
|
||||
valid_names=['v1', 'v2'],
|
||||
num_boost_round=10,
|
||||
evals_result=evals_result0,
|
||||
verbose_eval=False)
|
||||
evals_result=evals_result0)
|
||||
with pytest.warns(UserWarning, match="More than one metric available, picking one to plot."):
|
||||
ax0 = lgb.plot_metric(evals_result0)
|
||||
assert isinstance(ax0, matplotlib.axes.Axes)
|
||||
|
@ -259,8 +258,7 @@ def test_plot_metrics(params, breast_cancer_split, train_data):
|
|||
evals_result1 = {}
|
||||
lgb.train(params, train_data,
|
||||
num_boost_round=10,
|
||||
evals_result=evals_result1,
|
||||
verbose_eval=False)
|
||||
evals_result=evals_result1)
|
||||
with pytest.raises(ValueError, match="eval results cannot be empty."):
|
||||
lgb.plot_metric(evals_result1)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче