From 222775ca29f139afa1ef5a1c2d6e612707cdca12 Mon Sep 17 00:00:00 2001 From: Belinda Trotta Date: Fri, 13 Dec 2019 07:27:27 +1100 Subject: [PATCH] auc-mu metric (#2567) * Fix bug where small values of max_bin cause crash. * Revert "Fix bug where small values of max_bin cause crash." This reverts commit fe5c8e2547057c1fa5750bcddd359dd7708fab4b. * Add auc-mu multiclass metric. * Fix bug where scores are equal. * Merge. * Change name to auc_mu everywhere (instead of auc-mu). * Fix comparison between signed and unsigned int. * Change name to AUC-mu in docs and output messages. * Improve test. * Use prefix increment. * Update R package. * Fix style issues. * Tidy up test code. * Read all lines first then process. * Allow passing AUC-mu weights directly as a list in parameters. * Remove unused code, improve example and docs. --- R-package/R/lgb.Booster.R | 2 +- docs/Parameters.rst | 14 ++ examples/multiclass_classification/train.conf | 11 +- include/LightGBM/config.h | 16 ++- src/io/config.cpp | 33 +++++ src/io/config_auto.cpp | 6 + src/metric/metric.cpp | 2 + src/metric/multiclass_metric.hpp | 132 ++++++++++++++++++ tests/python_package_test/test_engine.py | 62 ++++++++ 9 files changed, 275 insertions(+), 3 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index fe73bfbba..b2591edf8 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -563,7 +563,7 @@ Booster <- R6::R6Class( # Parse and store privately names names <- strsplit(names, "\t")[[1L]] private$eval_names <- names - private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc$", names) + private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc", names) } diff --git a/docs/Parameters.rst b/docs/Parameters.rst index d158ab6fb..b2ef341a5 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -881,6 +881,8 @@ Metric Parameters - ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification + - ``auc_mu``, `AUC-mu `__ + - ``multi_logloss``, log loss for multi-class classification, aliases: ``multiclass``, ``softmax``, ``multiclassova``, ``multiclass_ova``, ``ova``, ``ovr`` - ``multi_error``, error rate for multi-class classification @@ -921,6 +923,18 @@ Metric Parameters - when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric +- ``auc_mu_weights`` :raw-html:`🔗︎`, default = ``None``, type = multi-double + + - used only with ``auc_mu`` metric + + - list representing flattened matrix (in row-major order) giving loss weights for classification errors + + - list should have ``n * n`` elements, where ``n`` is the number of classes + + - the matrix co-ordinate ``[i, j]`` should correspond to the ``i * n + j``-th element of the list + + - if not specified, will use equal weights for all classes + Network Parameters ------------------ diff --git a/examples/multiclass_classification/train.conf b/examples/multiclass_classification/train.conf index 3973630cf..67eb124bd 100644 --- a/examples/multiclass_classification/train.conf +++ b/examples/multiclass_classification/train.conf @@ -21,7 +21,16 @@ objective = multiclass # binary_error # multi_logloss # multi_error -metric = multi_logloss +# auc_mu +metric = multi_logloss,auc_mu + +# AUC-mu weights; the matrix of loss weights below is passed in parameter auc_mu_weights as a list +# 0 1 2 3 4 +# 5 0 6 7 8 +# 9 10 0 11 12 +# 13 14 15 0 16 +# 17 18 19 20 0 +auc_mu_weights = 0,1,2,3,4,5,0,6,7,8,9,10,0,11,12,13,14,15,0,16,17,18,19,20,0 # number of class, for multiclass classification num_class = 5 diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 971684853..378fd51cf 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -773,6 +773,7 @@ struct Config { // descl2 = ``auc``, `AUC `__ // descl2 = ``binary_logloss``, `log loss `__, aliases: ``binary`` // descl2 = ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification + // descl2 = ``auc_mu``, `AUC-mu `__ // descl2 = ``multi_logloss``, log loss for multi-class classification, aliases: ``multiclass``, ``softmax``, ``multiclassova``, ``multiclass_ova``, ``ova``, ``ovr`` // descl2 = ``multi_error``, error rate for multi-class classification // descl2 = ``cross_entropy``, cross-entropy (with optional linear weights), aliases: ``xentropy`` @@ -806,6 +807,15 @@ struct Config { // desc = when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric int multi_error_top_k = 1; + // type = multi-double + // default = None + // desc = used only with ``auc_mu`` metric + // desc = list representing flattened matrix (in row-major order) giving loss weights for classification errors + // desc = list should have ``n * n`` elements, where ``n`` is the number of classes + // desc = the matrix co-ordinate ``[i, j]`` should correspond to the ``i * n + j``-th element of the list + // desc = if not specified, will use equal weights for all classes + std::vector auc_mu_weights; + #pragma endregion #pragma region Network Parameters @@ -863,11 +873,13 @@ struct Config { LIGHTGBM_EXPORT void Set(const std::unordered_map& params); static std::unordered_map alias_table; static std::unordered_set parameter_set; + std::vector> auc_mu_weights_matrix; private: void CheckParamConflict(); void GetMembersFromString(const std::unordered_map& params); std::string SaveMembersToString() const; + void GetAucMuWeights(); }; inline bool Config::GetString( @@ -1013,6 +1025,8 @@ inline std::string ParseMetricAlias(const std::string& type) { return "kullback_leibler"; } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) { return "mape"; + } else if (type == std::string("auc_mu")) { + return "auc_mu"; } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) { return "custom"; } @@ -1021,4 +1035,4 @@ inline std::string ParseMetricAlias(const std::string& type) { } // namespace LightGBM -#endif // LightGBM_CONFIG_H_ \ No newline at end of file +#endif // LightGBM_CONFIG_H_ diff --git a/src/io/config.cpp b/src/io/config.cpp index fc61b13d6..eac1d7e2c 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -153,6 +153,36 @@ void GetTreeLearnerType(const std::unordered_map& para } } +void Config::GetAucMuWeights() { + if (auc_mu_weights.empty()) { + // equal weights for all classes + auc_mu_weights_matrix = std::vector> (num_class, std::vector(num_class, 1)); + for (size_t i = 0; i < static_cast(num_class); ++i) { + auc_mu_weights_matrix[i][i] = 0; + } + } else { + auc_mu_weights_matrix = std::vector> (num_class, std::vector(num_class, 0)); + if (auc_mu_weights.size() != static_cast(num_class * num_class)) { + Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size()); + } + for (size_t i = 0; i < static_cast(num_class); ++i) { + for (size_t j = 0; j < static_cast(num_class); ++j) { + if (i == j) { + auc_mu_weights_matrix[i][j] = 0; + if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) { + Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j); + } + } else { + if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) { + Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j); + } + auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j]; + } + } + } + } +}; + void Config::Set(const std::unordered_map& params) { // generate seeds by seed. if (GetInt(params, "seed", &seed)) { @@ -173,6 +203,8 @@ void Config::Set(const std::unordered_map& params) { GetMembersFromString(params); + GetAucMuWeights(); + // sort eval_at std::sort(eval_at.begin(), eval_at.end()); @@ -230,6 +262,7 @@ void Config::CheckParamConflict() { bool metric_type_multiclass = (CheckMultiClassObjective(metric_type) || metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error") + || metric_type == std::string("auc_mu") || (metric_type == std::string("custom") && num_class_check > 1)); if ((objective_type_multiclass && !metric_type_multiclass) || (!objective_type_multiclass && metric_type_multiclass)) { diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 809b8a784..add6aed50 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -276,6 +276,7 @@ std::unordered_set Config::parameter_set({ "is_provide_training_metric", "eval_at", "multi_error_top_k", + "auc_mu_weights", "num_machines", "local_listen_port", "time_out", @@ -561,6 +562,10 @@ void Config::GetMembersFromString(const std::unordered_map0); + if (GetString(params, "auc_mu_weights", &tmp_str)) { + auc_mu_weights = Common::StringToArray(tmp_str, ','); + } + GetInt(params, "num_machines", &num_machines); CHECK(num_machines >0); @@ -683,6 +688,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n"; + str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n"; str_buf << "[num_machines: " << num_machines << "]\n"; str_buf << "[local_listen_port: " << local_listen_port << "]\n"; str_buf << "[time_out: " << time_out << "]\n"; diff --git a/src/metric/metric.cpp b/src/metric/metric.cpp index 715c78910..3ac70415d 100644 --- a/src/metric/metric.cpp +++ b/src/metric/metric.cpp @@ -34,6 +34,8 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) { return new BinaryErrorMetric(config); } else if (type == std::string("auc")) { return new AUCMetric(config); + } else if (type == std::string("auc_mu")) { + return new AucMuMetric(config); } else if (type == std::string("ndcg")) { return new NDCGMetric(config); } else if (type == std::string("map")) { diff --git a/src/metric/multiclass_metric.hpp b/src/metric/multiclass_metric.hpp index 8cf92f67a..d8266ac11 100644 --- a/src/metric/multiclass_metric.hpp +++ b/src/metric/multiclass_metric.hpp @@ -178,5 +178,137 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric& GetName() const override { return name_; } + + double factor_to_bigger_better() const override { return 1.0f; } + + void Init(const Metadata& metadata, data_size_t num_data) override { + name_.emplace_back("auc_mu"); + + num_data_ = num_data; + label_ = metadata.label(); + + // sort the data indices by true class + sorted_data_idx_ = std::vector(num_data_, 0); + for (data_size_t i = 0; i < num_data_; ++i) { + sorted_data_idx_[i] = i; + } + Common::ParallelSort(sorted_data_idx_.begin(), sorted_data_idx_.end(), + [this](data_size_t a, data_size_t b) { return label_[a] < label_[b]; }); + } + + std::vector Eval(const double* score, const ObjectiveFunction*) const override { + // the notation follows that used in the paper introducing the auc-mu metric: + // http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf + + // get size of each class + auto class_sizes = std::vector(num_class_, 0); + for (data_size_t i = 0; i < num_data_; ++i) { + data_size_t curr_label = static_cast(label_[i]); + ++class_sizes[curr_label]; + } + + auto S = std::vector>(num_class_, std::vector(num_class_, 0)); + int i_start = 0; + for (int i = 0; i < num_class_; ++i) { + int j_start = i_start + class_sizes[i]; + for (int j = i + 1; j < num_class_; ++j) { + std::vector curr_v; + for (int k = 0; k < num_class_; ++k) { + curr_v.emplace_back(class_weights_[i][k] - class_weights_[j][k]); + } + double t1 = curr_v[i] - curr_v[j]; + // extract the data indices belonging to class i or j + std::vector class_i_j_indices; + class_i_j_indices.assign(sorted_data_idx_.begin() + i_start, sorted_data_idx_.begin() + i_start + class_sizes[i]); + class_i_j_indices.insert(class_i_j_indices.end(), + sorted_data_idx_.begin() + j_start, sorted_data_idx_.begin() + j_start + class_sizes[j]); + // sort according to distance from separating hyperplane + std::vector> dist; + for (data_size_t k = 0; static_cast(k) < class_i_j_indices.size(); ++k) { + data_size_t a = class_i_j_indices[k]; + double v_a = 0; + for (int m = 0; m < num_class_; ++m) { + v_a += curr_v[m] * score[num_data_ * m + a]; + } + dist.push_back(std::pair(a, t1 * v_a)); + } + Common::ParallelSort(dist.begin(), dist.end(), + [this](std::pair a, std::pair b) { + // if scores are equal, put j class first + if (std::fabs(a.second - b.second) < kEpsilon) { + return label_[a.first] > label_[b.first]; + } + else if (a.second < b.second) { + return true; + } else { + return false; + } + }); + // calculate auc + double num_j = 0; + double last_j_dist = 0; + double num_current_j = 0; + for (size_t k = 0; k < dist.size(); ++k) { + data_size_t a = dist[k].first; + double curr_dist = dist[k].second; + if (label_[a] == i) { + if (std::fabs(curr_dist - last_j_dist) < kEpsilon) { + S[i][j] += num_j - 0.5 * num_current_j; // members of class j with same distance as a contribute 0.5 + } else { + S[i][j] += num_j; + } + } else { + ++num_j; + if (std::fabs(curr_dist - last_j_dist) < kEpsilon) { + ++num_current_j; + } else { + last_j_dist = dist[k].second; + num_current_j = 1; + } + } + } + j_start += class_sizes[j]; + } + i_start += class_sizes[i]; + } + + double ans = 0; + for (int i = 0; i < num_class_; ++i) { + for (int j = i + 1; j < num_class_; ++j) { + ans += S[i][j] / (class_sizes[i] * class_sizes[j]); + } + } + ans = 2 * ans / (num_class_ * (num_class_ - 1)); + return std::vector(1, ans); + } + +private: + /*! \brief Number of data*/ + data_size_t num_data_; + /*! \brief Pointer to label*/ + const label_t* label_; + /*! \brief Name of this metric*/ + std::vector name_; + /*! \brief Number of classes*/ + int num_class_; + /*! \brief class_weights*/ + std::vector> class_weights_; + /*! \brief config parameters*/ + Config config_; + /*! \brief index to data, sorted by true class*/ + std::vector sorted_data_idx_; +}; + } // namespace LightGBM #endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_ diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index fa20ba62b..68bd0a5b5 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -428,6 +428,68 @@ class TestEngine(unittest.TestCase): valid_sets=[lgb_data], evals_result=results, verbose_eval=False) self.assertAlmostEqual(results['training']['multi_error@2'][-1], 0) + def test_auc_mu(self): + # should give same result as binary auc for 2 classes + X, y = load_digits(10, True) + y_new = np.zeros((len(y))) + y_new[y != 0] = 1 + lgb_X = lgb.Dataset(X, label=y_new) + params = {'objective': 'multiclass', + 'metric': 'auc_mu', + 'verbose': -1, + 'num_classes': 2, + 'seed': 0} + results_auc_mu = {} + lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc_mu) + params = {'objective': 'binary', + 'metric': 'auc', + 'verbose': -1, + 'seed': 0} + results_auc = {} + lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc) + np.testing.assert_allclose(results_auc_mu['training']['auc_mu'], results_auc['training']['auc']) + # test the case where all predictions are equal + lgb_X = lgb.Dataset(X[:10], label=y_new[:10]) + params = {'objective': 'multiclass', + 'metric': 'auc_mu', + 'verbose': -1, + 'num_classes': 2, + 'min_data_in_leaf': 20, + 'seed': 0} + results_auc_mu = {} + lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc_mu) + self.assertAlmostEqual(results_auc_mu['training']['auc_mu'][-1], 0.5) + # should give 1 when accuracy = 1 + X = X[:10, :] + y = y[:10] + lgb_X = lgb.Dataset(X, label=y) + params = {'objective': 'multiclass', + 'metric': 'auc_mu', + 'num_classes': 10, + 'min_data_in_leaf': 1, + 'verbose': -1} + results = {} + lgb.train(params, lgb_X, num_boost_round=100, valid_sets=[lgb_X], evals_result=results) + self.assertAlmostEqual(results['training']['auc_mu'][-1], 1) + # test loading weights + Xy = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)), + '../../examples/multiclass_classification/multiclass.train')) + y = Xy[:, 0] + X = Xy[:, 1:] + lgb_X = lgb.Dataset(X, label=y) + params = {'objective': 'multiclass', + 'metric': 'auc_mu', + 'auc_mu_weights': [0, 2, 2, 2, 2, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0], + 'num_classes': 5, + 'verbose': -1, + 'seed': 0} + results_weight = {} + lgb.train(params, lgb_X, num_boost_round=5, valid_sets=[lgb_X], evals_result=results_weight) + params['auc_mu_weights'] = [] + results_no_weight = {} + lgb.train(params, lgb_X, num_boost_round=5, valid_sets=[lgb_X], evals_result=results_no_weight) + self.assertNotEqual(results_weight['training']['auc_mu'][-1], results_no_weight['training']['auc_mu'][-1]) + def test_early_stopping(self): X, y = load_breast_cancer(True) params = {