Switch RMSE to MSE (true L2 loss) (#408)

* RMSE (L2) -> MSE (true L2)

* Remove sqrt unneeded reference

* Square L2 test (RMSE to MSE)

* No square root on test

* Attempt to add RMSE
This commit is contained in:
Laurae 2017-04-13 12:43:41 +02:00 коммит произвёл Guolin Ke
Родитель 18d6a902c3
Коммит ba99bcddc6
4 изменённых файлов: 26 добавлений и 5 удалений

Просмотреть файл

@ -203,9 +203,10 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can
## Metric parameters
* ```metric```, default={```l2``` for regression}, {```binary_logloss``` for binary classification},{```ndcg``` for lambdarank}, type=multi-enum, options=```l1```,```l2```,```ndcg```,```auc```,```binary_logloss```,```binary_error```
* ```metric```, default={```l2``` for regression}, {```binary_logloss``` for binary classification},{```ndcg``` for lambdarank}, type=multi-enum, options=```l1```,```l2```,```ndcg```,```auc```,```binary_logloss```,```binary_error```...
* ```l1```, absolute loss, alias=```mean_absolute_error```, ```mae```
* ```l2```, square loss, alias=```mean_squared_error```, ```mse```
* ```l2_root```, root square loss, alias=```root_mean_squared_error```, ```rmse```
* ```huber```, [Huber loss](https://en.wikipedia.org/wiki/Huber_loss "Huber loss - Wikipedia")
* ```fair```, [Fair loss](https://www.kaggle.com/c/allstate-claims-severity/discussion/24520)
* ```poisson```, [Poisson regression](https://en.wikipedia.org/wiki/Poisson_regression "Poisson regression")

Просмотреть файл

@ -10,6 +10,8 @@ namespace LightGBM {
Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) {
if (type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return new L2Metric(config);
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RMSEMetric(config);
} else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new L1Metric(config);
} else if (type == std::string("huber")) {

Просмотреть файл

@ -111,6 +111,25 @@ private:
std::vector<std::string> name_;
};
/*! \brief RMSE loss for regression task */
class RMSEMetric: public RegressionMetric<RMSEMetric> {
public:
explicit RMSEMetric(const MetricConfig& config) :RegressionMetric<RMSEMetric>(config) {}
inline static double LossOnPoint(float label, double score, double, double) {
return (score - label)*(score - label);
}
inline static double AverageLoss(double sum_loss, double sum_weights) {
// need sqrt the result for RMSE loss
return std::sqrt(sum_loss / sum_weights);
}
inline static const char* Name() {
return "rmse";
}
};
/*! \brief L2 loss for regression task */
class L2Metric: public RegressionMetric<L2Metric> {
public:
@ -121,8 +140,8 @@ public:
}
inline static double AverageLoss(double sum_loss, double sum_weights) {
// need sqrt the result for L2 loss
return std::sqrt(sum_loss / sum_weights);
// need mean of the result for L2 loss
return sum_loss / sum_weights;
}
inline static const char* Name() {

Просмотреть файл

@ -71,8 +71,7 @@ class TestEngine(unittest.TestCase):
def test_regreesion(self):
evals_result, ret = template.test_template()
ret **= 0.5
self.assertLess(ret, 4)
self.assertLess(ret, 16)
self.assertAlmostEqual(min(evals_result['eval']['l2']), ret, places=5)
def test_multiclass(self):