зеркало из https://github.com/microsoft/LightGBM.git
ignore unknown parameters when loading from model file (#6126)
This commit is contained in:
Родитель
8f577de01f
Коммит
b793cd821c
|
@ -179,15 +179,20 @@ class GBDT : public GBDTBase {
|
|||
const auto pair = Common::Split(line.c_str(), ":");
|
||||
if (pair[1] == " ]")
|
||||
continue;
|
||||
const auto param = pair[0].substr(1);
|
||||
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
|
||||
auto iter = param_types.find(param);
|
||||
if (iter == param_types.end()) {
|
||||
Log::Warning("Ignoring unrecognized parameter '%s' found in model string.", param.c_str());
|
||||
continue;
|
||||
}
|
||||
std::string param_type = iter->second;
|
||||
if (first) {
|
||||
first = false;
|
||||
str_buf << "\"";
|
||||
} else {
|
||||
str_buf << ",\"";
|
||||
}
|
||||
const auto param = pair[0].substr(1);
|
||||
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
|
||||
const auto param_type = param_types.at(param);
|
||||
str_buf << param << "\": ";
|
||||
if (param_type == "string") {
|
||||
str_buf << "\"" << value_str << "\"";
|
||||
|
|
|
@ -1470,7 +1470,7 @@ def test_feature_name_with_non_ascii():
|
|||
assert feature_names == gbm2.feature_name()
|
||||
|
||||
|
||||
def test_parameters_are_loaded_from_model_file(tmp_path):
|
||||
def test_parameters_are_loaded_from_model_file(tmp_path, capsys):
|
||||
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))])
|
||||
y = np.random.rand(100)
|
||||
ds = lgb.Dataset(X, y)
|
||||
|
@ -1487,8 +1487,18 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
|
|||
'num_threads': 1,
|
||||
}
|
||||
model_file = tmp_path / 'model.txt'
|
||||
lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file)
|
||||
orig_bst = lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2])
|
||||
orig_bst.save_model(model_file)
|
||||
with model_file.open('rt') as f:
|
||||
model_contents = f.readlines()
|
||||
params_start = model_contents.index('parameters:\n')
|
||||
model_contents.insert(params_start + 1, '[max_conflict_rate: 0]\n')
|
||||
with model_file.open('wt') as f:
|
||||
f.writelines(model_contents)
|
||||
bst = lgb.Booster(model_file=model_file)
|
||||
expected_msg = "[LightGBM] [Warning] Ignoring unrecognized parameter 'max_conflict_rate' found in model string."
|
||||
stdout = capsys.readouterr().out
|
||||
assert expected_msg in stdout
|
||||
set_params = {k: bst.params[k] for k in params.keys()}
|
||||
assert set_params == params
|
||||
assert bst.params['categorical_feature'] == [1, 2]
|
||||
|
@ -1498,6 +1508,11 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
|
|||
bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file)
|
||||
assert bst.params == bst2.params
|
||||
|
||||
# check inference isn't affected by unknown parameter
|
||||
orig_preds = orig_bst.predict(X)
|
||||
preds = bst.predict(X)
|
||||
np.testing.assert_allclose(preds, orig_preds)
|
||||
|
||||
|
||||
def test_save_load_copy_pickle():
|
||||
def train_and_predict(init_model=None, return_model=False):
|
||||
|
|
Загрузка…
Ссылка в новой задаче