From b793cd821c1d887a74b2d680d2d299f0275fb7a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 6 Oct 2023 21:43:40 -0600 Subject: [PATCH] ignore unknown parameters when loading from model file (#6126) --- src/boosting/gbdt.h | 11 ++++++++--- tests/python_package_test/test_engine.py | 19 +++++++++++++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index e38b26be3..9ddd9c313 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -179,15 +179,20 @@ class GBDT : public GBDTBase { const auto pair = Common::Split(line.c_str(), ":"); if (pair[1] == " ]") continue; + const auto param = pair[0].substr(1); + const auto value_str = pair[1].substr(1, pair[1].size() - 2); + auto iter = param_types.find(param); + if (iter == param_types.end()) { + Log::Warning("Ignoring unrecognized parameter '%s' found in model string.", param.c_str()); + continue; + } + std::string param_type = iter->second; if (first) { first = false; str_buf << "\""; } else { str_buf << ",\""; } - const auto param = pair[0].substr(1); - const auto value_str = pair[1].substr(1, pair[1].size() - 2); - const auto param_type = param_types.at(param); str_buf << param << "\": "; if (param_type == "string") { str_buf << "\"" << value_str << "\""; diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 2f592d43b..9a6341650 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1470,7 +1470,7 @@ def test_feature_name_with_non_ascii(): assert feature_names == gbm2.feature_name() -def test_parameters_are_loaded_from_model_file(tmp_path): +def test_parameters_are_loaded_from_model_file(tmp_path, capsys): X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))]) y = np.random.rand(100) ds = lgb.Dataset(X, y) @@ -1487,8 +1487,18 @@ def test_parameters_are_loaded_from_model_file(tmp_path): 'num_threads': 1, } model_file = tmp_path / 'model.txt' - lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file) + orig_bst = lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]) + orig_bst.save_model(model_file) + with model_file.open('rt') as f: + model_contents = f.readlines() + params_start = model_contents.index('parameters:\n') + model_contents.insert(params_start + 1, '[max_conflict_rate: 0]\n') + with model_file.open('wt') as f: + f.writelines(model_contents) bst = lgb.Booster(model_file=model_file) + expected_msg = "[LightGBM] [Warning] Ignoring unrecognized parameter 'max_conflict_rate' found in model string." + stdout = capsys.readouterr().out + assert expected_msg in stdout set_params = {k: bst.params[k] for k in params.keys()} assert set_params == params assert bst.params['categorical_feature'] == [1, 2] @@ -1498,6 +1508,11 @@ def test_parameters_are_loaded_from_model_file(tmp_path): bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file) assert bst.params == bst2.params + # check inference isn't affected by unknown parameter + orig_preds = orig_bst.predict(X) + preds = bst.predict(X) + np.testing.assert_allclose(preds, orig_preds) + def test_save_load_copy_pickle(): def train_and_predict(init_model=None, return_model=False):