[python-package] [R-package] include more params in model text representation (fixes #6010) (#6077)

This commit is contained in:
James Lamb 2023-09-13 17:35:38 -05:00 коммит произвёл GitHub
Родитель 163416d2f5
Коммит ab1eaa832d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 389 добавлений и 34 удалений

Просмотреть файл

@ -29,3 +29,20 @@
.LGB_VERBOSITY <- as.integer(
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
)
# [description]
# test that every element of 'x' is in 'y'
#
# testthat::expect_in() is not available in version of {testthat}
# built for R 3.6, this is here to support a similar interface on R 3.6
.expect_in <- function(x, y) {
if (exists("expect_in")) {
expect_in(x, y)
} else {
missing_items <- x[!(x %in% y)]
if (length(missing_items) != 0L) {
error_msg <- paste0("Some expected items not found: ", toString(missing_items))
stop(error_msg)
}
}
}

Просмотреть файл

@ -799,37 +799,166 @@ test_that("all parameters are stored correctly with save_model_to_string()", {
data = matrix(rnorm(500L), nrow = 100L)
, label = rnorm(100L)
)
nrounds <- 4L
bst <- lgb.train(
params = list(
objective = "regression"
, metric = "l2"
objective = "mape"
, metric = c("l2", "mae")
, num_threads = .LGB_MAX_THREADS
, seed = 708L
, data_sample_strategy = "bagging"
, sub_row = 0.8234
)
, data = dtrain
, nrounds = nrounds
, nrounds = 3L
, verbose = .LGB_VERBOSITY
)
model_str <- bst$save_model_to_string()
params_in_file <- .params_from_model_string(model_str = model_str)
# entries whose values should reflect params passed to lgb.train()
non_default_param_entries <- c(
"[objective: mape]"
# 'l1' was passed in with alias 'mae'
, "[metric: l2,l1]"
, "[data_sample_strategy: bagging]"
, "[seed: 708]"
# this was passed in with alias 'sub_row'
, "[bagging_fraction: 0.8234]"
, "[num_iterations: 3]"
)
# entries with default values of params
default_param_entries <- c(
"[boosting: gbdt]"
, "[tree_learner: serial]"
, "[device_type: cpu]"
, "[data: ]"
, "[valid: ]"
, "[learning_rate: 0.1]"
, "[num_leaves: 31]"
, sprintf("[num_threads: %i]", .LGB_MAX_THREADS)
, "[deterministic: 0]"
, "[histogram_pool_size: -1]"
, "[max_depth: -1]"
, "[min_data_in_leaf: 20]"
, "[min_sum_hessian_in_leaf: 0.001]"
, "[pos_bagging_fraction: 1]"
, "[neg_bagging_fraction: 1]"
, "[bagging_freq: 0]"
, "[bagging_seed: 15415]"
, "[feature_fraction: 1]"
, "[feature_fraction_bynode: 1]"
, "[feature_fraction_seed: 32671]"
, "[extra_trees: 0]"
, "[extra_seed: 6642]"
, "[early_stopping_round: 0]"
, "[first_metric_only: 0]"
, "[max_delta_step: 0]"
, "[lambda_l1: 0]"
, "[lambda_l2: 0]"
, "[linear_lambda: 0]"
, "[min_gain_to_split: 0]"
, "[drop_rate: 0.1]"
, "[max_drop: 50]"
, "[skip_drop: 0.5]"
, "[xgboost_dart_mode: 0]"
, "[uniform_drop: 0]"
, "[drop_seed: 20623]"
, "[top_rate: 0.2]"
, "[other_rate: 0.1]"
, "[min_data_per_group: 100]"
, "[max_cat_threshold: 32]"
, "[cat_l2: 10]"
, "[cat_smooth: 10]"
, "[max_cat_to_onehot: 4]"
, "[top_k: 20]"
, "[monotone_constraints: ]"
, "[monotone_constraints_method: basic]"
, "[monotone_penalty: 0]"
, "[feature_contri: ]"
, "[forcedsplits_filename: ]"
, "[force_col_wise: 0]"
, "[force_row_wise: 0]"
, "[refit_decay_rate: 0.9]"
, "[cegb_tradeoff: 1]"
, "[cegb_penalty_split: 0]"
, "[cegb_penalty_feature_lazy: ]"
, "[cegb_penalty_feature_coupled: ]"
, "[path_smooth: 0]"
, "[interaction_constraints: ]"
, sprintf("[verbosity: %i]", .LGB_VERBOSITY)
, "[saved_feature_importance_type: 0]"
, "[use_quantized_grad: 0]"
, "[num_grad_quant_bins: 4]"
, "[quant_train_renew_leaf: 0]"
, "[stochastic_rounding: 1]"
, "[linear_tree: 0]"
, "[max_bin: 255]"
, "[max_bin_by_feature: ]"
, "[min_data_in_bin: 3]"
, "[bin_construct_sample_cnt: 200000]"
, "[data_random_seed: 2350]"
, "[is_enable_sparse: 1]"
, "[enable_bundle: 1]"
, "[use_missing: 1]"
, "[zero_as_missing: 0]"
, "[feature_pre_filter: 1]"
, "[pre_partition: 0]"
, "[two_round: 0]"
, "[header: 0]"
, "[label_column: ]"
, "[weight_column: ]"
, "[group_column: ]"
, "[ignore_column: ]"
, "[categorical_feature: ]"
, "[forcedbins_filename: ]"
, "[precise_float_parser: 0]"
, "[parser_config_file: ]"
, "[objective_seed: 4309]"
, "[num_class: 1]"
, "[is_unbalance: 0]"
, "[scale_pos_weight: 1]"
, "[sigmoid: 1]"
, "[boost_from_average: 1]"
, "[reg_sqrt: 0]"
, "[alpha: 0.9]"
, "[fair_c: 1]"
, "[poisson_max_delta_step: 0.7]"
, "[tweedie_variance_power: 1.5]"
, "[lambdarank_truncation_level: 30]"
, "[lambdarank_norm: 1]"
, "[label_gain: ]"
, "[lambdarank_position_bias_regularization: 0]"
, "[eval_at: ]"
, "[multi_error_top_k: 1]"
, "[auc_mu_weights: ]"
, "[num_machines: 1]"
, "[local_listen_port: 12400]"
, "[time_out: 120]"
, "[machine_list_filename: ]"
, "[machines: ]"
, "[gpu_platform_id: -1]"
, "[gpu_device_id: -1]"
, "[gpu_use_dp: 0]"
, "[num_gpu: 1]"
)
all_param_entries <- c(non_default_param_entries, default_param_entries)
# parameters should match what was passed from the R package
expect_equal(sum(startsWith(params_in_file, "[metric:")), 1L)
expect_equal(sum(params_in_file == "[metric: l2]"), 1L)
expect_equal(sum(startsWith(params_in_file, "[num_iterations:")), 1L)
expect_equal(sum(params_in_file == "[num_iterations: 4]"), 1L)
expect_equal(sum(startsWith(params_in_file, "[objective:")), 1L)
expect_equal(sum(params_in_file == "[objective: regression]"), 1L)
expect_equal(sum(startsWith(params_in_file, "[verbosity:")), 1L)
expect_equal(sum(params_in_file == sprintf("[verbosity: %i]", .LGB_VERBOSITY)), 1L)
model_str <- bst$save_model_to_string()
params_in_file <- .params_from_model_string(model_str = model_str)
.expect_in(all_param_entries, params_in_file)
# early stopping should be off by default
expect_equal(sum(startsWith(params_in_file, "[early_stopping_round:")), 1L)
expect_equal(sum(params_in_file == "[early_stopping_round: 0]"), 1L)
# since save_model_to_string() is used when serializing with saveRDS(), check that parameters all
# roundtrip saveRDS()/loadRDS() successfully
rds_file <- tempfile()
saveRDS(bst, rds_file)
bst_rds <- readRDS(rds_file)
model_str <- bst_rds$save_model_to_string()
params_in_file <- .params_from_model_string(model_str = model_str)
.expect_in(all_param_entries, params_in_file)
})
test_that("early_stopping, num_iterations are stored correctly in model string even with aliases", {

Просмотреть файл

@ -330,7 +330,7 @@ def gen_parameter_code(
str_to_write += ' std::string tmp_str = "";\n'
for x in infos:
for y in x:
if "[doc-only]" in y:
if "[no-automatically-extract]" in y:
continue
param_type = y["inner_type"][0]
name = y["name"][0]
@ -345,7 +345,7 @@ def gen_parameter_code(
str_to_write += " std::stringstream str_buf;\n"
for x in infos:
for y in x:
if "[doc-only]" in y or "[no-save]" in y:
if "[no-save]" in y:
continue
param_type = y["inner_type"][0]
name = y["name"][0]

Просмотреть файл

@ -5,8 +5,13 @@
* \note
* - desc and descl2 fields must be written in reStructuredText format;
* - nested sections can be placed only at the bottom of parent's section;
* - [doc-only] tag indicates that only documentation for this param should be generated and all other actions are performed manually;
* - [no-save] tag indicates that this param should not be saved into a model text representation.
* - [no-automatically-extract]
* - do not automatically extract this parameter into a Config property with the same name in Config::GetMembersFromString(). Use if:
* - specialized extraction logic for this param exists in Config::GetMembersFromString()
* - [no-save]
* - this param should not be saved into a model text representation via Config::SaveMembersToString(). Use if:
* - param is only used by the CLI (especially the "predict" and "convert_model" tasks)
* - param is related to LightGBM writing files (e.g. "output_model", "save_binary")
*/
#ifndef LIGHTGBM_CONFIG_H_
#define LIGHTGBM_CONFIG_H_
@ -97,15 +102,15 @@ struct Config {
#pragma region Core Parameters
#endif // __NVCC__
// [no-automatically-extract]
// [no-save]
// [doc-only]
// alias = config_file
// desc = path of config file
// desc = **Note**: can be used only in CLI version
std::string config = "";
// [no-automatically-extract]
// [no-save]
// [doc-only]
// type = enum
// default = train
// options = train, predict, convert_model, refit
@ -118,7 +123,8 @@ struct Config {
// desc = **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions
TaskType task = TaskType::kTrain;
// [doc-only]
// [no-automatically-extract]
// [no-save]
// type = enum
// options = regression, regression_l1, huber, fair, poisson, quantile, mape, gamma, tweedie, binary, multiclass, multiclassova, cross_entropy, cross_entropy_lambda, lambdarank, rank_xendcg
// alias = objective_type, app, application, loss
@ -150,7 +156,8 @@ struct Config {
// descl2 = label should be ``int`` type, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)
std::string objective = "regression";
// [doc-only]
// [no-automatically-extract]
// [no-save]
// type = enum
// alias = boosting_type, boost
// options = gbdt, rf, dart
@ -160,7 +167,7 @@ struct Config {
// descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations
std::string boosting = "gbdt";
// [doc-only]
// [no-automatically-extract]
// type = enum
// options = bagging, goss
// desc = ``bagging``, Randomly Bagging Sampling
@ -200,7 +207,8 @@ struct Config {
// desc = max number of leaves in one tree
int num_leaves = kDefaultNumLeaves;
// [doc-only]
// [no-automatically-extract]
// [no-save]
// type = enum
// options = serial, feature, data, voting
// alias = tree, tree_type, tree_learner_type
@ -222,7 +230,8 @@ struct Config {
// desc = **Note**: please **don't** change this during training, especially when running multiple jobs simultaneously by external packages, otherwise it may cause undesirable errors
int num_threads = 0;
// [doc-only]
// [no-automatically-extract]
// [no-save]
// type = enum
// options = cpu, gpu, cuda
// alias = device
@ -235,7 +244,7 @@ struct Config {
// desc = **Note**: refer to `Installation Guide <./Installation-Guide.rst#build-gpu-version>`__ to build LightGBM with GPU support
std::string device_type = "cpu";
// [doc-only]
// [no-automatically-extract]
// alias = random_seed, random_state
// default = None
// desc = this seed is used to generate other seeds, e.g. ``data_random_seed``, ``feature_fraction_seed``, etc.
@ -593,7 +602,6 @@ struct Config {
// desc = **Note**: can be used only in CLI version
int snapshot_freq = -1;
// [no-save]
// desc = whether to use gradient quantization when training
// desc = enabling this will discretize (quantize) the gradients and hessians into bins of ``num_grad_quant_bins``
// desc = with quantized training, most arithmetics in the training process will be integer operations
@ -602,21 +610,18 @@ struct Config {
// desc = *New in version 4.0.0*
bool use_quantized_grad = false;
// [no-save]
// desc = number of bins to quantization gradients and hessians
// desc = with more bins, the quantized training will be closer to full precision training
// desc = **Note**: can be used only with ``device_type = cpu``
// desc = *New in 4.0.0*
int num_grad_quant_bins = 4;
// [no-save]
// desc = whether to renew the leaf values with original gradients when quantized training
// desc = renewing is very helpful for good quantized training accuracy for ranking objectives
// desc = **Note**: can be used only with ``device_type = cpu``
// desc = *New in 4.0.0*
bool quant_train_renew_leaf = false;
// [no-save]
// desc = whether to use stochastic rounding in gradient quantization
// desc = *New in 4.0.0*
bool stochastic_rounding = true;
@ -976,7 +981,8 @@ struct Config {
#pragma region Metric Parameters
#endif // __NVCC__
// [doc-only]
// [no-automatically-extract]
// [no-save]
// alias = metrics, metric_types
// default = ""
// type = multi-enum

Просмотреть файл

@ -664,12 +664,14 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
std::string Config::SaveMembersToString() const {
std::stringstream str_buf;
str_buf << "[data_sample_strategy: " << data_sample_strategy << "]\n";
str_buf << "[data: " << data << "]\n";
str_buf << "[valid: " << Common::Join(valid, ",") << "]\n";
str_buf << "[num_iterations: " << num_iterations << "]\n";
str_buf << "[learning_rate: " << learning_rate << "]\n";
str_buf << "[num_leaves: " << num_leaves << "]\n";
str_buf << "[num_threads: " << num_threads << "]\n";
str_buf << "[seed: " << seed << "]\n";
str_buf << "[deterministic: " << deterministic << "]\n";
str_buf << "[force_col_wise: " << force_col_wise << "]\n";
str_buf << "[force_row_wise: " << force_row_wise << "]\n";
@ -722,6 +724,10 @@ std::string Config::SaveMembersToString() const {
str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
str_buf << "[use_quantized_grad: " << use_quantized_grad << "]\n";
str_buf << "[num_grad_quant_bins: " << num_grad_quant_bins << "]\n";
str_buf << "[quant_train_renew_leaf: " << quant_train_renew_leaf << "]\n";
str_buf << "[stochastic_rounding: " << stochastic_rounding << "]\n";
str_buf << "[linear_tree: " << linear_tree << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";

Просмотреть файл

@ -1534,6 +1534,203 @@ def test_save_load_copy_pickle():
assert ret_origin == pytest.approx(ret)
def test_all_expected_params_are_written_out_to_model_text(tmp_path):
X, y = make_synthetic_regression()
params = {
'objective': 'mape',
'metric': ['l2', 'mae'],
'seed': 708,
'data_sample_strategy': 'bagging',
'sub_row': 0.8234,
'verbose': -1
}
dtrain = lgb.Dataset(data=X, label=y)
gbm = lgb.train(
params=params,
train_set=dtrain,
num_boost_round=3
)
model_txt_from_memory = gbm.model_to_string()
model_file = tmp_path / "out.model"
gbm.save_model(filename=model_file)
with open(model_file, "r") as f:
model_txt_from_file = f.read()
assert model_txt_from_memory == model_txt_from_file
# entries whose values should reflect params passed to lgb.train()
non_default_param_entries = [
"[objective: mape]",
# 'l1' was passed in with alias 'mae'
"[metric: l2,l1]",
"[data_sample_strategy: bagging]",
"[seed: 708]",
# NOTE: this was passed in with alias 'sub_row'
"[bagging_fraction: 0.8234]",
"[num_iterations: 3]",
]
# entries with default values of params
default_param_entries = [
"[boosting: gbdt]",
"[tree_learner: serial]",
"[data: ]",
"[valid: ]",
"[learning_rate: 0.1]",
"[num_leaves: 31]",
"[num_threads: 0]",
"[deterministic: 0]",
"[histogram_pool_size: -1]",
"[max_depth: -1]",
"[min_data_in_leaf: 20]",
"[min_sum_hessian_in_leaf: 0.001]",
"[pos_bagging_fraction: 1]",
"[neg_bagging_fraction: 1]",
"[bagging_freq: 0]",
"[bagging_seed: 15415]",
"[feature_fraction: 1]",
"[feature_fraction_bynode: 1]",
"[feature_fraction_seed: 32671]",
"[extra_trees: 0]",
"[extra_seed: 6642]",
"[early_stopping_round: 0]",
"[first_metric_only: 0]",
"[max_delta_step: 0]",
"[lambda_l1: 0]",
"[lambda_l2: 0]",
"[linear_lambda: 0]",
"[min_gain_to_split: 0]",
"[drop_rate: 0.1]",
"[max_drop: 50]",
"[skip_drop: 0.5]",
"[xgboost_dart_mode: 0]",
"[uniform_drop: 0]",
"[drop_seed: 20623]",
"[top_rate: 0.2]",
"[other_rate: 0.1]",
"[min_data_per_group: 100]",
"[max_cat_threshold: 32]",
"[cat_l2: 10]",
"[cat_smooth: 10]",
"[max_cat_to_onehot: 4]",
"[top_k: 20]",
"[monotone_constraints: ]",
"[monotone_constraints_method: basic]",
"[monotone_penalty: 0]",
"[feature_contri: ]",
"[forcedsplits_filename: ]",
"[refit_decay_rate: 0.9]",
"[cegb_tradeoff: 1]",
"[cegb_penalty_split: 0]",
"[cegb_penalty_feature_lazy: ]",
"[cegb_penalty_feature_coupled: ]",
"[path_smooth: 0]",
"[interaction_constraints: ]",
"[verbosity: -1]",
"[saved_feature_importance_type: 0]",
"[use_quantized_grad: 0]",
"[num_grad_quant_bins: 4]",
"[quant_train_renew_leaf: 0]",
"[stochastic_rounding: 1]",
"[linear_tree: 0]",
"[max_bin: 255]",
"[max_bin_by_feature: ]",
"[min_data_in_bin: 3]",
"[bin_construct_sample_cnt: 200000]",
"[data_random_seed: 2350]",
"[is_enable_sparse: 1]",
"[enable_bundle: 1]",
"[use_missing: 1]",
"[zero_as_missing: 0]",
"[feature_pre_filter: 1]",
"[pre_partition: 0]",
"[two_round: 0]",
"[header: 0]",
"[label_column: ]",
"[weight_column: ]",
"[group_column: ]",
"[ignore_column: ]",
"[categorical_feature: ]",
"[forcedbins_filename: ]",
"[precise_float_parser: 0]",
"[parser_config_file: ]",
"[objective_seed: 4309]",
"[num_class: 1]",
"[is_unbalance: 0]",
"[scale_pos_weight: 1]",
"[sigmoid: 1]",
"[boost_from_average: 1]",
"[reg_sqrt: 0]",
"[alpha: 0.9]",
"[fair_c: 1]",
"[poisson_max_delta_step: 0.7]",
"[tweedie_variance_power: 1.5]",
"[lambdarank_truncation_level: 30]",
"[lambdarank_norm: 1]",
"[label_gain: ]",
"[lambdarank_position_bias_regularization: 0]",
"[eval_at: ]",
"[multi_error_top_k: 1]",
"[auc_mu_weights: ]",
"[num_machines: 1]",
"[local_listen_port: 12400]",
"[time_out: 120]",
"[machine_list_filename: ]",
"[machines: ]",
"[gpu_platform_id: -1]",
"[gpu_device_id: -1]",
"[num_gpu: 1]",
]
all_param_entries = non_default_param_entries + default_param_entries
# add device-specific entries
#
# passed-in force_col_wise / force_row_wise parameters are ignored on CUDA and GPU builds...
# https://github.com/microsoft/LightGBM/blob/1d7ee63686272bceffd522284127573b511df6be/src/io/config.cpp#L375-L377
if getenv('TASK', '') == 'cuda':
device_entries = [
"[force_col_wise: 0]",
"[force_row_wise: 1]",
"[device_type: cuda]",
"[gpu_use_dp: 1]"
]
elif getenv('TASK', '') == 'gpu':
device_entries = [
"[force_col_wise: 1]",
"[force_row_wise: 0]",
"[device_type: gpu]",
"[gpu_use_dp: 0]"
]
else:
device_entries = [
"[force_col_wise: 0]",
"[force_row_wise: 0]",
"[device_type: cpu]",
"[gpu_use_dp: 0]"
]
all_param_entries += device_entries
# check that model text has all expected param entries
for param_str in all_param_entries:
assert param_str in model_txt_from_file
assert param_str in model_txt_from_memory
# since Booster.model_to_string() is used when pickling, check that parameters all
# roundtrip pickling successfully too
gbm_pkl = pickle_and_unpickle_object(gbm, serializer="joblib")
model_txt_from_memory = gbm_pkl.model_to_string()
model_file = tmp_path / "out-pkl.model"
gbm_pkl.save_model(filename=model_file)
with open(model_file, "r") as f:
model_txt_from_file = f.read()
for param_str in all_param_entries:
assert param_str in model_txt_from_file
assert param_str in model_txt_from_memory
def test_pandas_categorical():
pd = pytest.importorskip("pandas")
np.random.seed(42) # sometimes there is no difference how cols are treated (cat or not cat)