refactor CUDABinaryLoglossMetric

This commit is contained in:
Yu Shi 2023-03-23 04:59:31 +00:00
Родитель 9f13ccc8ec
Коммит 7ffbf809bb
2 изменённых файлов: 15 добавлений и 4 удалений

Просмотреть файл

@ -96,7 +96,7 @@ class BinaryMetric: public Metric {
return std::vector<double>(1, loss); return std::vector<double>(1, loss);
} }
private: protected:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */

Просмотреть файл

@ -20,16 +20,27 @@
namespace LightGBM { namespace LightGBM {
template <typename HOST_METRIC, typename CUDA_METRIC> template <typename HOST_METRIC, typename CUDA_METRIC>
class CUDABinaryMetricInterface: public CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC> { class CUDABinaryMetricInterface: public CUDAMetricInterface<HOST_METRIC> {
public: public:
explicit CUDABinaryMetricInterface(const Config& config): CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC>(config) {} explicit CUDABinaryMetricInterface(const Config& config): CUDAMetricInterface<HOST_METRIC>(config) {}
virtual ~CUDABinaryMetricInterface() {} virtual ~CUDABinaryMetricInterface() {}
void Init(const Metadata& metadata, data_size_t num_data) override;
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override;
protected:
double LaunchEvalKernel(const double* score_convert) const;
CUDAVector<double> score_convert_buffer_;
CUDAVector<double> reduce_block_buffer_;
CUDAVector<double> reduce_block_buffer_inner_;
}; };
class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric> { class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric> {
public: public:
explicit CUDABinaryLoglossMetric(const Config& config); explicit CUDABinaryLoglossMetric(const Config& config): CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric>(config) {}
virtual ~CUDABinaryLoglossMetric() {} virtual ~CUDABinaryLoglossMetric() {}