FP16Benchmark: Allow fp32 comparison using cblas (#56)

Summary:
FP16Benchmark: Allow comparison against fp32
using any local cblas library if MKL not found.
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/56

Reviewed By: jianyuh

Differential Revision: D13645545

Pulled By: dskhudia

fbshipit-source-id: ca98e84bfb85eb3b0edebad664d211c3af8db309
This commit is contained in:
WilliamTambellini 2019-01-14 11:05:42 -08:00 коммит произвёл Facebook Github Bot
Родитель 36309fc567
Коммит 9a59fbd05f
3 изменённых файлов: 33 добавлений и 7 удалений

1
.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1 @@
build/

Просмотреть файл

@ -1,6 +1,10 @@
cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
find_package(MKL)
if (NOT ${MKL_FOUND})
find_package(BLAS)
endif()
#benchmarks
macro(add_benchmark BENCHNAME)
add_executable(${BENCHNAME} ${ARGN}
@ -13,11 +17,18 @@ macro(add_benchmark BENCHNAME)
target_link_libraries(${BENCHNAME} fbgemm)
add_dependencies(${BENCHNAME} fbgemm)
if(${MKL_FOUND})
message(STATUS "MKL_LIBRARIES= ${MKL_LIBRARIES}")
target_include_directories(${BENCHNAME} PRIVATE "${MKL_INCLUDE_DIR}")
target_link_libraries(${BENCHNAME} "${MKL_LIBRARIES}")
target_compile_options(${BENCHNAME} PRIVATE
"-DUSE_MKL")
endif()
if (${BLAS_FOUND})
message(STATUS "BLAS_LIBRARIES= ${BLAS_LIBRARIES}")
target_compile_options(${BENCHNAME} PRIVATE "-DUSE_BLAS")
target_link_libraries(${BENCHNAME} "${BLAS_LIBRARIES}")
endif()
set_target_properties(${BENCHNAME} PROPERTIES FOLDER test)
endmacro()

Просмотреть файл

@ -12,6 +12,10 @@
#include <mkl.h>
#endif
#ifdef USE_BLAS
#include <cblas.h>
#endif
#ifdef _OPENMP
#include <omp.h>
#endif
@ -101,7 +105,7 @@ void performance_test() {
// warm up MKL and fbgemm
// check correctness at the same time
for (auto w = 0; w < 3; w++) {
#ifdef USE_MKL
#if defined(USE_MKL) || defined(USE_BLAS)
cblas_sgemm(
CblasRowMajor,
CblasNoTrans,
@ -121,19 +125,29 @@ void performance_test() {
cblas_gemm_compute(
matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data());
#ifdef USE_MKL
#if defined(USE_MKL) || defined(USE_BLAS)
// Compare results
for (auto i = 0; i < C_ref.size(); i++) {
// printf("%f %f\n", C_ref[i], C_fb[i]);
assert(std::abs(C_ref[i] - C_fb[i]) < 1e-3);
if (std::abs(C_ref[i] - C_fb[i]) > 1e-3) {
fprintf(
stderr,
"Error: too high diff between fp32 ref %f and fp16 %f\n",
C_ref[i],
C_fb[i]);
return;
}
}
#endif
}
chrono::time_point<chrono::system_clock> t_begin, t_end;
#ifdef USE_MKL
#if defined(USE_MKL) || defined(USE_BLAS)
// Gold via MKL sgemm
#if defined(USE_MKL)
type = "MKL_FP32";
#else
type = "BLAS_FP32";
#endif
ttot = 0;
for (auto it = -3; it < NITER; it++) {
if (flush) {
@ -166,7 +180,7 @@ void performance_test() {
gflops = nflops / ttot / 1e9;
gbs = nbytes / ttot / 1e9;
printf(
"\n%15s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
"\n%30s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
type.c_str(),
m,
n,
@ -199,7 +213,7 @@ void performance_test() {
gflops = nflops / ttot / 1e9;
gbs = nbytes / ttot / 1e9;
printf(
"%15s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
"%30s m = %5d n = %5d k = %5d Gflops = %8.4lf GBytes = %8.4lf\n",
type.c_str(),
m,
n,