Add curandGetErrorString and use it to redefine CURAND_CHECK

This commit is contained in:
Kai Li 2014-03-31 22:29:36 +08:00
Родитель aa4139b34e
Коммит 14fe30cc09
2 изменённых файлов: 39 добавлений и 1 удалений

Просмотреть файл

@ -38,7 +38,12 @@ private:\
cublasStatus_t status = condition; \ cublasStatus_t status = condition; \
CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << cublasGetErrorString(status); \ CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << cublasGetErrorString(status); \
} while(0) } while(0)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define CURAND_CHECK(condition) \
do { \
curandStatus_t status = condition; \
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << curandGetErrorString(status); \
} while(0)
// CUDA: grid stride looping // CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \ #define CUDA_KERNEL_LOOP(i, n) \
@ -135,6 +140,7 @@ class Caffe {
// NVIDIA_CUDA-5.5_Samples/common/inc/helper_cuda.h // NVIDIA_CUDA-5.5_Samples/common/inc/helper_cuda.h
const char* cublasGetErrorString(cublasStatus_t& error); const char* cublasGetErrorString(cublasStatus_t& error);
const char* curandGetErrorString(curandStatus_t error);
// CUDA: thread number configuration. // CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above, // Use 1024 threads per block, which requires cuda sm_2x or above,

Просмотреть файл

@ -169,4 +169,36 @@ const char* cublasGetErrorString(cublasStatus_t& error) {
return "Unknown cublas status"; return "Unknown cublas status";
} }
const char* curandGetErrorString(curandStatus_t error) {
switch (error) {
case CURAND_STATUS_SUCCESS:
return "CURAND_STATUS_SUCCESS";
case CURAND_STATUS_VERSION_MISMATCH:
return "CURAND_STATUS_VERSION_MISMATCH";
case CURAND_STATUS_NOT_INITIALIZED:
return "CURAND_STATUS_NOT_INITIALIZED";
case CURAND_STATUS_ALLOCATION_FAILED:
return "CURAND_STATUS_ALLOCATION_FAILED";
case CURAND_STATUS_TYPE_ERROR:
return "CURAND_STATUS_TYPE_ERROR";
case CURAND_STATUS_OUT_OF_RANGE:
return "CURAND_STATUS_OUT_OF_RANGE";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case CURAND_STATUS_LAUNCH_FAILURE:
return "CURAND_STATUS_LAUNCH_FAILURE";
case CURAND_STATUS_PREEXISTING_FAILURE:
return "CURAND_STATUS_PREEXISTING_FAILURE";
case CURAND_STATUS_INITIALIZATION_FAILED:
return "CURAND_STATUS_INITIALIZATION_FAILED";
case CURAND_STATUS_ARCH_MISMATCH:
return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR";
}
return "Unknown curand status";
}
} // namespace caffe } // namespace caffe