From 7ffbf809bb0a4e7ed1b04e92bebe1bf2ac6c63fc Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Thu, 23 Mar 2023 04:59:31 +0000 Subject: [PATCH] refactor CUDABinaryLoglossMetric --- src/metric/binary_metric.hpp | 2 +- src/metric/cuda/cuda_binary_metric.hpp | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/metric/binary_metric.hpp b/src/metric/binary_metric.hpp index f70a4ef4a..037f54ba0 100644 --- a/src/metric/binary_metric.hpp +++ b/src/metric/binary_metric.hpp @@ -96,7 +96,7 @@ class BinaryMetric: public Metric { return std::vector(1, loss); } - private: + protected: /*! \brief Number of data */ data_size_t num_data_; /*! \brief Pointer of label */ diff --git a/src/metric/cuda/cuda_binary_metric.hpp b/src/metric/cuda/cuda_binary_metric.hpp index e4abae48c..b9489fc13 100644 --- a/src/metric/cuda/cuda_binary_metric.hpp +++ b/src/metric/cuda/cuda_binary_metric.hpp @@ -20,16 +20,27 @@ namespace LightGBM { template -class CUDABinaryMetricInterface: public CUDARegressionMetricInterface { +class CUDABinaryMetricInterface: public CUDAMetricInterface { public: - explicit CUDABinaryMetricInterface(const Config& config): CUDARegressionMetricInterface(config) {} + explicit CUDABinaryMetricInterface(const Config& config): CUDAMetricInterface(config) {} virtual ~CUDABinaryMetricInterface() {} + + void Init(const Metadata& metadata, data_size_t num_data) override; + + std::vector Eval(const double* score, const ObjectiveFunction* objective) const override; + + protected: + double LaunchEvalKernel(const double* score_convert) const; + + CUDAVector score_convert_buffer_; + CUDAVector reduce_block_buffer_; + CUDAVector reduce_block_buffer_inner_; }; class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface { public: - explicit CUDABinaryLoglossMetric(const Config& config); + explicit CUDABinaryLoglossMetric(const Config& config): CUDABinaryMetricInterface(config) {} virtual ~CUDABinaryLoglossMetric() {}