ignore unknown parameters when loading from model file (#6126)

This commit is contained in:
José Morales 2023-10-06 21:43:40 -06:00 коммит произвёл GitHub
Родитель 8f577de01f
Коммит b793cd821c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 25 добавлений и 5 удалений

Просмотреть файл

@ -179,15 +179,20 @@ class GBDT : public GBDTBase {
const auto pair = Common::Split(line.c_str(), ":");
if (pair[1] == " ]")
continue;
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
auto iter = param_types.find(param);
if (iter == param_types.end()) {
Log::Warning("Ignoring unrecognized parameter '%s' found in model string.", param.c_str());
continue;
}
std::string param_type = iter->second;
if (first) {
first = false;
str_buf << "\"";
} else {
str_buf << ",\"";
}
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
const auto param_type = param_types.at(param);
str_buf << param << "\": ";
if (param_type == "string") {
str_buf << "\"" << value_str << "\"";

Просмотреть файл

@ -1470,7 +1470,7 @@ def test_feature_name_with_non_ascii():
assert feature_names == gbm2.feature_name()
def test_parameters_are_loaded_from_model_file(tmp_path):
def test_parameters_are_loaded_from_model_file(tmp_path, capsys):
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))])
y = np.random.rand(100)
ds = lgb.Dataset(X, y)
@ -1487,8 +1487,18 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
'num_threads': 1,
}
model_file = tmp_path / 'model.txt'
lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file)
orig_bst = lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2])
orig_bst.save_model(model_file)
with model_file.open('rt') as f:
model_contents = f.readlines()
params_start = model_contents.index('parameters:\n')
model_contents.insert(params_start + 1, '[max_conflict_rate: 0]\n')
with model_file.open('wt') as f:
f.writelines(model_contents)
bst = lgb.Booster(model_file=model_file)
expected_msg = "[LightGBM] [Warning] Ignoring unrecognized parameter 'max_conflict_rate' found in model string."
stdout = capsys.readouterr().out
assert expected_msg in stdout
set_params = {k: bst.params[k] for k in params.keys()}
assert set_params == params
assert bst.params['categorical_feature'] == [1, 2]
@ -1498,6 +1508,11 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file)
assert bst.params == bst2.params
# check inference isn't affected by unknown parameter
orig_preds = orig_bst.predict(X)
preds = bst.predict(X)
np.testing.assert_allclose(preds, orig_preds)
def test_save_load_copy_pickle():
def train_and_predict(init_model=None, return_model=False):