зеркало из https://github.com/microsoft/LightGBM.git
fix metric alias (#2273)
* fix metric alias * fix format * updated docs * simplify alias in objective function * move the alias parsing to config.cpp * updated docs * fix multi-class aliases * updated regression aliases in docs * fixed trailing space
This commit is contained in:
Родитель
716fe4d015
Коммит
5d3a3ea47e
|
@ -51,13 +51,13 @@ Core Parameters
|
|||
|
||||
- **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions
|
||||
|
||||
- ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">🔗︎</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gamma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``, aliases: ``objective_type``, ``app``, ``application``
|
||||
- ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">🔗︎</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gamma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``cross_entropy``, ``cross_entropy_lambda``, ``lambdarank``, aliases: ``objective_type``, ``app``, ``application``
|
||||
|
||||
- regression application
|
||||
|
||||
- ``regression_l2``, L2 loss, aliases: ``regression``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse``
|
||||
- ``regression``, L2 loss, aliases: ``regression_l2``, ``l2``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse``
|
||||
|
||||
- ``regression_l1``, L1 loss, aliases: ``mean_absolute_error``, ``mae``
|
||||
- ``regression_l1``, L1 loss, aliases: ``l1``, ``mean_absolute_error``, ``mae``
|
||||
|
||||
- ``huber``, `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`__
|
||||
|
||||
|
@ -85,9 +85,9 @@ Core Parameters
|
|||
|
||||
- cross-entropy application
|
||||
|
||||
- ``xentropy``, objective function for cross-entropy (with optional linear weights), aliases: ``cross_entropy``
|
||||
- ``cross_entropy``, objective function for cross-entropy (with optional linear weights), aliases: ``xentropy``
|
||||
|
||||
- ``xentlambda``, alternative parameterization of cross-entropy, aliases: ``cross_entropy_lambda``
|
||||
- ``cross_entropy_lambda``, alternative parameterization of cross-entropy, aliases: ``xentlambda``
|
||||
|
||||
- label is anything in interval [0, 1]
|
||||
|
||||
|
@ -857,11 +857,11 @@ Metric Parameters
|
|||
|
||||
- ``multi_error``, error rate for multi-class classification
|
||||
|
||||
- ``xentropy``, cross-entropy (with optional linear weights), aliases: ``cross_entropy``
|
||||
- ``cross_entropy``, cross-entropy (with optional linear weights), aliases: ``xentropy``
|
||||
|
||||
- ``xentlambda``, "intensity-weighted" cross-entropy, aliases: ``cross_entropy_lambda``
|
||||
- ``cross_entropy_lambda``, "intensity-weighted" cross-entropy, aliases: ``xentlambda``
|
||||
|
||||
- ``kldiv``, `Kullback-Leibler divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`__, aliases: ``kullback_leibler``
|
||||
- ``kullback_leibler``, `Kullback-Leibler divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`__, aliases: ``kldiv``
|
||||
|
||||
- support multiple metrics, separated by ``,``
|
||||
|
||||
|
|
|
@ -102,11 +102,11 @@ struct Config {
|
|||
|
||||
// [doc-only]
|
||||
// type = enum
|
||||
// options = regression, regression_l1, huber, fair, poisson, quantile, mape, gamma, tweedie, binary, multiclass, multiclassova, xentropy, xentlambda, lambdarank
|
||||
// options = regression, regression_l1, huber, fair, poisson, quantile, mape, gamma, tweedie, binary, multiclass, multiclassova, cross_entropy, cross_entropy_lambda, lambdarank
|
||||
// alias = objective_type, app, application
|
||||
// desc = regression application
|
||||
// descl2 = ``regression_l2``, L2 loss, aliases: ``regression``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse``
|
||||
// descl2 = ``regression_l1``, L1 loss, aliases: ``mean_absolute_error``, ``mae``
|
||||
// descl2 = ``regression``, L2 loss, aliases: ``regression_l2``, ``l2``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse``
|
||||
// descl2 = ``regression_l1``, L1 loss, aliases: ``l1``, ``mean_absolute_error``, ``mae``
|
||||
// descl2 = ``huber``, `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`__
|
||||
// descl2 = ``fair``, `Fair loss <https://www.kaggle.com/c/allstate-claims-severity/discussion/24520>`__
|
||||
// descl2 = ``poisson``, `Poisson regression <https://en.wikipedia.org/wiki/Poisson_regression>`__
|
||||
|
@ -120,8 +120,8 @@ struct Config {
|
|||
// descl2 = ``multiclassova``, `One-vs-All <https://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest>`__ binary objective function, aliases: ``multiclass_ova``, ``ova``, ``ovr``
|
||||
// descl2 = ``num_class`` should be set as well
|
||||
// desc = cross-entropy application
|
||||
// descl2 = ``xentropy``, objective function for cross-entropy (with optional linear weights), aliases: ``cross_entropy``
|
||||
// descl2 = ``xentlambda``, alternative parameterization of cross-entropy, aliases: ``cross_entropy_lambda``
|
||||
// descl2 = ``cross_entropy``, objective function for cross-entropy (with optional linear weights), aliases: ``xentropy``
|
||||
// descl2 = ``cross_entropy_lambda``, alternative parameterization of cross-entropy, aliases: ``xentlambda``
|
||||
// descl2 = label is anything in interval [0, 1]
|
||||
// desc = ``lambdarank``, `lambdarank <https://papers.nips.cc/paper/2971-learning-to-rank-with-nonsmooth-cost-functions.pdf>`__ application
|
||||
// descl2 = label should be ``int`` type in lambdarank tasks, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)
|
||||
|
@ -754,9 +754,9 @@ struct Config {
|
|||
// descl2 = ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification
|
||||
// descl2 = ``multi_logloss``, log loss for multi-class classification, aliases: ``multiclass``, ``softmax``, ``multiclassova``, ``multiclass_ova``, ``ova``, ``ovr``
|
||||
// descl2 = ``multi_error``, error rate for multi-class classification
|
||||
// descl2 = ``xentropy``, cross-entropy (with optional linear weights), aliases: ``cross_entropy``
|
||||
// descl2 = ``xentlambda``, "intensity-weighted" cross-entropy, aliases: ``cross_entropy_lambda``
|
||||
// descl2 = ``kldiv``, `Kullback-Leibler divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`__, aliases: ``kullback_leibler``
|
||||
// descl2 = ``cross_entropy``, cross-entropy (with optional linear weights), aliases: ``xentropy``
|
||||
// descl2 = ``cross_entropy_lambda``, "intensity-weighted" cross-entropy, aliases: ``xentlambda``
|
||||
// descl2 = ``kullback_leibler``, `Kullback-Leibler divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`__, aliases: ``kldiv``
|
||||
// desc = support multiple metrics, separated by ``,``
|
||||
std::vector<std::string> metric;
|
||||
|
||||
|
|
|
@ -63,41 +63,91 @@ void GetBoostingType(const std::unordered_map<std::string, std::string>& params,
|
|||
}
|
||||
}
|
||||
|
||||
std::string ParseObjectiveAlias(const std::string& type) {
|
||||
if (type == std::string("regression") || type == std::string("regression_l2")
|
||||
|| type == std::string("mean_squared_error") || type == std::string("mse") || type == std::string("l2")
|
||||
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
|
||||
return "regression";
|
||||
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error")
|
||||
|| type == std::string("l1") || type == std::string("mae")) {
|
||||
return "regression_l1";
|
||||
} else if (type == std::string("multiclass") || type == std::string("softmax")) {
|
||||
return "multiclass";
|
||||
} else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
|
||||
return "multiclassova";
|
||||
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
|
||||
return "cross_entropy";
|
||||
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
|
||||
return "cross_entropy_lambda";
|
||||
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
|
||||
return "mape";
|
||||
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
|
||||
return "custom";
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
std::string ParseMetricAlias(const std::string& type) {
|
||||
if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
|
||||
return "l2";
|
||||
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
|
||||
return "rmse";
|
||||
} else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
|
||||
return "l1";
|
||||
} else if (type == std::string("binary_logloss") || type == std::string("binary")) {
|
||||
return "binary_logloss";
|
||||
} else if (type == std::string("ndcg") || type == std::string("lambdarank")) {
|
||||
return "ndcg";
|
||||
} else if (type == std::string("map") || type == std::string("mean_average_precision")) {
|
||||
return "map";
|
||||
} else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("softmax") || type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
|
||||
return "multi_logloss";
|
||||
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
|
||||
return "cross_entropy";
|
||||
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
|
||||
return "cross_entropy_lambda";
|
||||
} else if (type == std::string("kldiv") || type == std::string("kullback_leibler")) {
|
||||
return "kullback_leibler";
|
||||
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
|
||||
return "mape";
|
||||
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
|
||||
return "custom";
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
|
||||
std::unordered_set<std::string> metric_sets;
|
||||
out_metric->clear();
|
||||
std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
|
||||
for (auto& met : metrics) {
|
||||
auto type = ParseMetricAlias(met);
|
||||
if (metric_sets.count(type) <= 0) {
|
||||
out_metric->push_back(type);
|
||||
metric_sets.insert(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
|
||||
std::string value;
|
||||
if (Config::GetString(params, "objective", &value)) {
|
||||
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
|
||||
*objective = value;
|
||||
*objective = ParseObjectiveAlias(value);
|
||||
}
|
||||
}
|
||||
|
||||
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
|
||||
std::string value;
|
||||
if (Config::GetString(params, "metric", &value)) {
|
||||
// clear old metrics
|
||||
metric->clear();
|
||||
// to lower
|
||||
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
|
||||
// split
|
||||
std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
|
||||
// remove duplicate
|
||||
std::unordered_set<std::string> metric_sets;
|
||||
for (auto& met : metrics) {
|
||||
std::transform(met.begin(), met.end(), met.begin(), Common::tolower);
|
||||
if (metric_sets.count(met) <= 0) {
|
||||
metric_sets.insert(met);
|
||||
}
|
||||
}
|
||||
for (auto& met : metric_sets) {
|
||||
metric->push_back(met);
|
||||
}
|
||||
metric->shrink_to_fit();
|
||||
ParseMetrics(value, metric);
|
||||
}
|
||||
// add names of objective function if not providing metric
|
||||
if (metric->empty() && value.size() == 0) {
|
||||
if (Config::GetString(params, "objective", &value)) {
|
||||
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
|
||||
metric->push_back(value);
|
||||
ParseMetrics(value, metric);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -196,20 +246,13 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
|
|||
}
|
||||
|
||||
bool CheckMultiClassObjective(const std::string& objective) {
|
||||
return (objective == std::string("multiclass")
|
||||
|| objective == std::string("multiclassova")
|
||||
|| objective == std::string("softmax")
|
||||
|| objective == std::string("multiclass_ova")
|
||||
|| objective == std::string("ova")
|
||||
|| objective == std::string("ovr"));
|
||||
return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
|
||||
}
|
||||
|
||||
void Config::CheckParamConflict() {
|
||||
// check if objective, metric, and num_class match
|
||||
int num_class_check = num_class;
|
||||
bool objective_custom = objective == std::string("none") || objective == std::string("null")
|
||||
|| objective == std::string("custom") || objective == std::string("na");
|
||||
bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective_custom && num_class_check > 1);
|
||||
bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
|
||||
|
||||
if (objective_type_multiclass) {
|
||||
if (num_class_check <= 1) {
|
||||
|
@ -221,12 +264,10 @@ void Config::CheckParamConflict() {
|
|||
}
|
||||
}
|
||||
for (std::string metric_type : metric) {
|
||||
bool metric_custom_or_none = metric_type == std::string("none") || metric_type == std::string("null")
|
||||
|| metric_type == std::string("custom") || metric_type == std::string("na");
|
||||
bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
|
||||
|| metric_type == std::string("multi_logloss")
|
||||
|| metric_type == std::string("multi_error")
|
||||
|| (metric_custom_or_none && num_class_check > 1));
|
||||
|| (metric_type == std::string("custom") && num_class_check > 1));
|
||||
if ((objective_type_multiclass && !metric_type_multiclass)
|
||||
|| (!objective_type_multiclass && metric_type_multiclass)) {
|
||||
Log::Fatal("Multiclass objective and metrics don't match");
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
namespace LightGBM {
|
||||
|
||||
Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
|
||||
if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
|
||||
if (type == std::string("l2")) {
|
||||
return new L2Metric(config);
|
||||
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
|
||||
} else if (type == std::string("rmse")) {
|
||||
return new RMSEMetric(config);
|
||||
} else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
|
||||
} else if (type == std::string("l1")) {
|
||||
return new L1Metric(config);
|
||||
} else if (type == std::string("quantile")) {
|
||||
return new QuantileMetric(config);
|
||||
|
@ -28,27 +28,27 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
|
|||
return new FairLossMetric(config);
|
||||
} else if (type == std::string("poisson")) {
|
||||
return new PoissonMetric(config);
|
||||
} else if (type == std::string("binary_logloss") || type == std::string("binary")) {
|
||||
} else if (type == std::string("binary_logloss")) {
|
||||
return new BinaryLoglossMetric(config);
|
||||
} else if (type == std::string("binary_error")) {
|
||||
return new BinaryErrorMetric(config);
|
||||
} else if (type == std::string("auc")) {
|
||||
return new AUCMetric(config);
|
||||
} else if (type == std::string("ndcg") || type == std::string("lambdarank")) {
|
||||
} else if (type == std::string("ndcg")) {
|
||||
return new NDCGMetric(config);
|
||||
} else if (type == std::string("map") || type == std::string("mean_average_precision")) {
|
||||
} else if (type == std::string("map")) {
|
||||
return new MapMetric(config);
|
||||
} else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("softmax") || type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
|
||||
} else if (type == std::string("multi_logloss")) {
|
||||
return new MultiSoftmaxLoglossMetric(config);
|
||||
} else if (type == std::string("multi_error")) {
|
||||
return new MultiErrorMetric(config);
|
||||
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
|
||||
} else if (type == std::string("cross_entropy")) {
|
||||
return new CrossEntropyMetric(config);
|
||||
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
|
||||
} else if (type == std::string("cross_entropy_lambda")) {
|
||||
return new CrossEntropyLambdaMetric(config);
|
||||
} else if (type == std::string("kldiv") || type == std::string("kullback_leibler")) {
|
||||
} else if (type == std::string("kullback_leibler")) {
|
||||
return new KullbackLeiblerDivergence(config);
|
||||
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
|
||||
} else if (type == std::string("mape")) {
|
||||
return new MAPEMetric(config);
|
||||
} else if (type == std::string("gamma")) {
|
||||
return new GammaMetric(config);
|
||||
|
|
|
@ -287,7 +287,7 @@ class GammaDevianceMetric : public RegressionMetric<GammaDevianceMetric> {
|
|||
return tmp - Common::SafeLog(tmp) - 1;
|
||||
}
|
||||
inline static const char* Name() {
|
||||
return "gamma-deviance";
|
||||
return "gamma_deviance";
|
||||
}
|
||||
inline static double AverageLoss(double sum_loss, double) {
|
||||
return sum_loss * 2;
|
||||
|
|
|
@ -74,7 +74,7 @@ class CrossEntropyMetric : public Metric {
|
|||
virtual ~CrossEntropyMetric() {}
|
||||
|
||||
void Init(const Metadata& metadata, data_size_t num_data) override {
|
||||
name_.emplace_back("xentropy");
|
||||
name_.emplace_back("cross_entropy");
|
||||
num_data_ = num_data;
|
||||
label_ = metadata.label();
|
||||
weights_ = metadata.weights();
|
||||
|
@ -169,7 +169,7 @@ class CrossEntropyLambdaMetric : public Metric {
|
|||
virtual ~CrossEntropyLambdaMetric() {}
|
||||
|
||||
void Init(const Metadata& metadata, data_size_t num_data) override {
|
||||
name_.emplace_back("xentlambda");
|
||||
name_.emplace_back("cross_entropy_lambda");
|
||||
num_data_ = num_data;
|
||||
label_ = metadata.label();
|
||||
weights_ = metadata.weights();
|
||||
|
@ -252,7 +252,7 @@ class KullbackLeiblerDivergence : public Metric {
|
|||
virtual ~KullbackLeiblerDivergence() {}
|
||||
|
||||
void Init(const Metadata& metadata, data_size_t num_data) override {
|
||||
name_.emplace_back("kldiv");
|
||||
name_.emplace_back("kullback_leibler");
|
||||
num_data_ = num_data;
|
||||
label_ = metadata.label();
|
||||
weights_ = metadata.weights();
|
||||
|
|
|
@ -13,11 +13,9 @@
|
|||
namespace LightGBM {
|
||||
|
||||
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
|
||||
if (type == std::string("regression") || type == std::string("regression_l2")
|
||||
|| type == std::string("mean_squared_error") || type == std::string("mse")
|
||||
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
|
||||
if (type == std::string("regression")) {
|
||||
return new RegressionL2loss(config);
|
||||
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
|
||||
} else if (type == std::string("regression_l1")) {
|
||||
return new RegressionL1loss(config);
|
||||
} else if (type == std::string("quantile")) {
|
||||
return new RegressionQuantileloss(config);
|
||||
|
@ -31,21 +29,21 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
|
|||
return new BinaryLogloss(config);
|
||||
} else if (type == std::string("lambdarank")) {
|
||||
return new LambdarankNDCG(config);
|
||||
} else if (type == std::string("multiclass") || type == std::string("softmax")) {
|
||||
} else if (type == std::string("multiclass")) {
|
||||
return new MulticlassSoftmax(config);
|
||||
} else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
|
||||
} else if (type == std::string("multiclassova")) {
|
||||
return new MulticlassOVA(config);
|
||||
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
|
||||
} else if (type == std::string("cross_entropy")) {
|
||||
return new CrossEntropy(config);
|
||||
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
|
||||
} else if (type == std::string("cross_entropy_lambda")) {
|
||||
return new CrossEntropyLambda(config);
|
||||
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
|
||||
} else if (type == std::string("mape")) {
|
||||
return new RegressionMAPELOSS(config);
|
||||
} else if (type == std::string("gamma")) {
|
||||
return new RegressionGammaLoss(config);
|
||||
} else if (type == std::string("tweedie")) {
|
||||
return new RegressionTweedieLoss(config);
|
||||
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
|
||||
} else if (type == std::string("custom")) {
|
||||
return nullptr;
|
||||
}
|
||||
Log::Fatal("Unknown objective type name: %s", type.c_str());
|
||||
|
@ -74,17 +72,17 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
|
|||
return new MulticlassSoftmax(strs);
|
||||
} else if (type == std::string("multiclassova")) {
|
||||
return new MulticlassOVA(strs);
|
||||
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
|
||||
} else if (type == std::string("cross_entropy")) {
|
||||
return new CrossEntropy(strs);
|
||||
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
|
||||
} else if (type == std::string("cross_entropy_lambda")) {
|
||||
return new CrossEntropyLambda(strs);
|
||||
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
|
||||
} else if (type == std::string("mape")) {
|
||||
return new RegressionMAPELOSS(strs);
|
||||
} else if (type == std::string("gamma")) {
|
||||
return new RegressionGammaLoss(strs);
|
||||
} else if (type == std::string("tweedie")) {
|
||||
return new RegressionTweedieLoss(strs);
|
||||
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
|
||||
} else if (type == std::string("custom")) {
|
||||
return nullptr;
|
||||
}
|
||||
Log::Fatal("Unknown objective type name: %s", type.c_str());
|
||||
|
|
|
@ -94,7 +94,7 @@ class CrossEntropy: public ObjectiveFunction {
|
|||
}
|
||||
|
||||
const char* GetName() const override {
|
||||
return "xentropy";
|
||||
return "cross_entropy";
|
||||
}
|
||||
|
||||
// convert score to a probability
|
||||
|
@ -213,7 +213,7 @@ class CrossEntropyLambda: public ObjectiveFunction {
|
|||
}
|
||||
|
||||
const char* GetName() const override {
|
||||
return "xentlambda";
|
||||
return "cross_entropy_lambda";
|
||||
}
|
||||
|
||||
//
|
||||
|
|
Загрузка…
Ссылка в новой задаче