зеркало из https://github.com/microsoft/LightGBM.git
refactor CUDABinaryLoglossMetric
This commit is contained in:
Родитель
9f13ccc8ec
Коммит
7ffbf809bb
|
@ -96,7 +96,7 @@ class BinaryMetric: public Metric {
|
|||
return std::vector<double>(1, loss);
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
/*! \brief Number of data */
|
||||
data_size_t num_data_;
|
||||
/*! \brief Pointer of label */
|
||||
|
|
|
@ -20,16 +20,27 @@
|
|||
namespace LightGBM {
|
||||
|
||||
template <typename HOST_METRIC, typename CUDA_METRIC>
|
||||
class CUDABinaryMetricInterface: public CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC> {
|
||||
class CUDABinaryMetricInterface: public CUDAMetricInterface<HOST_METRIC> {
|
||||
public:
|
||||
explicit CUDABinaryMetricInterface(const Config& config): CUDARegressionMetricInterface<HOST_METRIC, CUDA_METRIC>(config) {}
|
||||
explicit CUDABinaryMetricInterface(const Config& config): CUDAMetricInterface<HOST_METRIC>(config) {}
|
||||
|
||||
virtual ~CUDABinaryMetricInterface() {}
|
||||
|
||||
void Init(const Metadata& metadata, data_size_t num_data) override;
|
||||
|
||||
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override;
|
||||
|
||||
protected:
|
||||
double LaunchEvalKernel(const double* score_convert) const;
|
||||
|
||||
CUDAVector<double> score_convert_buffer_;
|
||||
CUDAVector<double> reduce_block_buffer_;
|
||||
CUDAVector<double> reduce_block_buffer_inner_;
|
||||
};
|
||||
|
||||
class CUDABinaryLoglossMetric: public CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric> {
|
||||
public:
|
||||
explicit CUDABinaryLoglossMetric(const Config& config);
|
||||
explicit CUDABinaryLoglossMetric(const Config& config): CUDABinaryMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric>(config) {}
|
||||
|
||||
virtual ~CUDABinaryLoglossMetric() {}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче