* Fix bug where small values of max_bin cause crash.

* Revert "Fix bug where small values of max_bin cause crash."

This reverts commit fe5c8e2547.

* Add auc-mu multiclass metric.

* Fix bug where scores are equal.

* Merge.

* Change name to auc_mu everywhere (instead of auc-mu).

* Fix comparison between signed and unsigned int.

* Change name to AUC-mu in docs and output messages.

* Improve test.

* Use prefix increment.

* Update R package.

* Fix style issues.

* Tidy up test code.

* Read all lines first then process.

* Allow passing AUC-mu weights directly as a list in parameters.

* Remove unused code, improve example and docs.
This commit is contained in:
Belinda Trotta 2019-12-13 07:27:27 +11:00 коммит произвёл Nikita Titov
Родитель 9fd378e2bc
Коммит 222775ca29
9 изменённых файлов: 275 добавлений и 3 удалений

Просмотреть файл

@ -563,7 +563,7 @@ Booster <- R6::R6Class(
# Parse and store privately names
names <- strsplit(names, "\t")[[1L]]
private$eval_names <- names
private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc$", names)
private$higher_better_inner_eval <- grepl("^ndcg|^map|^auc", names)
}

Просмотреть файл

@ -881,6 +881,8 @@ Metric Parameters
- ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification
- ``auc_mu``, `AUC-mu <http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf>`__
- ``multi_logloss``, log loss for multi-class classification, aliases: ``multiclass``, ``softmax``, ``multiclassova``, ``multiclass_ova``, ``ova``, ``ovr``
- ``multi_error``, error rate for multi-class classification
@ -921,6 +923,18 @@ Metric Parameters
- when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric
- ``auc_mu_weights`` :raw-html:`<a id="auc_mu_weights" title="Permalink to this parameter" href="#auc_mu_weights">&#x1F517;&#xFE0E;</a>`, default = ``None``, type = multi-double
- used only with ``auc_mu`` metric
- list representing flattened matrix (in row-major order) giving loss weights for classification errors
- list should have ``n * n`` elements, where ``n`` is the number of classes
- the matrix co-ordinate ``[i, j]`` should correspond to the ``i * n + j``-th element of the list
- if not specified, will use equal weights for all classes
Network Parameters
------------------

Просмотреть файл

@ -21,7 +21,16 @@ objective = multiclass
# binary_error
# multi_logloss
# multi_error
metric = multi_logloss
# auc_mu
metric = multi_logloss,auc_mu
# AUC-mu weights; the matrix of loss weights below is passed in parameter auc_mu_weights as a list
# 0 1 2 3 4
# 5 0 6 7 8
# 9 10 0 11 12
# 13 14 15 0 16
# 17 18 19 20 0
auc_mu_weights = 0,1,2,3,4,5,0,6,7,8,9,10,0,11,12,13,14,15,0,16,17,18,19,20,0
# number of class, for multiclass classification
num_class = 5

Просмотреть файл

@ -773,6 +773,7 @@ struct Config {
// descl2 = ``auc``, `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
// descl2 = ``binary_logloss``, `log loss <https://en.wikipedia.org/wiki/Cross_entropy>`__, aliases: ``binary``
// descl2 = ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification
// descl2 = ``auc_mu``, `AUC-mu <http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf>`__
// descl2 = ``multi_logloss``, log loss for multi-class classification, aliases: ``multiclass``, ``softmax``, ``multiclassova``, ``multiclass_ova``, ``ova``, ``ovr``
// descl2 = ``multi_error``, error rate for multi-class classification
// descl2 = ``cross_entropy``, cross-entropy (with optional linear weights), aliases: ``xentropy``
@ -806,6 +807,15 @@ struct Config {
// desc = when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric
int multi_error_top_k = 1;
// type = multi-double
// default = None
// desc = used only with ``auc_mu`` metric
// desc = list representing flattened matrix (in row-major order) giving loss weights for classification errors
// desc = list should have ``n * n`` elements, where ``n`` is the number of classes
// desc = the matrix co-ordinate ``[i, j]`` should correspond to the ``i * n + j``-th element of the list
// desc = if not specified, will use equal weights for all classes
std::vector<double> auc_mu_weights;
#pragma endregion
#pragma region Network Parameters
@ -863,11 +873,13 @@ struct Config {
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
static std::unordered_map<std::string, std::string> alias_table;
static std::unordered_set<std::string> parameter_set;
std::vector<std::vector<double>> auc_mu_weights_matrix;
private:
void CheckParamConflict();
void GetMembersFromString(const std::unordered_map<std::string, std::string>& params);
std::string SaveMembersToString() const;
void GetAucMuWeights();
};
inline bool Config::GetString(
@ -1013,6 +1025,8 @@ inline std::string ParseMetricAlias(const std::string& type) {
return "kullback_leibler";
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
return "mape";
} else if (type == std::string("auc_mu")) {
return "auc_mu";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom";
}
@ -1021,4 +1035,4 @@ inline std::string ParseMetricAlias(const std::string& type) {
} // namespace LightGBM
#endif // LightGBM_CONFIG_H_
#endif // LightGBM_CONFIG_H_

Просмотреть файл

@ -153,6 +153,36 @@ void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& para
}
}
void Config::GetAucMuWeights() {
if (auc_mu_weights.empty()) {
// equal weights for all classes
auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 1));
for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
auc_mu_weights_matrix[i][i] = 0;
}
} else {
auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 0));
if (auc_mu_weights.size() != static_cast<size_t>(num_class * num_class)) {
Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size());
}
for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
for (size_t j = 0; j < static_cast<size_t>(num_class); ++j) {
if (i == j) {
auc_mu_weights_matrix[i][j] = 0;
if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) {
Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j);
}
} else {
if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) {
Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j);
}
auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j];
}
}
}
}
};
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
@ -173,6 +203,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
GetMembersFromString(params);
GetAucMuWeights();
// sort eval_at
std::sort(eval_at.begin(), eval_at.end());
@ -230,6 +262,7 @@ void Config::CheckParamConflict() {
bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
|| metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")
|| metric_type == std::string("auc_mu")
|| (metric_type == std::string("custom") && num_class_check > 1));
if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) {

Просмотреть файл

@ -276,6 +276,7 @@ std::unordered_set<std::string> Config::parameter_set({
"is_provide_training_metric",
"eval_at",
"multi_error_top_k",
"auc_mu_weights",
"num_machines",
"local_listen_port",
"time_out",
@ -561,6 +562,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "multi_error_top_k", &multi_error_top_k);
CHECK(multi_error_top_k >0);
if (GetString(params, "auc_mu_weights", &tmp_str)) {
auc_mu_weights = Common::StringToArray<double>(tmp_str, ',');
}
GetInt(params, "num_machines", &num_machines);
CHECK(num_machines >0);
@ -683,6 +688,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n";
str_buf << "[num_machines: " << num_machines << "]\n";
str_buf << "[local_listen_port: " << local_listen_port << "]\n";
str_buf << "[time_out: " << time_out << "]\n";

Просмотреть файл

@ -34,6 +34,8 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
return new BinaryErrorMetric(config);
} else if (type == std::string("auc")) {
return new AUCMetric(config);
} else if (type == std::string("auc_mu")) {
return new AucMuMetric(config);
} else if (type == std::string("ndcg")) {
return new NDCGMetric(config);
} else if (type == std::string("map")) {

Просмотреть файл

@ -178,5 +178,137 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
}
};
/*! \brief Auc-mu for multiclass task*/
class AucMuMetric : public Metric {
public:
explicit AucMuMetric(const Config& config) : config_(config) {
num_class_ = config.num_class;
class_weights_ = config.auc_mu_weights_matrix;
}
virtual ~AucMuMetric() {}
const std::vector<std::string>& GetName() const override { return name_; }
double factor_to_bigger_better() const override { return 1.0f; }
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back("auc_mu");
num_data_ = num_data;
label_ = metadata.label();
// sort the data indices by true class
sorted_data_idx_ = std::vector<data_size_t>(num_data_, 0);
for (data_size_t i = 0; i < num_data_; ++i) {
sorted_data_idx_[i] = i;
}
Common::ParallelSort(sorted_data_idx_.begin(), sorted_data_idx_.end(),
[this](data_size_t a, data_size_t b) { return label_[a] < label_[b]; });
}
std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
// the notation follows that used in the paper introducing the auc-mu metric:
// http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf
// get size of each class
auto class_sizes = std::vector<data_size_t>(num_class_, 0);
for (data_size_t i = 0; i < num_data_; ++i) {
data_size_t curr_label = static_cast<data_size_t>(label_[i]);
++class_sizes[curr_label];
}
auto S = std::vector<std::vector<double>>(num_class_, std::vector<double>(num_class_, 0));
int i_start = 0;
for (int i = 0; i < num_class_; ++i) {
int j_start = i_start + class_sizes[i];
for (int j = i + 1; j < num_class_; ++j) {
std::vector<double> curr_v;
for (int k = 0; k < num_class_; ++k) {
curr_v.emplace_back(class_weights_[i][k] - class_weights_[j][k]);
}
double t1 = curr_v[i] - curr_v[j];
// extract the data indices belonging to class i or j
std::vector<data_size_t> class_i_j_indices;
class_i_j_indices.assign(sorted_data_idx_.begin() + i_start, sorted_data_idx_.begin() + i_start + class_sizes[i]);
class_i_j_indices.insert(class_i_j_indices.end(),
sorted_data_idx_.begin() + j_start, sorted_data_idx_.begin() + j_start + class_sizes[j]);
// sort according to distance from separating hyperplane
std::vector<std::pair<data_size_t, double>> dist;
for (data_size_t k = 0; static_cast<size_t>(k) < class_i_j_indices.size(); ++k) {
data_size_t a = class_i_j_indices[k];
double v_a = 0;
for (int m = 0; m < num_class_; ++m) {
v_a += curr_v[m] * score[num_data_ * m + a];
}
dist.push_back(std::pair<data_size_t, double>(a, t1 * v_a));
}
Common::ParallelSort(dist.begin(), dist.end(),
[this](std::pair<data_size_t, double> a, std::pair<data_size_t, double> b) {
// if scores are equal, put j class first
if (std::fabs(a.second - b.second) < kEpsilon) {
return label_[a.first] > label_[b.first];
}
else if (a.second < b.second) {
return true;
} else {
return false;
}
});
// calculate auc
double num_j = 0;
double last_j_dist = 0;
double num_current_j = 0;
for (size_t k = 0; k < dist.size(); ++k) {
data_size_t a = dist[k].first;
double curr_dist = dist[k].second;
if (label_[a] == i) {
if (std::fabs(curr_dist - last_j_dist) < kEpsilon) {
S[i][j] += num_j - 0.5 * num_current_j; // members of class j with same distance as a contribute 0.5
} else {
S[i][j] += num_j;
}
} else {
++num_j;
if (std::fabs(curr_dist - last_j_dist) < kEpsilon) {
++num_current_j;
} else {
last_j_dist = dist[k].second;
num_current_j = 1;
}
}
}
j_start += class_sizes[j];
}
i_start += class_sizes[i];
}
double ans = 0;
for (int i = 0; i < num_class_; ++i) {
for (int j = i + 1; j < num_class_; ++j) {
ans += S[i][j] / (class_sizes[i] * class_sizes[j]);
}
}
ans = 2 * ans / (num_class_ * (num_class_ - 1));
return std::vector<double>(1, ans);
}
private:
/*! \brief Number of data*/
data_size_t num_data_;
/*! \brief Pointer to label*/
const label_t* label_;
/*! \brief Name of this metric*/
std::vector<std::string> name_;
/*! \brief Number of classes*/
int num_class_;
/*! \brief class_weights*/
std::vector<std::vector<double>> class_weights_;
/*! \brief config parameters*/
Config config_;
/*! \brief index to data, sorted by true class*/
std::vector<data_size_t> sorted_data_idx_;
};
} // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_

Просмотреть файл

@ -428,6 +428,68 @@ class TestEngine(unittest.TestCase):
valid_sets=[lgb_data], evals_result=results, verbose_eval=False)
self.assertAlmostEqual(results['training']['multi_error@2'][-1], 0)
def test_auc_mu(self):
# should give same result as binary auc for 2 classes
X, y = load_digits(10, True)
y_new = np.zeros((len(y)))
y_new[y != 0] = 1
lgb_X = lgb.Dataset(X, label=y_new)
params = {'objective': 'multiclass',
'metric': 'auc_mu',
'verbose': -1,
'num_classes': 2,
'seed': 0}
results_auc_mu = {}
lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc_mu)
params = {'objective': 'binary',
'metric': 'auc',
'verbose': -1,
'seed': 0}
results_auc = {}
lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc)
np.testing.assert_allclose(results_auc_mu['training']['auc_mu'], results_auc['training']['auc'])
# test the case where all predictions are equal
lgb_X = lgb.Dataset(X[:10], label=y_new[:10])
params = {'objective': 'multiclass',
'metric': 'auc_mu',
'verbose': -1,
'num_classes': 2,
'min_data_in_leaf': 20,
'seed': 0}
results_auc_mu = {}
lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=results_auc_mu)
self.assertAlmostEqual(results_auc_mu['training']['auc_mu'][-1], 0.5)
# should give 1 when accuracy = 1
X = X[:10, :]
y = y[:10]
lgb_X = lgb.Dataset(X, label=y)
params = {'objective': 'multiclass',
'metric': 'auc_mu',
'num_classes': 10,
'min_data_in_leaf': 1,
'verbose': -1}
results = {}
lgb.train(params, lgb_X, num_boost_round=100, valid_sets=[lgb_X], evals_result=results)
self.assertAlmostEqual(results['training']['auc_mu'][-1], 1)
# test loading weights
Xy = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/multiclass_classification/multiclass.train'))
y = Xy[:, 0]
X = Xy[:, 1:]
lgb_X = lgb.Dataset(X, label=y)
params = {'objective': 'multiclass',
'metric': 'auc_mu',
'auc_mu_weights': [0, 2, 2, 2, 2, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
'num_classes': 5,
'verbose': -1,
'seed': 0}
results_weight = {}
lgb.train(params, lgb_X, num_boost_round=5, valid_sets=[lgb_X], evals_result=results_weight)
params['auc_mu_weights'] = []
results_no_weight = {}
lgb.train(params, lgb_X, num_boost_round=5, valid_sets=[lgb_X], evals_result=results_no_weight)
self.assertNotEqual(results_weight['training']['auc_mu'][-1], results_no_weight['training']['auc_mu'][-1])
def test_early_stopping(self):
X, y = load_breast_cancer(True)
params = {