зеркало из https://github.com/microsoft/LightGBM.git
auc-mu metric (#2567)
* Fix bug where small values of max_bin cause crash.
* Revert "Fix bug where small values of max_bin cause crash."
This reverts commit fe5c8e2547
.
* Add auc-mu multiclass metric.
* Fix bug where scores are equal.
* Merge.
* Change name to auc_mu everywhere (instead of auc-mu).
* Fix comparison between signed and unsigned int.
* Change name to AUC-mu in docs and output messages.
* Improve test.
* Use prefix increment.
* Update R package.
* Fix style issues.
* Tidy up test code.
* Read all lines first then process.
* Allow passing AUC-mu weights directly as a list in parameters.
* Remove unused code, improve example and docs.
This commit is contained in:
Родитель
9fd378e2bc
Коммит
222775ca29
|
@ -563,7 +563,7 @@ Booster <- R6::R6Class(
|
|||
# Parse and store privately names
|
||||
names <- strsplit(names, "\t")[[1L]]
|
||||
private$eval_names <- names
|
||||
private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc$", names)
|
||||
private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc", names)
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -881,6 +881,8 @@ Metric Parameters
|
|||
|
||||
- ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification
|
||||
|
||||
- ``auc_mu``, `AUC-mu <http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf>`__
|
||||
|
||||
- ``multi_logloss``, log loss for multi-class classification, aliases: ``multiclass``, ``softmax``, ``multiclassova``, ``multiclass_ova``, ``ova``, ``ovr``
|
||||
|
||||
- ``multi_error``, error rate for multi-class classification
|
||||
|
@ -921,6 +923,18 @@ Metric Parameters
|
|||
|
||||
- when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric
|
||||
|
||||
- ``auc_mu_weights`` :raw-html:`<a id="auc_mu_weights" title="Permalink to this parameter" href="#auc_mu_weights">🔗︎</a>`, default = ``None``, type = multi-double
|
||||
|
||||
- used only with ``auc_mu`` metric
|
||||
|
||||
- list representing flattened matrix (in row-major order) giving loss weights for classification errors
|
||||
|
||||
- list should have ``n * n`` elements, where ``n`` is the number of classes
|
||||
|
||||
- the matrix co-ordinate ``[i, j]`` should correspond to the ``i * n + j``-th element of the list
|
||||
|
||||
- if not specified, will use equal weights for all classes
|
||||
|
||||
Network Parameters
|
||||
------------------
|
||||
|
||||
|
|
|
@ -21,7 +21,16 @@ objective = multiclass
|
|||
# binary_error
|
||||
# multi_logloss
|
||||
# multi_error
|
||||
metric = multi_logloss
|
||||
# auc_mu
|
||||
metric = multi_logloss,auc_mu
|
||||
|
||||
# AUC-mu weights; the matrix of loss weights below is passed in parameter auc_mu_weights as a list
|
||||
# 0 1 2 3 4
|
||||
# 5 0 6 7 8
|
||||
# 9 10 0 11 12
|
||||
# 13 14 15 0 16
|
||||
# 17 18 19 20 0
|
||||
auc_mu_weights = 0,1,2,3,4,5,0,6,7,8,9,10,0,11,12,13,14,15,0,16,17,18,19,20,0
|
||||
|
||||
# number of class, for multiclass classification
|
||||
num_class = 5
|
||||
|
|
|
@ -773,6 +773,7 @@ struct Config {
|
|||
// descl2 = ``auc``, `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
|
||||
// descl2 = ``binary_logloss``, `log loss <https://en.wikipedia.org/wiki/Cross_entropy>`__, aliases: ``binary``
|
||||
// descl2 = ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification
|
||||
// descl2 = ``auc_mu``, `AUC-mu <http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf>`__
|
||||
// descl2 = ``multi_logloss``, log loss for multi-class classification, aliases: ``multiclass``, ``softmax``, ``multiclassova``, ``multiclass_ova``, ``ova``, ``ovr``
|
||||
// descl2 = ``multi_error``, error rate for multi-class classification
|
||||
// descl2 = ``cross_entropy``, cross-entropy (with optional linear weights), aliases: ``xentropy``
|
||||
|
@ -806,6 +807,15 @@ struct Config {
|
|||
// desc = when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric
|
||||
int multi_error_top_k = 1;
|
||||
|
||||
// type = multi-double
|
||||
// default = None
|
||||
// desc = used only with ``auc_mu`` metric
|
||||
// desc = list representing flattened matrix (in row-major order) giving loss weights for classification errors
|
||||
// desc = list should have ``n * n`` elements, where ``n`` is the number of classes
|
||||
// desc = the matrix co-ordinate ``[i, j]`` should correspond to the ``i * n + j``-th element of the list
|
||||
// desc = if not specified, will use equal weights for all classes
|
||||
std::vector<double> auc_mu_weights;
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region Network Parameters
|
||||
|
@ -863,11 +873,13 @@ struct Config {
|
|||
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
|
||||
static std::unordered_map<std::string, std::string> alias_table;
|
||||
static std::unordered_set<std::string> parameter_set;
|
||||
std::vector<std::vector<double>> auc_mu_weights_matrix;
|
||||
|
||||
private:
|
||||
void CheckParamConflict();
|
||||
void GetMembersFromString(const std::unordered_map<std::string, std::string>& params);
|
||||
std::string SaveMembersToString() const;
|
||||
void GetAucMuWeights();
|
||||
};
|
||||
|
||||
inline bool Config::GetString(
|
||||
|
@ -1013,6 +1025,8 @@ inline std::string ParseMetricAlias(const std::string& type) {
|
|||
return "kullback_leibler";
|
||||
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
|
||||
return "mape";
|
||||
} else if (type == std::string("auc_mu")) {
|
||||
return "auc_mu";
|
||||
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
|
||||
return "custom";
|
||||
}
|
||||
|
@ -1021,4 +1035,4 @@ inline std::string ParseMetricAlias(const std::string& type) {
|
|||
|
||||
} // namespace LightGBM
|
||||
|
||||
#endif // LightGBM_CONFIG_H_
|
||||
#endif // LightGBM_CONFIG_H_
|
||||
|
|
|
@ -153,6 +153,36 @@ void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& para
|
|||
}
|
||||
}
|
||||
|
||||
void Config::GetAucMuWeights() {
|
||||
if (auc_mu_weights.empty()) {
|
||||
// equal weights for all classes
|
||||
auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 1));
|
||||
for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
|
||||
auc_mu_weights_matrix[i][i] = 0;
|
||||
}
|
||||
} else {
|
||||
auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 0));
|
||||
if (auc_mu_weights.size() != static_cast<size_t>(num_class * num_class)) {
|
||||
Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size());
|
||||
}
|
||||
for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
|
||||
for (size_t j = 0; j < static_cast<size_t>(num_class); ++j) {
|
||||
if (i == j) {
|
||||
auc_mu_weights_matrix[i][j] = 0;
|
||||
if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) {
|
||||
Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j);
|
||||
}
|
||||
} else {
|
||||
if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) {
|
||||
Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j);
|
||||
}
|
||||
auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
|
||||
// generate seeds by seed.
|
||||
if (GetInt(params, "seed", &seed)) {
|
||||
|
@ -173,6 +203,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
|
|||
|
||||
GetMembersFromString(params);
|
||||
|
||||
GetAucMuWeights();
|
||||
|
||||
// sort eval_at
|
||||
std::sort(eval_at.begin(), eval_at.end());
|
||||
|
||||
|
@ -230,6 +262,7 @@ void Config::CheckParamConflict() {
|
|||
bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
|
||||
|| metric_type == std::string("multi_logloss")
|
||||
|| metric_type == std::string("multi_error")
|
||||
|| metric_type == std::string("auc_mu")
|
||||
|| (metric_type == std::string("custom") && num_class_check > 1));
|
||||
if ((objective_type_multiclass && !metric_type_multiclass)
|
||||
|| (!objective_type_multiclass && metric_type_multiclass)) {
|
||||
|
|
|
@ -276,6 +276,7 @@ std::unordered_set<std::string> Config::parameter_set({
|
|||
"is_provide_training_metric",
|
||||
"eval_at",
|
||||
"multi_error_top_k",
|
||||
"auc_mu_weights",
|
||||
"num_machines",
|
||||
"local_listen_port",
|
||||
"time_out",
|
||||
|
@ -561,6 +562,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
|
|||
GetInt(params, "multi_error_top_k", &multi_error_top_k);
|
||||
CHECK(multi_error_top_k >0);
|
||||
|
||||
if (GetString(params, "auc_mu_weights", &tmp_str)) {
|
||||
auc_mu_weights = Common::StringToArray<double>(tmp_str, ',');
|
||||
}
|
||||
|
||||
GetInt(params, "num_machines", &num_machines);
|
||||
CHECK(num_machines >0);
|
||||
|
||||
|
@ -683,6 +688,7 @@ std::string Config::SaveMembersToString() const {
|
|||
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
|
||||
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
|
||||
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
|
||||
str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n";
|
||||
str_buf << "[num_machines: " << num_machines << "]\n";
|
||||
str_buf << "[local_listen_port: " << local_listen_port << "]\n";
|
||||
str_buf << "[time_out: " << time_out << "]\n";
|
||||
|
|
|
@ -34,6 +34,8 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
|
|||
return new BinaryErrorMetric(config);
|
||||
} else if (type == std::string("auc")) {
|
||||
return new AUCMetric(config);
|
||||
} else if (type == std::string("auc_mu")) {
|
||||
return new AucMuMetric(config);
|
||||
} else if (type == std::string("ndcg")) {
|
||||
return new NDCGMetric(config);
|
||||
} else if (type == std::string("map")) {
|
||||
|
|
|
@ -178,5 +178,137 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
|
|||
}
|
||||
};
|
||||
|
||||
/*! \brief Auc-mu for multiclass task*/
|
||||
class AucMuMetric : public Metric {
|
||||
public:
|
||||
explicit AucMuMetric(const Config& config) : config_(config) {
|
||||
num_class_ = config.num_class;
|
||||
class_weights_ = config.auc_mu_weights_matrix;
|
||||
}
|
||||
|
||||
virtual ~AucMuMetric() {}
|
||||
|
||||
const std::vector<std::string>& GetName() const override { return name_; }
|
||||
|
||||
double factor_to_bigger_better() const override { return 1.0f; }
|
||||
|
||||
void Init(const Metadata& metadata, data_size_t num_data) override {
|
||||
name_.emplace_back("auc_mu");
|
||||
|
||||
num_data_ = num_data;
|
||||
label_ = metadata.label();
|
||||
|
||||
// sort the data indices by true class
|
||||
sorted_data_idx_ = std::vector<data_size_t>(num_data_, 0);
|
||||
for (data_size_t i = 0; i < num_data_; ++i) {
|
||||
sorted_data_idx_[i] = i;
|
||||
}
|
||||
Common::ParallelSort(sorted_data_idx_.begin(), sorted_data_idx_.end(),
|
||||
[this](data_size_t a, data_size_t b) { return label_[a] < label_[b]; });
|
||||
}
|
||||
|
||||
std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
|
||||
// the notation follows that used in the paper introducing the auc-mu metric:
|
||||
// http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf
|
||||
|
||||
// get size of each class
|
||||
auto class_sizes = std::vector<data_size_t>(num_class_, 0);
|
||||
for (data_size_t i = 0; i < num_data_; ++i) {
|
||||
data_size_t curr_label = static_cast<data_size_t>(label_[i]);
|
||||
++class_sizes[curr_label];
|
||||
}
|
||||
|
||||
auto S = std::vector<std::vector<double>>(num_class_, std::vector<double>(num_class_, 0));
|
||||
int i_start = 0;
|
||||
for (int i = 0; i < num_class_; ++i) {
|
||||
int j_start = i_start + class_sizes[i];
|
||||
for (int j = i + 1; j < num_class_; ++j) {
|
||||
std::vector<double> curr_v;
|
||||
for (int k = 0; k < num_class_; ++k) {
|
||||
curr_v.emplace_back(class_weights_[i][k] - class_weights_[j][k]);
|
||||
}
|
||||
double t1 = curr_v[i] - curr_v[j];
|
||||
// extract the data indices belonging to class i or j
|
||||
std::vector<data_size_t> class_i_j_indices;
|
||||
class_i_j_indices.assign(sorted_data_idx_.begin() + i_start, sorted_data_idx_.begin() + i_start + class_sizes[i]);
|
||||
class_i_j_indices.insert(class_i_j_indices.end(),
|
||||
sorted_data_idx_.begin() + j_start, sorted_data_idx_.begin() + j_start + class_sizes[j]);
|
||||
// sort according to distance from separating hyperplane
|
||||
std::vector<std::pair<data_size_t, double>> dist;
|
||||
for (data_size_t k = 0; static_cast<size_t>(k) < class_i_j_indices.size(); ++k) {
|
||||
data_size_t a = class_i_j_indices[k];
|
||||
double v_a = 0;
|
||||
for (int m = 0; m < num_class_; ++m) {
|
||||
v_a += curr_v[m] * score[num_data_ * m + a];
|
||||
}
|
||||
dist.push_back(std::pair<data_size_t, double>(a, t1 * v_a));
|
||||
}
|
||||
Common::ParallelSort(dist.begin(), dist.end(),
|
||||
[this](std::pair<data_size_t, double> a, std::pair<data_size_t, double> b) {
|
||||
// if scores are equal, put j class first
|
||||
if (std::fabs(a.second - b.second) < kEpsilon) {
|
||||
return label_[a.first] > label_[b.first];
|
||||
}
|
||||
else if (a.second < b.second) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
// calculate auc
|
||||
double num_j = 0;
|
||||
double last_j_dist = 0;
|
||||
double num_current_j = 0;
|
||||
for (size_t k = 0; k < dist.size(); ++k) {
|
||||
data_size_t a = dist[k].first;
|
||||
double curr_dist = dist[k].second;
|
||||
if (label_[a] == i) {
|
||||
if (std::fabs(curr_dist - last_j_dist) < kEpsilon) {
|
||||
S[i][j] += num_j - 0.5 * num_current_j; // members of class j with same distance as a contribute 0.5
|
||||
} else {
|
||||
S[i][j] += num_j;
|
||||
}
|
||||
} else {
|
||||
++num_j;
|
||||
if (std::fabs(curr_dist - last_j_dist) < kEpsilon) {
|
||||
++num_current_j;
|
||||
} else {
|
||||
last_j_dist = dist[k].second;
|
||||
num_current_j = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
j_start += class_sizes[j];
|
||||
}
|
||||
i_start += class_sizes[i];
|
||||
}
|
||||
|
||||
double ans = 0;
|
||||
for (int i = 0; i < num_class_; ++i) {
|
||||
for (int j = i + 1; j < num_class_; ++j) {
|
||||
ans += S[i][j] / (class_sizes[i] * class_sizes[j]);
|
||||
}
|
||||
}
|
||||
ans = 2 * ans / (num_class_ * (num_class_ - 1));
|
||||
return std::vector<double>(1, ans);
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Number of data*/
|
||||
data_size_t num_data_;
|
||||
/*! \brief Pointer to label*/
|
||||
const label_t* label_;
|
||||
/*! \brief Name of this metric*/
|
||||
std::vector<std::string> name_;
|
||||
/*! \brief Number of classes*/
|
||||
int num_class_;
|
||||
/*! \brief class_weights*/
|
||||
std::vector<std::vector<double>> class_weights_;
|
||||
/*! \brief config parameters*/
|
||||
Config config_;
|
||||
/*! \brief index to data, sorted by true class*/
|
||||
std::vector<data_size_t> sorted_data_idx_;
|
||||
};
|
||||
|
||||
} // namespace LightGBM
|
||||
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
|
||||
|
|
|
@ -428,6 +428,68 @@ class TestEngine(unittest.TestCase):
|
|||
valid_sets=[lgb_data], evals_result=results, verbose_eval=False)
|
||||
self.assertAlmostEqual(results['training']['multi_error@2'][-1], 0)
|
||||
|
||||
def test_auc_mu(self):
|
||||
# should give same result as binary auc for 2 classes
|
||||
X, y = load_digits(10, True)
|
||||
y_new = np.zeros((len(y)))
|
||||
y_new[y != 0] = 1
|
||||
lgb_X = lgb.Dataset(X, label=y_new)
|
||||
params = {'objective': 'multiclass',
|
||||
'metric': 'auc_mu',
|
||||
'verbose': -1,
|
||||
'num_classes': 2,
|
||||
'seed': 0}
|
||||
results_auc_mu = {}
|
||||
lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc_mu)
|
||||
params = {'objective': 'binary',
|
||||
'metric': 'auc',
|
||||
'verbose': -1,
|
||||
'seed': 0}
|
||||
results_auc = {}
|
||||
lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc)
|
||||
np.testing.assert_allclose(results_auc_mu['training']['auc_mu'], results_auc['training']['auc'])
|
||||
# test the case where all predictions are equal
|
||||
lgb_X = lgb.Dataset(X[:10], label=y_new[:10])
|
||||
params = {'objective': 'multiclass',
|
||||
'metric': 'auc_mu',
|
||||
'verbose': -1,
|
||||
'num_classes': 2,
|
||||
'min_data_in_leaf': 20,
|
||||
'seed': 0}
|
||||
results_auc_mu = {}
|
||||
lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc_mu)
|
||||
self.assertAlmostEqual(results_auc_mu['training']['auc_mu'][-1], 0.5)
|
||||
# should give 1 when accuracy = 1
|
||||
X = X[:10, :]
|
||||
y = y[:10]
|
||||
lgb_X = lgb.Dataset(X, label=y)
|
||||
params = {'objective': 'multiclass',
|
||||
'metric': 'auc_mu',
|
||||
'num_classes': 10,
|
||||
'min_data_in_leaf': 1,
|
||||
'verbose': -1}
|
||||
results = {}
|
||||
lgb.train(params, lgb_X, num_boost_round=100, valid_sets=[lgb_X], evals_result=results)
|
||||
self.assertAlmostEqual(results['training']['auc_mu'][-1], 1)
|
||||
# test loading weights
|
||||
Xy = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
'../../examples/multiclass_classification/multiclass.train'))
|
||||
y = Xy[:, 0]
|
||||
X = Xy[:, 1:]
|
||||
lgb_X = lgb.Dataset(X, label=y)
|
||||
params = {'objective': 'multiclass',
|
||||
'metric': 'auc_mu',
|
||||
'auc_mu_weights': [0, 2, 2, 2, 2, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
|
||||
'num_classes': 5,
|
||||
'verbose': -1,
|
||||
'seed': 0}
|
||||
results_weight = {}
|
||||
lgb.train(params, lgb_X, num_boost_round=5, valid_sets=[lgb_X], evals_result=results_weight)
|
||||
params['auc_mu_weights'] = []
|
||||
results_no_weight = {}
|
||||
lgb.train(params, lgb_X, num_boost_round=5, valid_sets=[lgb_X], evals_result=results_no_weight)
|
||||
self.assertNotEqual(results_weight['training']['auc_mu'][-1], results_no_weight['training']['auc_mu'][-1])
|
||||
|
||||
def test_early_stopping(self):
|
||||
X, y = load_breast_cancer(True)
|
||||
params = {
|
||||
|
|
Загрузка…
Ссылка в новой задаче