зеркало из https://github.com/microsoft/LightGBM.git
ignore unknown parameters when loading from model file (#6126)
This commit is contained in:
Родитель
8f577de01f
Коммит
b793cd821c
|
@ -179,15 +179,20 @@ class GBDT : public GBDTBase {
|
||||||
const auto pair = Common::Split(line.c_str(), ":");
|
const auto pair = Common::Split(line.c_str(), ":");
|
||||||
if (pair[1] == " ]")
|
if (pair[1] == " ]")
|
||||||
continue;
|
continue;
|
||||||
|
const auto param = pair[0].substr(1);
|
||||||
|
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
|
||||||
|
auto iter = param_types.find(param);
|
||||||
|
if (iter == param_types.end()) {
|
||||||
|
Log::Warning("Ignoring unrecognized parameter '%s' found in model string.", param.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string param_type = iter->second;
|
||||||
if (first) {
|
if (first) {
|
||||||
first = false;
|
first = false;
|
||||||
str_buf << "\"";
|
str_buf << "\"";
|
||||||
} else {
|
} else {
|
||||||
str_buf << ",\"";
|
str_buf << ",\"";
|
||||||
}
|
}
|
||||||
const auto param = pair[0].substr(1);
|
|
||||||
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
|
|
||||||
const auto param_type = param_types.at(param);
|
|
||||||
str_buf << param << "\": ";
|
str_buf << param << "\": ";
|
||||||
if (param_type == "string") {
|
if (param_type == "string") {
|
||||||
str_buf << "\"" << value_str << "\"";
|
str_buf << "\"" << value_str << "\"";
|
||||||
|
|
|
@ -1470,7 +1470,7 @@ def test_feature_name_with_non_ascii():
|
||||||
assert feature_names == gbm2.feature_name()
|
assert feature_names == gbm2.feature_name()
|
||||||
|
|
||||||
|
|
||||||
def test_parameters_are_loaded_from_model_file(tmp_path):
|
def test_parameters_are_loaded_from_model_file(tmp_path, capsys):
|
||||||
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))])
|
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))])
|
||||||
y = np.random.rand(100)
|
y = np.random.rand(100)
|
||||||
ds = lgb.Dataset(X, y)
|
ds = lgb.Dataset(X, y)
|
||||||
|
@ -1487,8 +1487,18 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
|
||||||
'num_threads': 1,
|
'num_threads': 1,
|
||||||
}
|
}
|
||||||
model_file = tmp_path / 'model.txt'
|
model_file = tmp_path / 'model.txt'
|
||||||
lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file)
|
orig_bst = lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2])
|
||||||
|
orig_bst.save_model(model_file)
|
||||||
|
with model_file.open('rt') as f:
|
||||||
|
model_contents = f.readlines()
|
||||||
|
params_start = model_contents.index('parameters:\n')
|
||||||
|
model_contents.insert(params_start + 1, '[max_conflict_rate: 0]\n')
|
||||||
|
with model_file.open('wt') as f:
|
||||||
|
f.writelines(model_contents)
|
||||||
bst = lgb.Booster(model_file=model_file)
|
bst = lgb.Booster(model_file=model_file)
|
||||||
|
expected_msg = "[LightGBM] [Warning] Ignoring unrecognized parameter 'max_conflict_rate' found in model string."
|
||||||
|
stdout = capsys.readouterr().out
|
||||||
|
assert expected_msg in stdout
|
||||||
set_params = {k: bst.params[k] for k in params.keys()}
|
set_params = {k: bst.params[k] for k in params.keys()}
|
||||||
assert set_params == params
|
assert set_params == params
|
||||||
assert bst.params['categorical_feature'] == [1, 2]
|
assert bst.params['categorical_feature'] == [1, 2]
|
||||||
|
@ -1498,6 +1508,11 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
|
||||||
bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file)
|
bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file)
|
||||||
assert bst.params == bst2.params
|
assert bst.params == bst2.params
|
||||||
|
|
||||||
|
# check inference isn't affected by unknown parameter
|
||||||
|
orig_preds = orig_bst.predict(X)
|
||||||
|
preds = bst.predict(X)
|
||||||
|
np.testing.assert_allclose(preds, orig_preds)
|
||||||
|
|
||||||
|
|
||||||
def test_save_load_copy_pickle():
|
def test_save_load_copy_pickle():
|
||||||
def train_and_predict(init_model=None, return_model=False):
|
def train_and_predict(init_model=None, return_model=False):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче