зеркало из https://github.com/microsoft/LightGBM.git
simplify start_iteration param for predict in Python and some code cleanup for start_iteration (#3288)
* simplify start_iteration param for predict in Python and some code cleanup for start_iteration * revert docs changes about the prediction result shape
This commit is contained in:
Родитель
97d5758fbb
Коммит
877d58fac7
|
@ -491,11 +491,11 @@ Booster <- R6::R6Class(
|
|||
header = FALSE,
|
||||
reshape = FALSE, ...) {
|
||||
|
||||
# Check if number of iteration is non existent
|
||||
# Check if number of iteration is non existent
|
||||
if (is.null(num_iteration)) {
|
||||
num_iteration <- self$best_iter
|
||||
}
|
||||
# Check if start iteration is non existent
|
||||
# Check if start iteration is non existent
|
||||
if (is.null(start_iteration)) {
|
||||
start_iteration <- 0L
|
||||
}
|
||||
|
|
|
@ -38,9 +38,8 @@ test_that("start_iteration works correctly", {
|
|||
, label = train$label
|
||||
, num_leaves = 4L
|
||||
, learning_rate = 0.6
|
||||
, nrounds = 100L
|
||||
, nrounds = 50L
|
||||
, objective = "binary"
|
||||
, save_name = tempfile(fileext = ".model")
|
||||
, valids = list("test" = dtest)
|
||||
, early_stopping_rounds = 2L
|
||||
)
|
||||
|
@ -50,7 +49,7 @@ test_that("start_iteration works correctly", {
|
|||
pred2 <- rep(0.0, length(pred1))
|
||||
pred_contrib2 <- rep(0.0, length(pred2))
|
||||
step <- 11L
|
||||
end_iter <- 99L
|
||||
end_iter <- 49L
|
||||
if (bst$best_iter != -1L) {
|
||||
end_iter <- bst$best_iter - 1L
|
||||
}
|
||||
|
|
|
@ -2813,7 +2813,7 @@ class Booster(object):
|
|||
default=json_default_with_numpy))
|
||||
return ret
|
||||
|
||||
def predict(self, data, start_iteration=None, num_iteration=None,
|
||||
def predict(self, data, start_iteration=0, num_iteration=None,
|
||||
raw_score=False, pred_leaf=False, pred_contrib=False,
|
||||
data_has_header=False, is_reshape=True, **kwargs):
|
||||
"""Make a prediction.
|
||||
|
@ -2823,14 +2823,14 @@ class Booster(object):
|
|||
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
|
||||
Data source for prediction.
|
||||
If string, it represents the path to txt file.
|
||||
start_iteration : int or None, optional (default=None)
|
||||
start_iteration : int, optional (default=0)
|
||||
Start index of the iteration to predict.
|
||||
If None or <= 0, starts from the first iteration.
|
||||
If <= 0, starts from the first iteration.
|
||||
num_iteration : int or None, optional (default=None)
|
||||
Limit number of iterations in the prediction.
|
||||
If None, if the best iteration exists and start_iteration is None or <= 0, the best iteration is used;
|
||||
otherwise, all iterations from start_iteration are used.
|
||||
If <= 0, all iterations from start_iteration are used (no limits).
|
||||
Total number of iterations used in the prediction.
|
||||
If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
|
||||
otherwise, all iterations from ``start_iteration`` are used (no limits).
|
||||
If <= 0, all iterations from ``start_iteration`` are used (no limits).
|
||||
raw_score : bool, optional (default=False)
|
||||
Whether to predict raw scores.
|
||||
pred_leaf : bool, optional (default=False)
|
||||
|
@ -2861,10 +2861,8 @@ class Booster(object):
|
|||
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
|
||||
"""
|
||||
predictor = self._to_predictor(copy.deepcopy(kwargs))
|
||||
if start_iteration is None or start_iteration < 0:
|
||||
start_iteration = 0
|
||||
if num_iteration is None:
|
||||
if start_iteration == 0:
|
||||
if start_iteration <= 0:
|
||||
num_iteration = self.best_iteration
|
||||
else:
|
||||
num_iteration = -1
|
||||
|
|
|
@ -612,7 +612,7 @@ class LGBMModel(_LGBMModelBase):
|
|||
del train_set, valid_sets
|
||||
return self
|
||||
|
||||
def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,
|
||||
def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
|
||||
pred_leaf=False, pred_contrib=False, **kwargs):
|
||||
"""Return the predicted value for each sample.
|
||||
|
||||
|
@ -622,13 +622,14 @@ class LGBMModel(_LGBMModelBase):
|
|||
Input features matrix.
|
||||
raw_score : bool, optional (default=False)
|
||||
Whether to predict raw scores.
|
||||
start_iteration : int or None, optional (default=None)
|
||||
start_iteration : int, optional (default=0)
|
||||
Start index of the iteration to predict.
|
||||
If None or <= 0, starts from the first iteration.
|
||||
If <= 0, starts from the first iteration.
|
||||
num_iteration : int or None, optional (default=None)
|
||||
Limit number of iterations in the prediction.
|
||||
If None, if the best iteration exists, it is used; otherwise, all trees are used.
|
||||
If <= 0, all trees are used (no limits).
|
||||
Total number of iterations used in the prediction.
|
||||
If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
|
||||
otherwise, all iterations from ``start_iteration`` are used (no limits).
|
||||
If <= 0, all iterations from ``start_iteration`` are used (no limits).
|
||||
pred_leaf : bool, optional (default=False)
|
||||
Whether to predict leaf index.
|
||||
pred_contrib : bool, optional (default=False)
|
||||
|
@ -835,7 +836,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
|
|||
|
||||
fit.__doc__ = LGBMModel.fit.__doc__
|
||||
|
||||
def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,
|
||||
def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
|
||||
pred_leaf=False, pred_contrib=False, **kwargs):
|
||||
"""Docstring is inherited from the LGBMModel."""
|
||||
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
|
||||
|
@ -848,7 +849,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
|
|||
|
||||
predict.__doc__ = LGBMModel.predict.__doc__
|
||||
|
||||
def predict_proba(self, X, raw_score=False, start_iteration=None, num_iteration=None,
|
||||
def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None,
|
||||
pred_leaf=False, pred_contrib=False, **kwargs):
|
||||
"""Return the predicted probability for each class for each sample.
|
||||
|
||||
|
@ -858,13 +859,14 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
|
|||
Input features matrix.
|
||||
raw_score : bool, optional (default=False)
|
||||
Whether to predict raw scores.
|
||||
start_iteration : int or None, optional (default=None)
|
||||
start_iteration : int, optional (default=0)
|
||||
Start index of the iteration to predict.
|
||||
If None or <= 0, starts from the first iteration.
|
||||
If <= 0, starts from the first iteration.
|
||||
num_iteration : int or None, optional (default=None)
|
||||
Limit number of iterations in the prediction.
|
||||
If None, if the best iteration exists, it is used; otherwise, all trees are used.
|
||||
If <= 0, all trees are used (no limits).
|
||||
Total number of iterations used in the prediction.
|
||||
If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
|
||||
otherwise, all iterations from ``start_iteration`` are used (no limits).
|
||||
If <= 0, all iterations from ``start_iteration`` are used (no limits).
|
||||
pred_leaf : bool, optional (default=False)
|
||||
Whether to predict leaf index.
|
||||
pred_contrib : bool, optional (default=False)
|
||||
|
|
|
@ -226,7 +226,6 @@ class Predictor {
|
|||
data_size_t, const std::vector<std::string>& lines) {
|
||||
std::vector<std::pair<int, double>> oneline_features;
|
||||
std::vector<std::string> result_to_write(lines.size());
|
||||
Log::Warning("before predict_fun_ is called");
|
||||
OMP_INIT_EX();
|
||||
#pragma omp parallel for schedule(static) firstprivate(oneline_features)
|
||||
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
|
||||
|
@ -241,7 +240,6 @@ class Predictor {
|
|||
result_to_write[i] = str_result;
|
||||
OMP_LOOP_EX_END();
|
||||
}
|
||||
Log::Warning("after predict_fun_ is called");
|
||||
OMP_THROW_EX();
|
||||
for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
|
||||
writer->Write(result_to_write[i].c_str(), result_to_write[i].size());
|
||||
|
|
|
@ -78,7 +78,7 @@ void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double*
|
|||
|
||||
void GBDT::PredictLeafIndex(const double* features, double* output) const {
|
||||
int start_tree = start_iteration_for_pred_ * num_tree_per_iteration_;
|
||||
int num_trees = num_iteration_for_pred_ * num_tree_per_iteration_;
|
||||
int num_trees = num_iteration_for_pred_ * num_tree_per_iteration_;
|
||||
const auto* models_ptr = models_.data() + start_tree;
|
||||
for (int i = 0; i < num_trees; ++i) {
|
||||
output[i] = models_ptr[i]->PredictLeafIndex(features);
|
||||
|
|
|
@ -2321,7 +2321,7 @@ class TestEngine(unittest.TestCase):
|
|||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
|
||||
train_data = lgb.Dataset(X_train, label=y_train)
|
||||
valid_data = lgb.Dataset(X_test, label=y_test)
|
||||
booster = lgb.train(params, train_data, num_boost_round=100, early_stopping_rounds=early_stopping_rounds, valid_sets=[valid_data])
|
||||
booster = lgb.train(params, train_data, num_boost_round=50, early_stopping_rounds=early_stopping_rounds, valid_sets=[valid_data])
|
||||
|
||||
# test that the predict once with all iterations equals summed results with start_iteration and num_iteration
|
||||
all_pred = booster.predict(X, raw_score=True)
|
||||
|
@ -2330,17 +2330,15 @@ class TestEngine(unittest.TestCase):
|
|||
for step in steps:
|
||||
pred = np.zeros_like(all_pred)
|
||||
pred_contrib = np.zeros_like(all_pred_contrib)
|
||||
for start_iter in range(0, 100, step):
|
||||
pred += booster.predict(X, num_iteration=step, start_iteration=start_iter, raw_score=True)
|
||||
pred_contrib += booster.predict(X, num_iteration=step, start_iteration=start_iter, pred_contrib=True)
|
||||
for start_iter in range(0, 50, step):
|
||||
pred += booster.predict(X, start_iteration=start_iter, num_iteration=step, raw_score=True)
|
||||
pred_contrib += booster.predict(X, start_iteration=start_iter, num_iteration=step, pred_contrib=True)
|
||||
np.testing.assert_allclose(all_pred, pred)
|
||||
np.testing.assert_allclose(all_pred_contrib, pred_contrib)
|
||||
# test the case where start_iteration <= 0, and num_iteration is None
|
||||
pred1 = booster.predict(X, start_iteration=-1)
|
||||
pred2 = booster.predict(X, num_iteration=booster.best_iteration)
|
||||
pred3 = booster.predict(X, num_iteration=booster.best_iteration, start_iteration=0)
|
||||
np.testing.assert_allclose(pred1, pred2)
|
||||
np.testing.assert_allclose(pred1, pred3)
|
||||
|
||||
# test the case where start_iteration > 0, and num_iteration <= 0
|
||||
pred4 = booster.predict(X, start_iteration=10, num_iteration=-1)
|
||||
|
@ -2351,14 +2349,14 @@ class TestEngine(unittest.TestCase):
|
|||
|
||||
# test the case where start_iteration > 0, and num_iteration <= 0, with pred_leaf=True
|
||||
pred4 = booster.predict(X, start_iteration=10, num_iteration=-1, pred_leaf=True)
|
||||
pred5 = booster.predict(X, start_iteration=10, num_iteration=90, pred_leaf=True)
|
||||
pred5 = booster.predict(X, start_iteration=10, num_iteration=40, pred_leaf=True)
|
||||
pred6 = booster.predict(X, start_iteration=10, num_iteration=0, pred_leaf=True)
|
||||
np.testing.assert_allclose(pred4, pred5)
|
||||
np.testing.assert_allclose(pred4, pred6)
|
||||
|
||||
# test the case where start_iteration > 0, and num_iteration <= 0, with pred_contrib=True
|
||||
pred4 = booster.predict(X, start_iteration=10, num_iteration=-1, pred_contrib=True)
|
||||
pred5 = booster.predict(X, start_iteration=10, num_iteration=90, pred_contrib=True)
|
||||
pred5 = booster.predict(X, start_iteration=10, num_iteration=40, pred_contrib=True)
|
||||
pred6 = booster.predict(X, start_iteration=10, num_iteration=0, pred_contrib=True)
|
||||
np.testing.assert_allclose(pred4, pred5)
|
||||
np.testing.assert_allclose(pred4, pred6)
|
||||
|
@ -2373,7 +2371,7 @@ class TestEngine(unittest.TestCase):
|
|||
}
|
||||
# test both with and without early stopping
|
||||
inner_test(X, y, params, early_stopping_rounds=1)
|
||||
inner_test(X, y, params, early_stopping_rounds=10)
|
||||
inner_test(X, y, params, early_stopping_rounds=5)
|
||||
inner_test(X, y, params, early_stopping_rounds=None)
|
||||
|
||||
# test for multi-class
|
||||
|
@ -2387,7 +2385,7 @@ class TestEngine(unittest.TestCase):
|
|||
}
|
||||
# test both with and without early stopping
|
||||
inner_test(X, y, params, early_stopping_rounds=1)
|
||||
inner_test(X, y, params, early_stopping_rounds=10)
|
||||
inner_test(X, y, params, early_stopping_rounds=5)
|
||||
inner_test(X, y, params, early_stopping_rounds=None)
|
||||
|
||||
# test for binary
|
||||
|
@ -2400,5 +2398,5 @@ class TestEngine(unittest.TestCase):
|
|||
}
|
||||
# test both with and without early stopping
|
||||
inner_test(X, y, params, early_stopping_rounds=1)
|
||||
inner_test(X, y, params, early_stopping_rounds=10)
|
||||
inner_test(X, y, params, early_stopping_rounds=5)
|
||||
inner_test(X, y, params, early_stopping_rounds=None)
|
||||
|
|
Загрузка…
Ссылка в новой задаче