Benchmarks: Micro benchmarks - add source code of correctness check for cublas functions (#450)
**Description** Add c source code of correctness check for cublas functions. **Major Revision** - add correctness check for all supported cublas functions - add --correctness option into binary **Minor Revision** - fix bug and template fill_data and prepare_tensor to get right memory-alignment output matrix for different datatype
This commit is contained in:
Родитель
9dfefce350
Коммит
678b1251f1
|
@ -202,6 +202,7 @@ class RuleOp:
|
|||
categories (set): categories of violated rules
|
||||
store_values (dict): including the number of the metrics that violate the rule, and the values of
|
||||
the metrics for the rules with 'store' True
|
||||
|
||||
Returns:
|
||||
number: 0 if the rule is passed, otherwise 1
|
||||
"""
|
||||
|
|
|
@ -89,8 +89,9 @@ class ResultSummary(RuleBase):
|
|||
|
||||
Args:
|
||||
category (str): category in the rule
|
||||
summary_df_of_rule ([type]): summary df of a rule, the columns are metrics, the index are statistics
|
||||
summary_df_of_rule (DataFrame): summary df of a rule, the columns are metrics, the index are statistics
|
||||
statistics (list): statistics in the rule
|
||||
|
||||
Returns:
|
||||
list: list of summary lines like [category, metric, statistic, value]
|
||||
"""
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <stdlib.h>
|
||||
|
@ -51,6 +52,8 @@ class CublasFunction {
|
|||
int warm_up; ///< the number of steps used to warm up
|
||||
int num_in_step; ///< the number of functions invoking in a step
|
||||
int random_seed; ///< the random seed used to generate random data
|
||||
double eps; ///< the acceptable error bound for numeric stability
|
||||
bool correctness; ///< whether enable correctness check or not
|
||||
std::string name_; ///< the name of the cublas function
|
||||
int m_; ///< the m dim of matrix
|
||||
int k_; ///< the k dim of matrix
|
||||
|
@ -65,21 +68,15 @@ class CublasFunction {
|
|||
cublasHandle_t cublas_handle; ///< the handle of cublas function
|
||||
|
||||
/**
|
||||
* @brief Fill the random data into the input in float type
|
||||
* @brief Fill the random data into the input
|
||||
*/
|
||||
void fill_data_float(float *Parameter_0_0_host, float *Parameter_1_0_host);
|
||||
template <typename T> void fill_data(T *Parameter_0_0_host, T *Parameter_1_0_host);
|
||||
/**
|
||||
* @brief Fill the random data into the input in cuComplex type
|
||||
* @brief Prepare memory and data of the input and output
|
||||
*/
|
||||
void fill_data_cucomplex(cuComplex *Parameter_0_0_host, cuComplex *Parameter_1_0_host);
|
||||
/**
|
||||
* @brief Prepare memory and data of the input and output in float type
|
||||
*/
|
||||
void prepare_tensor_float(float **Parameter_0_0, float **Parameter_1_0, float **Result_3_0);
|
||||
/**
|
||||
* @brief Prepare memory and data of the input and output in cuComplex type
|
||||
*/
|
||||
void prepare_tensor_cucomplex(cuComplex **Parameter_0_0, cuComplex **Parameter_1_0, cuComplex **Result_3_0);
|
||||
template <typename T>
|
||||
void prepare_tensor_template(T **Parameter_0_0, T **Parameter_1_0, T **Result_3_0, T **Parameter_0_0_host,
|
||||
T **Parameter_1_0_host);
|
||||
/**
|
||||
* @brief Prepare memory and data of the input and output for kernel running
|
||||
*/
|
||||
|
@ -88,6 +85,29 @@ class CublasFunction {
|
|||
* @brief Execute the kernel/function
|
||||
*/
|
||||
virtual void kernel_entry() {}
|
||||
/**
|
||||
* @brief Transpose the colomn-order stored matrix
|
||||
*/
|
||||
template <typename T> T *transpose(const T *matrix, int m, int n, int batch_count);
|
||||
/**
|
||||
* @brief Matrix multiply calculation on CPU side with input data and output data
|
||||
*/
|
||||
template <typename T1, typename T2>
|
||||
void matrix_calculation_on_cpu_with_data(const T1 *Parameter_0_0_host, const T1 *Parameter_1_0_host,
|
||||
const T1 *Result_3_0, T2 **Result_cpu, T2 alpha = 1, T2 beta = 0);
|
||||
/**
|
||||
* @brief Check if the error < eps between the calculation result of GPU and CPU for each element in the matrix
|
||||
*/
|
||||
template <typename T1, typename T2>
|
||||
int check_result(int batch_count, T1 *Result_3_0, T2 *Result_cpu, double eps = 1.e-6);
|
||||
/**
|
||||
* @brief Virtual function of Matrix multiply calculation on CPU side
|
||||
*/
|
||||
virtual void matrix_calculation_on_cpu() {}
|
||||
/**
|
||||
* @brief Virtual function of Check the cublas function calculation correctness
|
||||
*/
|
||||
virtual int correctness_check() { return 0; }
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -110,10 +130,21 @@ class CublasFunction {
|
|||
* @param random_seed random seed
|
||||
*/
|
||||
void set_random_seed(int random_seed) { this->random_seed = random_seed; }
|
||||
/**
|
||||
* @brief Set the correctness
|
||||
* @param correctness_check if check the correctness of the function result
|
||||
*/
|
||||
void set_correctness(int correctness_check) { this->correctness = correctness_check; }
|
||||
/**
|
||||
* @brief Set the eps
|
||||
* @param eps the acceptable error bound for numeric stability
|
||||
*/
|
||||
void set_eps(double eps) { this->eps = eps; }
|
||||
/**
|
||||
* @brief Set the params string
|
||||
* @param str the str representing the params of the function
|
||||
*/
|
||||
|
||||
void set_function(std::string &str) { this->function_str_ = str; }
|
||||
/**
|
||||
* @brief Set the name member
|
||||
|
@ -195,95 +226,242 @@ class CublasFunction {
|
|||
};
|
||||
|
||||
/**
|
||||
* @brief Fill the random data into the input in cuComplex type
|
||||
* @brief Fill the random data into the input in float type
|
||||
*/
|
||||
void CublasFunction::fill_data_float(float *Parameter_0_0_host, float *Parameter_1_0_host) {
|
||||
template <> void CublasFunction::fill_data(float *Parameter_0_0_host, float *Parameter_1_0_host) {
|
||||
srand(random_seed);
|
||||
for (int i = 0; i < m_ * k_; i++) {
|
||||
Parameter_0_0_host[i] = (float)rand() / (float)(RAND_MAX);
|
||||
for (int i = 0; i < m_ * k_ * batch_count_; i++) {
|
||||
Parameter_0_0_host[i] = ((float)rand() / (float)(RAND_MAX));
|
||||
}
|
||||
for (int i = 0; i < k_ * n_; ++i) {
|
||||
Parameter_1_0_host[i] = (float)rand() / (float)(RAND_MAX);
|
||||
for (int i = 0; i < k_ * n_ * batch_count_; ++i) {
|
||||
Parameter_1_0_host[i] = ((float)rand() / (float)(RAND_MAX));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Fill the random data into the input in half type
|
||||
*/
|
||||
template <> void CublasFunction::fill_data(half *Parameter_0_0_host, half *Parameter_1_0_host) {
|
||||
srand(random_seed);
|
||||
for (int i = 0; i < m_ * k_ * batch_count_; i++) {
|
||||
Parameter_0_0_host[i] = half((float)rand() / (float)(RAND_MAX));
|
||||
}
|
||||
for (int i = 0; i < k_ * n_ * batch_count_; ++i) {
|
||||
Parameter_1_0_host[i] = half((float)rand() / (float)(RAND_MAX));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Fill the random data into the input in cuComplex type
|
||||
*/
|
||||
void CublasFunction::fill_data_cucomplex(cuComplex *Parameter_0_0_host, cuComplex *Parameter_1_0_host) {
|
||||
template <> void CublasFunction::fill_data(cuComplex *Parameter_0_0_host, cuComplex *Parameter_1_0_host) {
|
||||
srand(random_seed);
|
||||
for (int i = 0; i < m_ * k_; i++) {
|
||||
for (int i = 0; i < m_ * k_ * batch_count_; i++) {
|
||||
Parameter_0_0_host[i] =
|
||||
make_cuComplex(((float)rand() / (float)(RAND_MAX)), ((float)rand() / (float)(RAND_MAX)));
|
||||
}
|
||||
for (int i = 0; i < k_ * n_; ++i) {
|
||||
for (int i = 0; i < k_ * n_ * batch_count_; ++i) {
|
||||
Parameter_1_0_host[i] =
|
||||
make_cuComplex(((float)rand() / (float)(RAND_MAX)), ((float)rand() / (float)(RAND_MAX)));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Prepare memory and data of the input and output in float type
|
||||
* @brief Prepare memory and data of the input and output
|
||||
*/
|
||||
void CublasFunction::prepare_tensor_float(float **Parameter_0_0, float **Parameter_1_0, float **Result_3_0) {
|
||||
int m = this->m_;
|
||||
int n = this->n_;
|
||||
int k = this->k_;
|
||||
|
||||
float *Parameter_0_0_host, *Parameter_1_0_host;
|
||||
template <typename T>
|
||||
void CublasFunction::prepare_tensor_template(T **Parameter_0_0, T **Parameter_1_0, T **Result_3_0,
|
||||
T **Parameter_0_0_host, T **Parameter_1_0_host) {
|
||||
int m = this->m_, n = this->n_, k = this->k_, batch_count = this->batch_count_;
|
||||
// input argument
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_0_0_host, sizeof(float) * m * k * this->batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_0_0, sizeof(float) * m * k * this->batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)Parameter_0_0_host, sizeof(T) * m * k * batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_0_0, sizeof(T) * m * k * batch_count_));
|
||||
// input argument
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_1_0_host, sizeof(float) * n * k * this->batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_1_0, sizeof(float) * n * k * this->batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)Parameter_1_0_host, sizeof(T) * n * k * batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_1_0, sizeof(T) * n * k * batch_count_));
|
||||
|
||||
// fill input values
|
||||
fill_data_float(Parameter_0_0_host, Parameter_1_0_host);
|
||||
fill_data(reinterpret_cast<T *>(*Parameter_0_0_host), reinterpret_cast<T *>(*Parameter_1_0_host));
|
||||
|
||||
// copy input data from host to device
|
||||
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_0_0, Parameter_0_0_host, sizeof(float) * m * k * this->batch_count_,
|
||||
cudaMemcpyHostToDevice));
|
||||
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_1_0, Parameter_1_0_host, sizeof(float) * k * n * this->batch_count_,
|
||||
cudaMemcpyHostToDevice));
|
||||
CUDA_SAFE_CALL(
|
||||
cudaMemcpy(*Parameter_0_0, *Parameter_0_0_host, sizeof(T) * m * k * batch_count_, cudaMemcpyHostToDevice));
|
||||
CUDA_SAFE_CALL(
|
||||
cudaMemcpy(*Parameter_1_0, *Parameter_1_0_host, sizeof(T) * k * n * batch_count_, cudaMemcpyHostToDevice));
|
||||
|
||||
// output arguments
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Result_3_0, sizeof(float) * m * n * batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMemset((void *)*Result_3_0, 0, sizeof(float) * m * n * batch_count_));
|
||||
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Result_3_0, sizeof(T) * m * n * batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMemset((void *)*Result_3_0, 0, sizeof(T) * m * n * batch_count_));
|
||||
}
|
||||
/**
|
||||
* @brief Prepare memory and data of the input and output in cuComplex type
|
||||
* @brief Transpose the colomn-order stored matrix with float or half datatype
|
||||
*/
|
||||
void CublasFunction::prepare_tensor_cucomplex(cuComplex **Parameter_0_0, cuComplex **Parameter_1_0,
|
||||
cuComplex **Result_3_0) {
|
||||
int m = this->m_;
|
||||
int n = this->n_;
|
||||
int k = this->k_;
|
||||
template <typename T> T *CublasFunction::transpose(const T *matrix, int m, int n, int batch_count) {
|
||||
T *transpose_matrix = (T *)malloc((unsigned long)m * (unsigned long)n * sizeof(T) * (unsigned long)batch_count);
|
||||
for (int b = 0; b < batch_count; b++) {
|
||||
for (int i = 0; i < m * n; i++) {
|
||||
int c = i / m;
|
||||
int r = i % m;
|
||||
int tran_i = r * n + c;
|
||||
transpose_matrix[tran_i + b * m * n] = matrix[i + b * m * n];
|
||||
}
|
||||
}
|
||||
return transpose_matrix;
|
||||
}
|
||||
/**
|
||||
* @brief Matrix multiply calculation on CPU side
|
||||
*/
|
||||
template <typename T1, typename T2>
|
||||
void CublasFunction::matrix_calculation_on_cpu_with_data(const T1 *Parameter_0_0_host, const T1 *Parameter_1_0_host,
|
||||
const T1 *Result_3_0, T2 **Result_cpu, T2 alpha, T2 beta) {
|
||||
int m = this->m_, n = this->n_, k = this->k_, batch_count = this->batch_count_;
|
||||
// Copy result from device to host
|
||||
T1 *Result_3_0_host;
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Result_3_0_host,
|
||||
sizeof(T1) * (unsigned long)m * (unsigned long)n * (unsigned long)batch_count));
|
||||
CUDA_SAFE_CALL(cudaMemcpy(Result_3_0_host, Result_3_0,
|
||||
sizeof(T1) * (unsigned long)m * (unsigned long)n * (unsigned long)batch_count,
|
||||
cudaMemcpyDeviceToHost));
|
||||
// Transpose the input matrix
|
||||
T1 *Parameter_0_0_host_op, *Parameter_1_0_host_op;
|
||||
Parameter_0_0_host_op = (T1 *)malloc((unsigned long)m * (unsigned long)k * sizeof(T1) * (unsigned long)batch_count);
|
||||
Parameter_1_0_host_op = (T1 *)malloc((unsigned long)n * (unsigned long)k * sizeof(T1) * (unsigned long)batch_count);
|
||||
memcpy(Parameter_0_0_host_op, Parameter_0_0_host,
|
||||
(unsigned long)m * (unsigned long)k * sizeof(T1) * (unsigned long)batch_count);
|
||||
memcpy(Parameter_1_0_host_op, Parameter_1_0_host,
|
||||
(unsigned long)n * (unsigned long)k * sizeof(T1) * (unsigned long)batch_count);
|
||||
if (this->transa_) {
|
||||
Parameter_0_0_host_op = transpose(Parameter_0_0_host, k, m, batch_count);
|
||||
}
|
||||
if (this->transb_) {
|
||||
Parameter_1_0_host_op = transpose(Parameter_1_0_host, n, k, batch_count);
|
||||
}
|
||||
// C + i*strideC = alpha*op(A+i*strideA)*op(B+i*strideB)+beta(C+i*strideC), for i in [0, batchcount -1 ]
|
||||
// reference in https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference
|
||||
*Result_cpu = (T2 *)malloc((unsigned long)m * (unsigned long)n * sizeof(T2) * (unsigned long)batch_count);
|
||||
for (int b = 0; b < batch_count; b++) {
|
||||
for (int i = 0; i < m; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
(*Result_cpu)[i + j * m + b * m * n] = beta * (T2)(Result_3_0_host[i + j * m + b * m * n]);
|
||||
for (int p = 0; p < k; p++) {
|
||||
(*Result_cpu)[i + j * m + b * m * n] +=
|
||||
Parameter_0_0_host_op[p * m + i + b * m * k] * Parameter_1_0_host_op[j * k + p + b * k * n];
|
||||
(*Result_cpu)[i + j * m + b * m * n] *= alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Result_3_0_host));
|
||||
free(Parameter_0_0_host_op);
|
||||
free(Parameter_1_0_host_op);
|
||||
}
|
||||
/**
|
||||
* @brief Transpose the colomn-order stored matrix with complex datatype
|
||||
*/
|
||||
template <>
|
||||
void CublasFunction::matrix_calculation_on_cpu_with_data(const cuComplex *Parameter_0_0_host,
|
||||
const cuComplex *Parameter_1_0_host,
|
||||
const cuComplex *Result_3_0, std::complex<float> **Result_cpu,
|
||||
std::complex<float> alpha, std::complex<float> beta) {
|
||||
int m = this->m_, n = this->n_, k = this->k_, batch_count = this->batch_count_;
|
||||
// Copy result from device to host
|
||||
std::complex<float> *Result_3_0_host;
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Result_3_0_host, sizeof(std::complex<float>) * m * n * batch_count));
|
||||
CUDA_SAFE_CALL(cudaMemcpy(Result_3_0_host, Result_3_0, sizeof(std::complex<float>) * m * n * batch_count,
|
||||
cudaMemcpyDeviceToHost));
|
||||
cuComplex *Parameter_0_0_host_op, *Parameter_1_0_host_op;
|
||||
Parameter_0_0_host_op =
|
||||
(cuComplex *)malloc((unsigned long)m * (unsigned long)k * sizeof(cuComplex) * (unsigned long)batch_count);
|
||||
Parameter_1_0_host_op =
|
||||
(cuComplex *)malloc((unsigned long)n * (unsigned long)k * sizeof(cuComplex) * (unsigned long)batch_count);
|
||||
memcpy(Parameter_0_0_host_op, Parameter_0_0_host,
|
||||
(unsigned long)m * (unsigned long)k * sizeof(cuComplex) * (unsigned long)batch_count);
|
||||
memcpy(Parameter_1_0_host_op, Parameter_1_0_host,
|
||||
(unsigned long)n * (unsigned long)k * sizeof(cuComplex) * (unsigned long)batch_count);
|
||||
if (this->transa_) {
|
||||
Parameter_0_0_host_op = transpose<cuComplex>(Parameter_0_0_host, k, m, batch_count);
|
||||
}
|
||||
if (this->transb_) {
|
||||
Parameter_1_0_host_op = transpose<cuComplex>(Parameter_1_0_host, n, k, batch_count);
|
||||
}
|
||||
|
||||
cuComplex *Parameter_0_0_host, *Parameter_1_0_host;
|
||||
// input argument
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_0_0_host, sizeof(cuComplex) * m * k * this->batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_0_0, sizeof(cuComplex) * m * k * this->batch_count_));
|
||||
// input argument
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_1_0_host, sizeof(cuComplex) * n * k * this->batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_1_0, sizeof(cuComplex) * n * k * this->batch_count_));
|
||||
*Result_cpu = (std::complex<float> *)malloc((unsigned long)m * (unsigned long)n * sizeof(std::complex<float>) *
|
||||
(unsigned long)batch_count);
|
||||
|
||||
// fill input values
|
||||
fill_data_cucomplex(Parameter_0_0_host, Parameter_1_0_host);
|
||||
for (int b = 0; b < batch_count; b++) {
|
||||
for (int i = 0; i < m; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
(*Result_cpu)[i + j * m + b * m * n] = beta * Result_3_0_host[i + j * m + b * m * n];
|
||||
for (int p = 0; p < k; p++) {
|
||||
(*Result_cpu)[i + j * m + b * m * n] +=
|
||||
std::complex<float>(Parameter_0_0_host_op[p * m + i + b * m * k].x,
|
||||
Parameter_0_0_host_op[p * m + i + b * m * k].y) *
|
||||
std::complex<float>(Parameter_1_0_host_op[j * k + p + b * k * n].x,
|
||||
Parameter_1_0_host_op[j * k + p + b * k * n].y);
|
||||
(*Result_cpu)[i + j * m + b * m * n] *= alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Result_3_0_host));
|
||||
free(Parameter_0_0_host_op);
|
||||
free(Parameter_1_0_host_op);
|
||||
}
|
||||
/**
|
||||
* @brief Check if the error < eps between the calculation result of GPU and CPU for each element in the matrix
|
||||
*/
|
||||
template <typename T1, typename T2>
|
||||
int CublasFunction::check_result(int batch_count, T1 *Result_3_0, T2 *Result_cpu, double eps) {
|
||||
int m = this->m_, n = this->n_, k = this->k_;
|
||||
// Copy result from device to host
|
||||
T1 *Result_3_0_host;
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Result_3_0_host, sizeof(T1) * m * n * batch_count));
|
||||
CUDA_SAFE_CALL(cudaMemcpy(Result_3_0_host, Result_3_0, sizeof(T1) * m * n * batch_count, cudaMemcpyDeviceToHost));
|
||||
|
||||
// copy input data from host to device
|
||||
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_0_0, Parameter_0_0_host, sizeof(cuComplex) * m * k * this->batch_count_,
|
||||
cudaMemcpyHostToDevice));
|
||||
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_1_0, Parameter_1_0_host, sizeof(cuComplex) * k * n * this->batch_count_,
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
// output arguments
|
||||
CUDA_SAFE_CALL(cudaMalloc((void **)Result_3_0, sizeof(cuComplex) * m * n * batch_count_));
|
||||
CUDA_SAFE_CALL(cudaMemset((void *)*Result_3_0, 0, sizeof(cuComplex) * m * n * batch_count_));
|
||||
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
// test relative error by the formula
|
||||
// |<x, y>_cpu - <x,y>_gpu|/|<x, y>_cpu|/dot_length < eps
|
||||
int error_count = 0;
|
||||
for (int i = 0; i < static_cast<int>(m * n) * batch_count; i++) {
|
||||
double abs_err = fabs(Result_cpu[i] - Result_3_0_host[i]);
|
||||
double dot_length = k;
|
||||
double abs_val = fabs(Result_cpu[i]);
|
||||
double rel_err = abs_err / abs_val / dot_length;
|
||||
if (rel_err > eps) {
|
||||
printf("error! matrix[%05d]=%.8f, ref=%.8f error term %.8f is > %E\n", i, (float)Result_3_0_host[i],
|
||||
Result_cpu[i], rel_err, eps);
|
||||
error_count += 1;
|
||||
}
|
||||
}
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Result_3_0_host));
|
||||
free(Result_cpu);
|
||||
return error_count;
|
||||
}
|
||||
/**
|
||||
* @brief Check if the error < eps between the calculation result of GPU and CPU for each element in the matrix
|
||||
*/
|
||||
template <>
|
||||
int CublasFunction::check_result(int batch_count, cuComplex *Result_3_0, std::complex<float> *Result_cpu, double eps) {
|
||||
int m = this->m_, n = this->n_, k = this->k_;
|
||||
// Copy result from device to host
|
||||
std::complex<float> *Result_3_0_host;
|
||||
CUDA_SAFE_CALL(cudaMallocHost((void **)&Result_3_0_host, sizeof(std::complex<float>) * m * n * batch_count));
|
||||
CUDA_SAFE_CALL(cudaMemcpy(Result_3_0_host, Result_3_0, sizeof(std::complex<float>) * m * n * batch_count,
|
||||
cudaMemcpyDeviceToHost));
|
||||
// test relative error by the formula
|
||||
// |<x, y>_cpu - <x,y>_gpu|/|<x, y>_cpu|/dot_length < eps
|
||||
int error_count = 0;
|
||||
for (int i = 0; i < static_cast<int>(m * n) * batch_count; i++) {
|
||||
double abs_err = fabs(Result_cpu[i] - Result_3_0_host[i]);
|
||||
double dot_length = k;
|
||||
double abs_val = fabs(Result_cpu[i]);
|
||||
double rel_err = abs_err / abs_val / dot_length;
|
||||
if (rel_err > eps) {
|
||||
printf("error! matrix[%05d]=%.8f,%.8f, ref=%.8f,%.8f error term %.8f is > %E\n", i,
|
||||
Result_3_0_host[i].real(), Result_3_0_host[i].imag(), Result_cpu[i].real(), Result_cpu[i].imag(),
|
||||
rel_err, eps);
|
||||
error_count += 1;
|
||||
}
|
||||
}
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Result_3_0_host));
|
||||
free(Result_cpu);
|
||||
return error_count;
|
||||
}
|
||||
/**
|
||||
* @brief The main procedure for cublas function test, including warmup, function test, time measurement and output raw
|
||||
|
@ -303,12 +481,18 @@ void CublasFunction::benchmark() {
|
|||
|
||||
// Prepare some varibles for time measurement
|
||||
std::vector<float> iteration_time;
|
||||
int errors = 0;
|
||||
// Benchmark in range of steps
|
||||
for (int i_ = 0; i_ < num_test; i_++) {
|
||||
// Collect time within each step, including #repeat_in_one_step times function invoking
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
for (int j = 0; j < num_in_step; j++) {
|
||||
if (this->correctness)
|
||||
this->matrix_calculation_on_cpu();
|
||||
this->kernel_entry();
|
||||
if (this->correctness) {
|
||||
errors += this->correctness_check();
|
||||
}
|
||||
}
|
||||
CUDA_SAFE_CALL(cudaDeviceSynchronize());
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
|
@ -325,4 +509,9 @@ void CublasFunction::benchmark() {
|
|||
std::cout << iteration_time[i] << ",";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
if (this->correctness) {
|
||||
std::string correctness_str = errors == 0 ? "Result = PASS" : "Result = FAIL";
|
||||
std::cout << "[correctness]: " << correctness_str
|
||||
<< ", error rate: " << errors / (num_in_step * num_test * this->m_ * this->n_) << std::endl;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,9 +14,13 @@
|
|||
* @brief Class of SgemmFunction
|
||||
*/
|
||||
class SgemmFunction : public CublasFunction {
|
||||
float *Parameter_0_0; ///< the pointer of the first input data
|
||||
float *Parameter_1_0; ///< the pointer of the second input data
|
||||
float *Result_3_0; ///< the pointer of output data
|
||||
float *Parameter_0_0; ///< the pointer of the first input data
|
||||
float *Parameter_1_0; ///< the pointer of the second input data
|
||||
float *Result_3_0; ///< the pointer of output data
|
||||
float *Parameter_0_0_host; ///< the pointer of the first input data on host
|
||||
float *Parameter_1_0_host; ///< the pointer of the second input data on host
|
||||
float *Result_cpu;
|
||||
|
||||
/**
|
||||
* @brief Execute the kernel/function
|
||||
*/
|
||||
|
@ -25,10 +29,26 @@ class SgemmFunction : public CublasFunction {
|
|||
reinterpret_cast<const float *>(Parameter_0_0), reinterpret_cast<const float *>(Parameter_1_0),
|
||||
reinterpret_cast<float *>(Result_3_0));
|
||||
}
|
||||
/**
|
||||
* @brief Function calculation on CPU side
|
||||
*/
|
||||
virtual void matrix_calculation_on_cpu() {
|
||||
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu, 1.0f,
|
||||
1.0f);
|
||||
}
|
||||
/**
|
||||
* @brief Prepare memory and data of the input and output for kernel running
|
||||
*/
|
||||
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
|
||||
virtual void prepare_tensor() {
|
||||
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
|
||||
}
|
||||
/**
|
||||
* @brief Check the correctness of function calculation result
|
||||
*/
|
||||
virtual int correctness_check() {
|
||||
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
|
||||
return check_result(1, Result_3_0, Result_cpu, eps);
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -54,6 +74,8 @@ class SgemmFunction : public CublasFunction {
|
|||
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Result_3_0));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
cuda_free(&cublas_handle);
|
||||
}
|
||||
};
|
||||
|
@ -65,6 +87,9 @@ class CgemmFunction : public CublasFunction {
|
|||
cuComplex *Parameter_0_0;
|
||||
cuComplex *Parameter_1_0;
|
||||
cuComplex *Result_3_0;
|
||||
cuComplex *Parameter_0_0_host;
|
||||
cuComplex *Parameter_1_0_host;
|
||||
std::complex<float> *Result_cpu;
|
||||
/**
|
||||
* @brief Execute the kernel/function
|
||||
*/
|
||||
|
@ -73,11 +98,24 @@ class CgemmFunction : public CublasFunction {
|
|||
reinterpret_cast<const cuComplex *>(Parameter_0_0), reinterpret_cast<const cuComplex *>(Parameter_1_0),
|
||||
reinterpret_cast<cuComplex *>(Result_3_0));
|
||||
}
|
||||
/**
|
||||
* @brief Function calculation on CPU side
|
||||
*/
|
||||
virtual void matrix_calculation_on_cpu() {
|
||||
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu);
|
||||
}
|
||||
/**
|
||||
* @brief Prepare memory and data of the input and output for kernel running
|
||||
*/
|
||||
virtual void prepare_tensor() {
|
||||
CublasFunction::prepare_tensor_cucomplex(&Parameter_0_0, &Parameter_1_0, &Result_3_0);
|
||||
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
|
||||
}
|
||||
/**
|
||||
* @brief Check the correctness of function calculation result
|
||||
*/
|
||||
virtual int correctness_check() {
|
||||
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
|
||||
return check_result(1, Result_3_0, Result_cpu, eps);
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -104,6 +142,8 @@ class CgemmFunction : public CublasFunction {
|
|||
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Result_3_0));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
cuda_free(&cublas_handle);
|
||||
}
|
||||
};
|
||||
|
@ -112,9 +152,12 @@ class CgemmFunction : public CublasFunction {
|
|||
* @brief Class of GemmExFunction
|
||||
*/
|
||||
class GemmExFunction : public CublasFunction {
|
||||
float *Parameter_0_0;
|
||||
float *Parameter_1_0;
|
||||
float *Result_3_0;
|
||||
void *Parameter_0_0;
|
||||
void *Parameter_1_0;
|
||||
void *Result_3_0;
|
||||
void *Parameter_0_0_host;
|
||||
void *Parameter_1_0_host;
|
||||
void *Result_cpu;
|
||||
/**
|
||||
* @brief Execute the kernel/function
|
||||
*/
|
||||
|
@ -126,7 +169,49 @@ class GemmExFunction : public CublasFunction {
|
|||
/**
|
||||
* @brief Prepare memory and data of the input and output for kernel running
|
||||
*/
|
||||
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
|
||||
virtual void prepare_tensor() {
|
||||
if (this->datatype_.compare("half")) {
|
||||
CublasFunction::prepare_tensor_template<half>(
|
||||
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
|
||||
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
|
||||
reinterpret_cast<half **>(&Parameter_1_0_host));
|
||||
} else if (this->datatype_.compare("float")) {
|
||||
CublasFunction::prepare_tensor_template<float>(
|
||||
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
|
||||
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
|
||||
reinterpret_cast<float **>(&Parameter_1_0_host));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Function calculation on CPU side
|
||||
*/
|
||||
virtual void matrix_calculation_on_cpu() {
|
||||
if (this->datatype_.compare("half")) {
|
||||
matrix_calculation_on_cpu_with_data(
|
||||
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
|
||||
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
|
||||
} else if (this->datatype_.compare("float")) {
|
||||
matrix_calculation_on_cpu_with_data(
|
||||
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
|
||||
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Check the correctness of function calculation result
|
||||
*/
|
||||
virtual int correctness_check() {
|
||||
int result = 0;
|
||||
if (this->datatype_.compare("half")) {
|
||||
double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
|
||||
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
|
||||
reinterpret_cast<float *>(Result_cpu), eps);
|
||||
} else if (this->datatype_.compare("float")) {
|
||||
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
|
||||
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
|
||||
reinterpret_cast<float *>(Result_cpu), eps);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -152,6 +237,8 @@ class GemmExFunction : public CublasFunction {
|
|||
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Result_3_0));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
cuda_free(&cublas_handle);
|
||||
}
|
||||
};
|
||||
|
@ -160,9 +247,12 @@ class GemmExFunction : public CublasFunction {
|
|||
* @brief Class of GemmStridedBatchedExFunction
|
||||
*/
|
||||
class GemmStridedBatchedExFunction : public CublasFunction {
|
||||
float *Parameter_0_0;
|
||||
float *Parameter_1_0;
|
||||
float *Result_3_0;
|
||||
void *Parameter_0_0;
|
||||
void *Parameter_1_0;
|
||||
void *Result_3_0;
|
||||
void *Parameter_0_0_host;
|
||||
void *Parameter_1_0_host;
|
||||
void *Result_cpu;
|
||||
/**
|
||||
* @brief Execute the kernel/function
|
||||
*/
|
||||
|
@ -175,7 +265,49 @@ class GemmStridedBatchedExFunction : public CublasFunction {
|
|||
/**
|
||||
* @brief Prepare memory and data of the input and output for kernel running
|
||||
*/
|
||||
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
|
||||
virtual void prepare_tensor() {
|
||||
if (this->datatype_.compare("half")) {
|
||||
prepare_tensor_template<half>(
|
||||
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
|
||||
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
|
||||
reinterpret_cast<half **>(&Parameter_1_0_host));
|
||||
} else if (this->datatype_.compare("float")) {
|
||||
prepare_tensor_template<float>(
|
||||
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
|
||||
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
|
||||
reinterpret_cast<float **>(&Parameter_1_0_host));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Function calculation on CPU side
|
||||
*/
|
||||
virtual void matrix_calculation_on_cpu() {
|
||||
if (this->datatype_.compare("half")) {
|
||||
matrix_calculation_on_cpu_with_data(
|
||||
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
|
||||
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
|
||||
} else if (this->datatype_.compare("float"), 1.0f, 1.0f) {
|
||||
matrix_calculation_on_cpu_with_data(
|
||||
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
|
||||
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu), 1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Check the correctness of function calculation result
|
||||
*/
|
||||
virtual int correctness_check() {
|
||||
int result = 0;
|
||||
if (this->datatype_.compare("half")) {
|
||||
double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
|
||||
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
|
||||
reinterpret_cast<float *>(Result_cpu), eps);
|
||||
} else if (this->datatype_.compare("float")) {
|
||||
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
|
||||
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
|
||||
reinterpret_cast<float *>(Result_cpu), eps);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -195,6 +327,8 @@ class GemmStridedBatchedExFunction : public CublasFunction {
|
|||
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Result_3_0));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
cuda_free(&cublas_handle);
|
||||
}
|
||||
};
|
||||
|
@ -206,6 +340,9 @@ class SgemmStridedBatchedFunction : public CublasFunction {
|
|||
float *Parameter_0_0;
|
||||
float *Parameter_1_0;
|
||||
float *Result_3_0;
|
||||
float *Parameter_0_0_host;
|
||||
float *Parameter_1_0_host;
|
||||
float *Result_cpu;
|
||||
/**
|
||||
* @brief Execute the kernel/function
|
||||
*/
|
||||
|
@ -218,7 +355,23 @@ class SgemmStridedBatchedFunction : public CublasFunction {
|
|||
/**
|
||||
* @brief Prepare memory and data of the input and output for kernel running
|
||||
*/
|
||||
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
|
||||
virtual void prepare_tensor() {
|
||||
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
|
||||
}
|
||||
/**
|
||||
* @brief Function calculation on CPU side
|
||||
*/
|
||||
virtual void matrix_calculation_on_cpu() {
|
||||
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu, 1.0f,
|
||||
1.0f);
|
||||
}
|
||||
/**
|
||||
* @brief Check the correctness of function calculation result
|
||||
*/
|
||||
virtual int correctness_check() {
|
||||
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
|
||||
return check_result(this->batch_count_, Result_3_0, Result_cpu, eps);
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -238,6 +391,8 @@ class SgemmStridedBatchedFunction : public CublasFunction {
|
|||
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Result_3_0));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
cuda_free(&cublas_handle);
|
||||
}
|
||||
};
|
||||
|
@ -249,6 +404,9 @@ class Cgemm3mStridedBatchedFunction : public CublasFunction {
|
|||
cuComplex *Parameter_0_0;
|
||||
cuComplex *Parameter_1_0;
|
||||
cuComplex *Result_3_0;
|
||||
cuComplex *Parameter_0_0_host;
|
||||
cuComplex *Parameter_1_0_host;
|
||||
std::complex<float> *Result_cpu;
|
||||
/**
|
||||
* @brief Execute the kernel/function
|
||||
*/
|
||||
|
@ -262,7 +420,20 @@ class Cgemm3mStridedBatchedFunction : public CublasFunction {
|
|||
* @brief Prepare memory and data of the input and output for kernel running
|
||||
*/
|
||||
virtual void prepare_tensor() {
|
||||
CublasFunction::prepare_tensor_cucomplex(&Parameter_0_0, &Parameter_1_0, &Result_3_0);
|
||||
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
|
||||
}
|
||||
/**
|
||||
* @brief Function calculation on CPU side
|
||||
*/
|
||||
virtual void matrix_calculation_on_cpu() {
|
||||
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu);
|
||||
}
|
||||
/**
|
||||
* @brief Check the correctness of function calculation result
|
||||
*/
|
||||
virtual int correctness_check() {
|
||||
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
|
||||
return check_result(this->batch_count_, Result_3_0, Result_cpu, eps);
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -283,6 +454,8 @@ class Cgemm3mStridedBatchedFunction : public CublasFunction {
|
|||
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
|
||||
CUDA_SAFE_CALL(cudaFree(Result_3_0));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
|
||||
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
|
||||
cuda_free(&cublas_handle);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
@ -52,6 +53,18 @@ class Options {
|
|||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the double type value of cmd line argument
|
||||
* @param option the cmd line argument
|
||||
* @return double the double type value of cmd line argument 'option'
|
||||
*/
|
||||
double get_cmd_line_argument_double(const std::string &option) {
|
||||
if (char *value = get_cmd_option(option)) {
|
||||
return std::atof(value);
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the string type value of cmd line argument
|
||||
* @param option the cmd line argument
|
||||
|
@ -64,12 +77,27 @@ class Options {
|
|||
return "";
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the bool type value of cmd line argument
|
||||
* @param option the cmd line argument
|
||||
* @return std::string the int type value of cmd line argument 'option'
|
||||
*/
|
||||
bool get_cmd_line_argument_bool(const std::string &option) {
|
||||
char **itr = std::find(begin, end, option);
|
||||
if (itr != end) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
int num_test;
|
||||
int warm_up;
|
||||
int num_in_step;
|
||||
int random_seed;
|
||||
std::string para_info_json;
|
||||
bool correctness_check;
|
||||
double eps;
|
||||
|
||||
/**
|
||||
* @brief Construct a options object according to cmd or set a default value used to test
|
||||
|
@ -90,6 +118,8 @@ class Options {
|
|||
para_info_json = get_cmd_line_argument_string("--config_json");
|
||||
para_info_json = para_info_json == "" ? R"({"name":"cublasCgemm","m":512,"n":512,"k":32,"transa":1,"transb":0})"
|
||||
: para_info_json;
|
||||
correctness_check = get_cmd_line_argument_bool("--correctness");
|
||||
eps = get_cmd_line_argument_double("--eps");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -197,6 +227,8 @@ void run_benchmark(Options &options) {
|
|||
function.set_warm_up(options.warm_up);
|
||||
function.set_num_in_step(options.num_in_step);
|
||||
function.set_random_seed(options.random_seed);
|
||||
function.set_correctness(options.correctness_check);
|
||||
function.set_eps(options.eps);
|
||||
CublasFunction *p_function = get_cublas_function_pointer(function);
|
||||
p_function->benchmark();
|
||||
delete p_function;
|
||||
|
|
Загрузка…
Ссылка в новой задаче