Fix bug with interaction constraints (#3189)

* Fix bug: crashes when interaction_constraints is nonempty and not all features are used.

* Fix python lint error.
This commit is contained in:
Belinda Trotta 2020-06-28 16:46:06 +10:00 коммит произвёл GitHub
Родитель 7284946614
Коммит d563aff9a6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 11 добавлений и 2 удалений

Просмотреть файл

@ -117,7 +117,9 @@ class ColSampler {
} else {
for (int feat : allowed_features) {
int inner_feat = train_data_->InnerFeatureIndex(feat);
ret[inner_feat] = 1;
if (inner_feat >= 0) {
ret[inner_feat] = 1;
}
}
return ret;
}

Просмотреть файл

@ -2195,7 +2195,7 @@ class TestEngine(unittest.TestCase):
'seed': 0}
est = lgb.train(params, train_data, num_boost_round=10)
pred1 = est.predict(X)
est = lgb.train(dict(params, interation_constraints=[list(range(num_features))]), train_data,
est = lgb.train(dict(params, interaction_constraints=[list(range(num_features))]), train_data,
num_boost_round=10)
pred2 = est.predict(X)
np.testing.assert_allclose(pred1, pred2)
@ -2210,3 +2210,10 @@ class TestEngine(unittest.TestCase):
num_boost_round=10)
pred4 = est.predict(X)
self.assertLess(mean_squared_error(y, pred3), mean_squared_error(y, pred4))
# test that interaction constraints work when not all features are used
X = np.concatenate([np.zeros((X.shape[0], 1)), X], axis=1)
num_features = X.shape[1]
train_data = lgb.Dataset(X, label=y)
est = lgb.train(dict(params, interaction_constraints=[[0] + list(range(2, num_features)),
[1] + list(range(2, num_features))]),
train_data, num_boost_round=10)