зеркало из https://github.com/microsoft/LightGBM.git
refactor CUDABinaryLoglossMetric
This commit is contained in:
Родитель
9f13ccc8ec
Коммит
7ffbf809bb
|
@ -96,7 +96,7 @@ class BinaryMetric: public Metric {
|
||||||
return std::vector<double>(1, loss);
|
return std::vector<double>(1, loss);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
/*! \brief Number of data */
|
/*! \brief Number of data */
|
||||||
data_size_t num_data_;
|
data_size_t num_data_;
|
||||||
/*! \brief Pointer of label */
|
/*! \brief Pointer of label */
|
||||||
|
|
|
@ -20,16 +20,27 @@
|
||||||
namespace LightGBM {
|
namespace LightGBM {
|
||||||
|
|
||||||
template <typename HOST_METRIC, typename CUDA_METRIC>
|
template <typename HOST_METRIC, typename CUDA_METRIC>
|
||||||
class CUDABinaryMetricInterface: public CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC> {
|
class CUDABinaryMetricInterface: public CUDAMetricInterface<HOST_METRIC> {
|
||||||
public:
|
public:
|
||||||
explicit CUDABinaryMetricInterface(const Config& config): CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC>(config) {}
|
explicit CUDABinaryMetricInterface(const Config& config): CUDAMetricInterface<HOST_METRIC>(config) {}
|
||||||
|
|
||||||
virtual ~CUDABinaryMetricInterface() {}
|
virtual ~CUDABinaryMetricInterface() {}
|
||||||
|
|
||||||
|
void Init(const Metadata& metadata, data_size_t num_data) override;
|
||||||
|
|
||||||
|
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
double LaunchEvalKernel(const double* score_convert) const;
|
||||||
|
|
||||||
|
CUDAVector<double> score_convert_buffer_;
|
||||||
|
CUDAVector<double> reduce_block_buffer_;
|
||||||
|
CUDAVector<double> reduce_block_buffer_inner_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric> {
|
class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric> {
|
||||||
public:
|
public:
|
||||||
explicit CUDABinaryLoglossMetric(const Config& config);
|
explicit CUDABinaryLoglossMetric(const Config& config): CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric>(config) {}
|
||||||
|
|
||||||
virtual ~CUDABinaryLoglossMetric() {}
|
virtual ~CUDABinaryLoglossMetric() {}
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче