refactor CUDABinaryLoglossMetric

This commit is contained in:
Yu Shi 2023-03-23 04:59:31 +00:00
Родитель 9f13ccc8ec
Коммит 7ffbf809bb
2 изменённых файлов: 15 добавлений и 4 удалений

Просмотреть файл

@ -96,7 +96,7 @@ class BinaryMetric: public Metric {
return std::vector<double>(1, loss);
}
private:
protected:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Pointer of label */

Просмотреть файл

@ -20,16 +20,27 @@
namespace LightGBM {
template <typename HOST_METRIC, typename CUDA_METRIC>
class CUDABinaryMetricInterface: public CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC> {
class CUDABinaryMetricInterface: public CUDAMetricInterface<HOST_METRIC> {
public:
explicit CUDABinaryMetricInterface(const Config& config): CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC>(config) {}
explicit CUDABinaryMetricInterface(const Config& config): CUDAMetricInterface<HOST_METRIC>(config) {}
virtual ~CUDABinaryMetricInterface() {}
void Init(const Metadata& metadata, data_size_t num_data) override;
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override;
protected:
double LaunchEvalKernel(const double* score_convert) const;
CUDAVector<double> score_convert_buffer_;
CUDAVector<double> reduce_block_buffer_;
CUDAVector<double> reduce_block_buffer_inner_;
};
class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric> {
public:
explicit CUDABinaryLoglossMetric(const Config& config);
explicit CUDABinaryLoglossMetric(const Config& config): CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric>(config) {}
virtual ~CUDABinaryLoglossMetric() {}