Integrate anthaue/addblockmultiplier into master

This commit is contained in:
Project Philly 2016-06-24 16:50:41 -07:00
Родитель e04a9bd7f3 d2bf769c83
Коммит d39410d2fc
12 изменённых файлов: 3902 добавлений и 3 удалений

Просмотреть файл

@ -68,7 +68,7 @@ INCLUDEPATH:= $(addprefix $(SOURCEDIR)/, Common/Include CNTKv2LibraryDll CNTKv2L
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
COMMON_FLAGS:= -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11
CPPFLAGS:=
CXXFLAGS:= -msse3 -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new
CXXFLAGS:= -msse3 -mssse3 -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new
LIBPATH:=
LIBS:=
LDFLAGS:=
@ -87,7 +87,7 @@ SRC:=
all : buildall
# Set up basic nvcc options and add CUDA targets from above
CUFLAGS = -m 64
CUFLAGS = -m 64
ifdef CUDA_PATH
ifndef GDK_PATH
@ -167,6 +167,10 @@ ifdef KALDI_PATH
KALDI_LIBS += -lkaldi-util -lkaldi-matrix -lkaldi-base -lkaldi-hmm -lkaldi-cudamatrix -lkaldi-nnet -lkaldi-lat
endif
ifdef SUPPORT_AVX2
CPPFLAGS += -mavx2
endif
# Set up nvcc target architectures (will generate code to support them all, i.e. fat-binary, in release mode)
# In debug mode we will rely on JIT to create code "on the fly" for the underlying architecture
GENCODE_SM30 := -gencode arch=compute_30,code=\"sm_30,compute_30\"
@ -269,6 +273,7 @@ COMMON_SRC =\
$(SOURCEDIR)/Common/fileutil.cpp \
MATH_SRC =\
$(SOURCEDIR)/Math/BlockHandlerSSE.cpp \
$(SOURCEDIR)/Math/CPUMatrix.cpp \
$(SOURCEDIR)/Math/CPUSparseMatrix.cpp \
$(SOURCEDIR)/Math/CPURNGHandle.cpp \
@ -282,6 +287,12 @@ MATH_SRC =\
$(SOURCEDIR)/Math/ConvolutionEngine.cpp \
$(SOURCEDIR)/Math/BatchNormalizationEngine.cpp \
ifdef SUPPORT_AVX2
MATH_SRC +=\
$(SOURCEDIR)/Math/BlockHandlerAVX.cpp \
endif
ifdef CUDA_PATH
MATH_SRC +=\
$(SOURCEDIR)/Math/GPUMatrix.cu \

Просмотреть файл

@ -0,0 +1,48 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full licence information.
//
#include "stdafx.h"
#include <malloc.h>
#include <xmmintrin.h>
#include <emmintrin.h>
#include <tmmintrin.h>
#include <assert.h>
#include <iostream>
#include <exception>
#include "BlockMultiplierMatrixUtil.h"
#include "BlockHandlerAVX.h"
namespace Microsoft { namespace MSR { namespace CNTK {
int BlockHandlerAVX::RowToColOffsetRewrittenA(int row, int kOffset, int blockSize, int rowsPerBlock, int origCols)
{
int rowIdx = row / rowsPerBlock;
int offsetFromBlockBeginning = row % rowsPerBlock;
int colIdx = kOffset * rowsPerBlock * blockSize + (offsetFromBlockBeginning * blockSize);
return (rowIdx * (origCols / blockSize) * rowsPerBlock * blockSize) + colIdx;
}
//col is the original column of B
//kOffset is the offset to the current block we are multiplying against (in absolute
int BlockHandlerAVX::RowToColOffsetRewrittenB(int col, int kOffset, int blockSize, int origCols)
{
return (origCols * blockSize * kOffset) + (col * blockSize);
}
void BlockHandlerAVX::DumpM256(__m256i dumpMe)
{
union { int32_t i[8]; __m256i y; } u;
u.y = dumpMe;
for (int i = 0; i < 8; ++i)
{
std::cout << u.i[i] << " ";
}
std::cout << std::endl;
}
}}}

Просмотреть файл

@ -0,0 +1,961 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full licence information.
//
#pragma once
#include "BlockMultiplierPlatform.h"
#include <immintrin.h>
#include <emmintrin.h>
#include <assert.h>
#include <cstdint>
#define FOR_CNTK
#ifdef FOR_CNTK
#include "CommonMatrix.h"
#endif
namespace Microsoft { namespace MSR { namespace CNTK {
class MATH_API BlockHandlerAVX
{
private:
//USE SSE for the blocks of 8, borrowed from BlockHandlerSSE
FORCEINLINE static void kernelsse8x4(__m128i xmmRow0, __m128i xmmRow1, __m128i xmmRow2, __m128i xmmRow3,
short* B, __m128i* return1, __m128i* return2, __m128i* return3, __m128i* return4);
FORCEINLINE static void kernelavx16x4(__m256i xmmRow0B0a, __m256i xmmRow1B0a, __m256i xmmRow2B0a, __m256i xmmRow3B0a,
short* B, __m256i* return1, __m256i* return2, __m256i * return3, __m256i* return4);
FORCEINLINE static void kernelavx32x4(
__m256i xmmRow0B0a, __m256i xmmRow0B0b,
__m256i xmmRow1B0a, __m256i xmmRow1B0b,
__m256i xmmRow2B0a, __m256i xmmRow2B0b,
__m256i xmmRow3B0a, __m256i xmmRow3B0b,
short* B, __m256i* return1, __m256i* return2, __m256i * return3, __m256i* return4);
FORCEINLINE static void kernelavx64x4(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
__m256i xmmRow1B0a, __m256i xmmRow1B0b, __m256i xmmRow1B0c, __m256i xmmRow1B0d,
__m256i xmmRow2B0a, __m256i xmmRow2B0b, __m256i xmmRow2B0c, __m256i xmmRow2B0d,
__m256i xmmRow3B0a, __m256i xmmRow3B0b, __m256i xmmRow3B0c, __m256i xmmRow3B0d,
short* B, __m256i* return1, __m256i* return2, __m256i * return3, __m256i* return4);
FORCEINLINE static void kernelavx128x4(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
__m256i xmmRow0B0e, __m256i xmmRow0B0f, __m256i xmmRow0B0g, __m256i xmmRow0B0h,
__m256i xmmRow1B0a, __m256i xmmRow1B0b, __m256i xmmRow1B0c, __m256i xmmRow1B0d,
__m256i xmmRow1B0e, __m256i xmmRow1B0f, __m256i xmmRow1B0g, __m256i xmmRow1B0h,
__m256i xmmRow2B0a, __m256i xmmRow2B0b, __m256i xmmRow2B0c, __m256i xmmRow2B0d,
__m256i xmmRow2B0e, __m256i xmmRow2B0f, __m256i xmmRow2B0g, __m256i xmmRow2B0h,
__m256i xmmRow3B0a, __m256i xmmRow3B0b, __m256i xmmRow3B0c, __m256i xmmRow3B0d,
__m256i xmmRow3B0e, __m256i xmmRow3B0f, __m256i xmmRow3B0g, __m256i xmmRow3B0h,
short* B, __m256i* return1, __m256i* return2, __m256i* return3, __m256i* return4);
FORCEINLINE static void kernelsse8x1(__m128i xmmRow0,
short* B, __m128i* return1);
FORCEINLINE static void kernelavx16x1(__m256i xmmRow0B0a,
short* B, __m256i* return1 );
FORCEINLINE static void kernelavx32x1(
__m256i xmmRow0B0a, __m256i xmmRow0B0b,
short* B, __m256i* return1);
FORCEINLINE static void kernelavx64x1(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
short* B, __m256i* return1) ;
FORCEINLINE static void kernelavx128x1(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
__m256i xmmRow0B0e, __m256i xmmRow0B0f, __m256i xmmRow0B0g, __m256i xmmRow0B0h,
short* B, __m256i* return1);
//TODO: Should these be refactored somewhere else? Any BlockHandler will need access to these functions.
//Separate class with static functions? Maybe move the Block rewriting functions as well as these to a new
//static class.
static int RowToColOffsetRewrittenB(int col, int kOffset, int blockSize, int origCols);
static int RowToColOffsetRewrittenA(int row, int kOffset, int blockSize, int rowsPerBlock, int origCols);
static void DumpM256(__m256i dumpMe);
public:
typedef __m256i VectorT;
typedef int16_t ScalarAT;
typedef int16_t ScalarBT;
typedef int32_t ScalarCT;
FORCEINLINE static void HandleBlock8x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m128i* resultStorage);
FORCEINLINE static void HandleBlock32x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage);
FORCEINLINE static void HandleBlock64x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage);
FORCEINLINE static void HandleBlock128x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage, VectorT* subtractMe);
FORCEINLINE static void HandleBlock8x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m128i* resultStorage);
FORCEINLINE static void HandleBlock16x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage);
FORCEINLINE static void HandleBlock64x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage);
FORCEINLINE static void HandleBlock128x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage, VectorT* subtractMe);
FORCEINLINE static void HandleBlock16x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage);
//FORCEINLINE static void HandleBlock128x4(int currBlock, int startRow, int m, int k, int n, short* newA, short* B,
FORCEINLINE static void HandleBlock32x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage);
static VectorT* PrepareExtraB(const ScalarBT* /*prepareMe*/, int /*k*/, int /*n*/)
{
return nullptr;
}
static void FreePreparedB(VectorT* freeMe) { freeMe; assert(nullptr == freeMe); }
};
#define LOADAVX2_128x4 \
__m256i r0b0a2 = _mm256_load_si256((__m256i*)currA2); \
__m256i r0b0b2 = _mm256_load_si256((__m256i*)(currA2 + 16)); \
__m256i r0b0c2 = _mm256_load_si256((__m256i*)(currA2 + 32)); \
__m256i r0b0d2 = _mm256_load_si256((__m256i*)(currA2 + 48)); \
__m256i r0b0e2 = _mm256_load_si256((__m256i*)(currA2 + 64)); \
__m256i r0b0f2 = _mm256_load_si256((__m256i*)(currA2 + 80)); \
__m256i r0b0g2 = _mm256_load_si256((__m256i*)(currA2 + 96)); \
__m256i r0b0h2 = _mm256_load_si256((__m256i*)(currA2 + 112));\
\
__m256i r1b0a2 = _mm256_load_si256((__m256i*)(currA2 + 128));\
__m256i r1b0b2 = _mm256_load_si256((__m256i*)(currA2 + 144));\
__m256i r1b0c2 = _mm256_load_si256((__m256i*)(currA2 + 160));\
__m256i r1b0d2 = _mm256_load_si256((__m256i*)(currA2 + 176));\
__m256i r1b0e2 = _mm256_load_si256((__m256i*)(currA2 + 192));\
__m256i r1b0f2 = _mm256_load_si256((__m256i*)(currA2 + 208));\
__m256i r1b0g2 = _mm256_load_si256((__m256i*)(currA2 + 224));\
__m256i r1b0h2 = _mm256_load_si256((__m256i*)(currA2 + 240));\
\
__m256i r2b0a2 = _mm256_load_si256((__m256i*)(currA2 + 256));\
__m256i r2b0b2 = _mm256_load_si256((__m256i*)(currA2 + 272));\
__m256i r2b0c2 = _mm256_load_si256((__m256i*)(currA2 + 288));\
__m256i r2b0d2 = _mm256_load_si256((__m256i*)(currA2 + 304));\
__m256i r2b0e2 = _mm256_load_si256((__m256i*)(currA2 + 320));\
__m256i r2b0f2 = _mm256_load_si256((__m256i*)(currA2 + 336));\
__m256i r2b0g2 = _mm256_load_si256((__m256i*)(currA2 + 352));\
__m256i r2b0h2 = _mm256_load_si256((__m256i*)(currA2 + 368));\
\
__m256i r3b0a2 = _mm256_load_si256((__m256i*)(currA2 + 384));\
__m256i r3b0b2 = _mm256_load_si256((__m256i*)(currA2 + 400));\
__m256i r3b0c2 = _mm256_load_si256((__m256i*)(currA2 + 416));\
__m256i r3b0d2 = _mm256_load_si256((__m256i*)(currA2 + 432));\
__m256i r3b0e2 = _mm256_load_si256((__m256i*)(currA2 + 448));\
__m256i r3b0f2 = _mm256_load_si256((__m256i*)(currA2 + 464));\
__m256i r3b0g2 = _mm256_load_si256((__m256i*)(currA2 + 480));\
__m256i r3b0h2 = _mm256_load_si256((__m256i*)(currA2 + 496));\
#define LOADAVX2_128x1 \
__m256i r0b0a2 = _mm256_load_si256((__m256i*)currA2); \
__m256i r0b0b2 = _mm256_load_si256((__m256i*)(currA2 + 16)); \
__m256i r0b0c2 = _mm256_load_si256((__m256i*)(currA2 + 32)); \
__m256i r0b0d2 = _mm256_load_si256((__m256i*)(currA2 + 48)); \
__m256i r0b0e2 = _mm256_load_si256((__m256i*)(currA2 + 64)); \
__m256i r0b0f2 = _mm256_load_si256((__m256i*)(currA2 + 80)); \
__m256i r0b0g2 = _mm256_load_si256((__m256i*)(currA2 + 96)); \
__m256i r0b0h2 = _mm256_load_si256((__m256i*)(currA2 + 112));
#define LOADAVX_128x1 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA); \
__m256i r0b0b = _mm256_load_si256((__m256i*)(currA + 16)); \
__m256i r0b0c = _mm256_load_si256((__m256i*)(currA + 32)); \
__m256i r0b0d = _mm256_load_si256((__m256i*)(currA + 48)); \
__m256i r0b0e = _mm256_load_si256((__m256i*)(currA + 64)); \
__m256i r0b0f = _mm256_load_si256((__m256i*)(currA + 80)); \
__m256i r0b0g = _mm256_load_si256((__m256i*)(currA + 96)); \
__m256i r0b0h = _mm256_load_si256((__m256i*)(currA + 112));
#define LOADAVX_128x4 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA); \
__m256i r0b0b = _mm256_load_si256((__m256i*)(currA + 16)); \
__m256i r0b0c = _mm256_load_si256((__m256i*)(currA + 32)); \
__m256i r0b0d = _mm256_load_si256((__m256i*)(currA + 48)); \
__m256i r0b0e = _mm256_load_si256((__m256i*)(currA + 64)); \
__m256i r0b0f = _mm256_load_si256((__m256i*)(currA + 80)); \
__m256i r0b0g = _mm256_load_si256((__m256i*)(currA + 96)); \
__m256i r0b0h = _mm256_load_si256((__m256i*)(currA + 112));\
\
__m256i r1b0a = _mm256_load_si256((__m256i*)(currA + 128));\
__m256i r1b0b = _mm256_load_si256((__m256i*)(currA + 144));\
__m256i r1b0c = _mm256_load_si256((__m256i*)(currA + 160));\
__m256i r1b0d = _mm256_load_si256((__m256i*)(currA + 176));\
__m256i r1b0e = _mm256_load_si256((__m256i*)(currA + 192));\
__m256i r1b0f = _mm256_load_si256((__m256i*)(currA + 208));\
__m256i r1b0g = _mm256_load_si256((__m256i*)(currA + 224));\
__m256i r1b0h = _mm256_load_si256((__m256i*)(currA + 240));\
\
__m256i r2b0a = _mm256_load_si256((__m256i*)(currA + 256));\
__m256i r2b0b = _mm256_load_si256((__m256i*)(currA + 272));\
__m256i r2b0c = _mm256_load_si256((__m256i*)(currA + 288));\
__m256i r2b0d = _mm256_load_si256((__m256i*)(currA + 304));\
__m256i r2b0e = _mm256_load_si256((__m256i*)(currA + 320));\
__m256i r2b0f = _mm256_load_si256((__m256i*)(currA + 336));\
__m256i r2b0g = _mm256_load_si256((__m256i*)(currA + 352));\
__m256i r2b0h = _mm256_load_si256((__m256i*)(currA + 368));\
\
__m256i r3b0a = _mm256_load_si256((__m256i*)(currA + 384));\
__m256i r3b0b = _mm256_load_si256((__m256i*)(currA + 400));\
__m256i r3b0c = _mm256_load_si256((__m256i*)(currA + 416));\
__m256i r3b0d = _mm256_load_si256((__m256i*)(currA + 432));\
__m256i r3b0e = _mm256_load_si256((__m256i*)(currA + 448));\
__m256i r3b0f = _mm256_load_si256((__m256i*)(currA + 464));\
__m256i r3b0g = _mm256_load_si256((__m256i*)(currA + 480));\
__m256i r3b0h = _mm256_load_si256((__m256i*)(currA + 496));\
#define LOADAVX_64x4 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA); \
__m256i r0b0b = _mm256_load_si256((__m256i*)currA + 1); \
__m256i r0b0c = _mm256_load_si256((__m256i*)currA + 2); \
__m256i r0b0d = _mm256_load_si256((__m256i*)currA + 3); \
\
__m256i r1b0a = _mm256_load_si256((__m256i*)currA + 4);\
__m256i r1b0b = _mm256_load_si256((__m256i*)currA + 5);\
__m256i r1b0c = _mm256_load_si256((__m256i*)currA + 6);\
__m256i r1b0d = _mm256_load_si256((__m256i*)currA + 7);\
\
__m256i r2b0a = _mm256_load_si256((__m256i*)currA + 8);\
__m256i r2b0b = _mm256_load_si256((__m256i*)currA + 9);\
__m256i r2b0c = _mm256_load_si256((__m256i*)currA + 10);\
__m256i r2b0d = _mm256_load_si256((__m256i*)currA + 11);\
\
__m256i r3b0a = _mm256_load_si256((__m256i*)currA + 12);\
__m256i r3b0b = _mm256_load_si256((__m256i*)currA + 13);\
__m256i r3b0c = _mm256_load_si256((__m256i*)currA + 14);\
__m256i r3b0d = _mm256_load_si256((__m256i*)currA + 15);
#define LOADAVX_64x1 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA); \
__m256i r0b0b = _mm256_load_si256((__m256i*)currA + 1); \
__m256i r0b0c = _mm256_load_si256((__m256i*)currA + 2); \
__m256i r0b0d = _mm256_load_si256((__m256i*)currA + 3);
#define LOADAVX_32x4 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA); \
__m256i r0b0b = _mm256_load_si256((__m256i*)currA + 1); \
\
__m256i r1b0a = _mm256_load_si256((__m256i*)currA + 2);\
__m256i r1b0b = _mm256_load_si256((__m256i*)currA + 3);\
\
__m256i r2b0a = _mm256_load_si256((__m256i*)currA + 4);\
__m256i r2b0b = _mm256_load_si256((__m256i*)currA + 5);\
\
__m256i r3b0a = _mm256_load_si256((__m256i*)currA + 6);\
__m256i r3b0b = _mm256_load_si256((__m256i*)currA + 7);\
#define LOADAVX_32x1 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA); \
__m256i r0b0b = _mm256_load_si256((__m256i*)currA + 1);
#define LOADAVX_16x4 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA); \
__m256i r1b0a = _mm256_load_si256((__m256i*)currA + 1);\
__m256i r2b0a = _mm256_load_si256((__m256i*)currA + 2);\
__m256i r3b0a = _mm256_load_si256((__m256i*)currA + 3);\
#define LOADAVX_16x1 \
__m256i r0b0a = _mm256_load_si256((__m256i*)currA);
#define LOAD_8x4 \
__m128i r0b0a = _mm_load_si128((__m128i*)currA);\
__m128i r1b0a = _mm_load_si128((__m128i*)currA + 1);\
__m128i r2b0a = _mm_load_si128((__m128i*)currA + 2);\
__m128i r3b0a = _mm_load_si128((__m128i*)currA + 3);\
#define LOAD_8x1 \
__m128i r0b0a = _mm_load_si128((__m128i*)currA);
FORCEINLINE void BlockHandlerAVX::HandleBlock8x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m128i* resultStorage)
{
blockCnt; //warning 4100
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 8, 4, k);
short* currA = &newA[aOffset];
LOAD_8x4;
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 8, n)];
__m128i accum1 = _mm_set_epi32(0, 0, 0, 0);
__m128i accum2 = _mm_set_epi32(0, 0, 0, 0);
__m128i accum3 = _mm_set_epi32(0, 0, 0, 0);
__m128i accum4 = _mm_set_epi32(0, 0, 0, 0);
kernelsse8x4(r0b0a, r1b0a, r2b0a, r3b0a,
currB, &accum1, &accum2, &accum3, &accum4);
resultStorage[RowColToOffset(0, c, n)] = _mm_add_epi32(resultStorage[RowColToOffset(0, c, n)], accum1);
resultStorage[RowColToOffset(1, c, n)] = _mm_add_epi32(resultStorage[RowColToOffset(1, c, n)], accum2);
resultStorage[RowColToOffset(2, c, n)] = _mm_add_epi32(resultStorage[RowColToOffset(2, c, n)], accum3);
resultStorage[RowColToOffset(3, c, n)] = _mm_add_epi32(resultStorage[RowColToOffset(3, c, n)], accum4);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock8x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int /*blockCnt*/, __m128i* resultStorage)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 8, 4, k);
short* currA = &newA[aOffset];
LOAD_8x1;
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 8, n)];
__m128i accum1 = _mm_set_epi32(0, 0, 0, 0);
kernelsse8x1(r0b0a,
currB, &accum1);
resultStorage[RowColToOffset(0, c, n)] = _mm_add_epi32(resultStorage[RowColToOffset(0, c, n)], accum1);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock16x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int /*blockCnt*/, __m256i* resultStorage)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 16, 4, k);
short* currA = &newA[aOffset];
LOADAVX_16x4;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 16, n)];
//The gain comes when we have all the row values loaded up
//together and we multiply them all times each column, saving m_rowsPerBlock column
//loads.
__m256i accum1 = _mm256_set1_epi16(0);
__m256i accum2 = _mm256_set1_epi16(0);
__m256i accum3 = _mm256_set1_epi16(0);
__m256i accum4 = _mm256_set1_epi16(0);
kernelavx16x4(r0b0a, r1b0a, r2b0a, r3b0a,
currB, &accum1, &accum2, &accum3, &accum4);
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32(resultStorage[RowColToOffset(0, c, n)], accum1);
resultStorage[RowColToOffset(1, c, n)] = _mm256_add_epi32(resultStorage[RowColToOffset(1, c, n)], accum2);
resultStorage[RowColToOffset(2, c, n)] = _mm256_add_epi32(resultStorage[RowColToOffset(2, c, n)], accum3);
resultStorage[RowColToOffset(3, c, n)] = _mm256_add_epi32(resultStorage[RowColToOffset(3, c, n)], accum4);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock16x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int /*blockCnt*/, __m256i* resultStorage)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 16, 1, k);
short* currA = &newA[aOffset];
LOADAVX_16x1;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 16, n)];
//The gain comes when we have all the row values loaded up
//together and we multiply them all times each column, saving m_rowsPerBlock column
//loads.
__m256i accum1 = _mm256_set1_epi16(0);
kernelavx16x1(r0b0a, currB, &accum1);
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32(resultStorage[RowColToOffset(0, c, n)], accum1);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock32x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int /*blockCnt*/, __m256i* resultStorage)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 32, 4, k);
short* currA = &newA[aOffset];
LOADAVX_32x4;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 32, n)];
//The gain comes when we have all the row values loaded up
//together and we multiply them all times each column, saving m_rowsPerBlock column
//loads.
__m256i accum1 = _mm256_set1_epi16(0);
__m256i accum2 = _mm256_set1_epi16(0);
__m256i accum3 = _mm256_set1_epi16(0);
__m256i accum4 = _mm256_set1_epi16(0);
kernelavx32x4(
r0b0a, r0b0b,
r1b0a, r1b0b,
r2b0a, r2b0b,
r3b0a, r3b0b,
currB, &accum1, &accum2, &accum3, &accum4);
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(0, c, n)], accum1);
resultStorage[RowColToOffset(1, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(1, c, n)], accum2);
resultStorage[RowColToOffset(2, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(2, c, n)], accum3);
resultStorage[RowColToOffset(3, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(3, c, n)], accum4);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock32x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int /*blockCnt*/, __m256i* resultStorage)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 32, 1, k);
short* currA = &newA[aOffset];
LOADAVX_32x1;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 32, n)];
__m256i accum1 = _mm256_set1_epi16(0);
kernelavx32x1(
r0b0a, r0b0b, currB, &accum1);
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(0, c, n)], accum1);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock64x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int /*blockCnt*/, __m256i* resultStorage)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 64, 4, k);
short* currA = &newA[aOffset];
LOADAVX_64x4;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 64, n)];
//The gain comes when we have all the row values loaded up
//together and we multiply them all times each column, saving m_rowsPerBlock column
//loads.
__m256i accum1 = _mm256_set1_epi16(0);
__m256i accum2 = _mm256_set1_epi16(0);
__m256i accum3 = _mm256_set1_epi16(0);
__m256i accum4 = _mm256_set1_epi16(0);
kernelavx64x4(
r0b0a, r0b0b, r0b0c, r0b0d,
r1b0a, r1b0b, r1b0c, r1b0d,
r2b0a, r2b0b, r2b0c, r2b0d,
r3b0a, r3b0b, r3b0c, r3b0d,
currB, &accum1, &accum2, &accum3, &accum4);
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(0, c, n)], accum1);
resultStorage[RowColToOffset(1, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(1, c, n)], accum2);
resultStorage[RowColToOffset(2, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(2, c, n)], accum3);
resultStorage[RowColToOffset(3, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(3, c, n)], accum4);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock64x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int /*blockCnt*/, __m256i* resultStorage)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 64, 4, k);
short* currA = &newA[aOffset];
LOADAVX_64x1;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 64, n)];
//The gain comes when we have all the row values loaded up
//together and we multiply them all times each column, saving m_rowsPerBlock column
//loads.
__m256i accum1 = _mm256_set1_epi16(0);
kernelavx64x1(
r0b0a, r0b0b, r0b0c, r0b0d,
currB, &accum1);
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32(resultStorage[RowColToOffset(0, c, n)], accum1);
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock128x4(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage, VectorT* /*subtractMe*/)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 128, 4, k);
int aOffset2 = RowToColOffsetRewrittenA(startRow, currBlock + 1, 128, 4, k);
short* currA = &newA[aOffset];
short* currA2 = &newA[aOffset2];
LOADAVX_128x4;
LOADAVX2_128x4;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 128, n)];
short* currB2 = &B[RowToColOffsetRewrittenB(c, currBlock + 1, 128, n)];
//The gain comes when we have all the row values loaded up
//together and we multiply them all times each column, saving m_rowsPerBlock column
//loads.
__m256i accum1 = _mm256_set1_epi16(0);
__m256i accum2 = _mm256_set1_epi16(0);
__m256i accum3 = _mm256_set1_epi16(0);
__m256i accum4 = _mm256_set1_epi16(0);
__m256i accum5 = _mm256_set1_epi16(0);
__m256i accum6 = _mm256_set1_epi16(0);
__m256i accum7 = _mm256_set1_epi16(0);
__m256i accum8 = _mm256_set1_epi16(0);
kernelavx128x4(
r0b0a, r0b0b, r0b0c, r0b0d, r0b0e, r0b0f, r0b0g, r0b0h,
r1b0a, r1b0b, r1b0c, r1b0d, r1b0e, r1b0f, r1b0g, r1b0h,
r2b0a, r2b0b, r2b0c, r2b0d, r2b0e, r2b0f, r2b0g, r2b0h,
r3b0a, r3b0b, r3b0c, r3b0d, r3b0e, r3b0f, r3b0g, r3b0h,
currB, &accum1, &accum2, &accum3, &accum4);
if (blockCnt > 1)
{
kernelavx128x4(
r0b0a2, r0b0b2, r0b0c2, r0b0d2, r0b0e2, r0b0f2, r0b0g2, r0b0h2,
r1b0a2, r1b0b2, r1b0c2, r1b0d2, r1b0e2, r1b0f2, r1b0g2, r1b0h2,
r2b0a2, r2b0b2, r2b0c2, r2b0d2, r2b0e2, r2b0f2, r2b0g2, r2b0h2,
r3b0a2, r3b0b2, r3b0c2, r3b0d2, r3b0e2, r3b0f2, r3b0g2, r3b0h2,
currB2, &accum5, &accum6, &accum7, &accum8);
}
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(0, c, n)], _mm256_add_epi32(accum1, accum5));
resultStorage[RowColToOffset(1, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(1, c, n)], _mm256_add_epi32(accum2, accum6));
resultStorage[RowColToOffset(2, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(2, c, n)], _mm256_add_epi32(accum3, accum7));
resultStorage[RowColToOffset(3, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(3, c, n)], _mm256_add_epi32(accum4, accum8));
}
}
FORCEINLINE void BlockHandlerAVX::HandleBlock128x1(int currBlock, int startRow, int k, int n, short* newA, short* B,
int blockCnt, __m256i* resultStorage, VectorT* /*subtractMe*/)
{
int aOffset = RowToColOffsetRewrittenA(startRow, currBlock, 128, 4, k);
int aOffset2 = RowToColOffsetRewrittenA(startRow, currBlock + 1, 128, 4, k);
short* currA = &newA[aOffset];
short* currA2 = &newA[aOffset2];
LOADAVX_128x1;
LOADAVX2_128x1;
//#pragma omp parallel for
for (int c = 0; c < n; ++c)
{
short* currB = &B[RowToColOffsetRewrittenB(c, currBlock, 128, n)];
short* currB2 = &B[RowToColOffsetRewrittenB(c, currBlock + 1, 128, n)];
//The gain comes when we have all the row values loaded up
//together and we multiply them all times each column, saving m_rowsPerBlock column
//loads.
__m256i accum1 = _mm256_set1_epi16(0);
__m256i accum2 = _mm256_set1_epi16(0);
kernelavx128x1(
r0b0a, r0b0b, r0b0c, r0b0d, r0b0e, r0b0f, r0b0g, r0b0h,
currB, &accum1);
if (blockCnt > 1)
{
kernelavx128x1(
r0b0a2, r0b0b2, r0b0c2, r0b0d2, r0b0e2, r0b0f2, r0b0g2, r0b0h2,
currB2, &accum1);
}
resultStorage[RowColToOffset(0, c, n)] = _mm256_add_epi32( resultStorage[RowColToOffset(0, c, n)], _mm256_add_epi32(accum1, accum2));
}
}
FORCEINLINE void BlockHandlerAVX::kernelsse8x1(__m128i xmmRow0,
short* B, __m128i* return1)
{
__m128i xmmCol0 = _mm_load_si128((__m128i*)B);
__m128i result1 = _mm_madd_epi16(xmmRow0, xmmCol0);
*return1 = result1;
}
FORCEINLINE void BlockHandlerAVX::kernelsse8x4(__m128i xmmRow0, __m128i xmmRow1, __m128i xmmRow2, __m128i xmmRow3,
short* B, __m128i* return1, __m128i* return2, __m128i* return3, __m128i* return4)
{
__m128i xmmCol0 = _mm_load_si128((__m128i*)B);
__m128i result1 = _mm_madd_epi16(xmmRow0, xmmCol0);
__m128i result2 = _mm_madd_epi16(xmmRow1, xmmCol0);
__m128i result3 = _mm_madd_epi16(xmmRow2, xmmCol0);
__m128i result4 = _mm_madd_epi16(xmmRow3, xmmCol0);
*return1 = result1;
*return2 = result2;
*return3 = result3;
*return4 = result4;
}
FORCEINLINE void BlockHandlerAVX::kernelavx16x1(__m256i xmmRow0B0a,
short* B, __m256i* return1)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
//Result for row 0
//Nomenclature:
//r0b0axc0b0a means "Row zero block zero part A times column zero block zero part A. (Blocks > 8 take up > 1 __m256i each (xmm registers))
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
*return1 = r0b0axc0b0a;
}
FORCEINLINE void BlockHandlerAVX::kernelavx16x4(__m256i xmmRow0B0a, __m256i xmmRow1B0a, __m256i xmmRow2B0a, __m256i xmmRow3B0a,
short* B, __m256i* return1, __m256i* return2, __m256i * return3, __m256i* return4)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
//Result for row 0
//Nomenclature:
//r0b0axc0b0a means "Row zero block zero part A times column zero block zero part A. (Blocks > 8 take up > 1 __m256i each (xmm registers))
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
//Result for row 1
__m256i r1b0axc0b0a = _mm256_madd_epi16(xmmRow1B0a, xmmCol0B0a);
//Result for row 2
__m256i r2b0axc0b0a = _mm256_madd_epi16(xmmRow2B0a, xmmCol0B0a);
//Result for row 3
__m256i r3b0axc0b0a = _mm256_madd_epi16(xmmRow3B0a, xmmCol0B0a);
*return1 = r0b0axc0b0a;
*return2 = r1b0axc0b0a;
*return3 = r2b0axc0b0a;
*return4 = r3b0axc0b0a;
}
FORCEINLINE void BlockHandlerAVX::kernelavx32x1(
__m256i xmmRow0B0a, __m256i xmmRow0B0b,
short* B, __m256i* return1)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
__m256i xmmCol0B0b = _mm256_load_si256((__m256i*)B + 1);
//Result for row 0
//Nomenclature:
//r0b0axc0b0a means "Row zero block zero part A times column zero block zero part A. (Blocks > 8 take up > 1 __m256i each (xmm registers))
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
__m256i r0b0bxc0b0b = _mm256_madd_epi16(xmmRow0B0b, xmmCol0B0b);
__m256i result1a = _mm256_add_epi32(r0b0axc0b0a, r0b0bxc0b0b);
*return1 = result1a;
}
FORCEINLINE void BlockHandlerAVX::kernelavx32x4(
__m256i xmmRow0B0a, __m256i xmmRow0B0b,
__m256i xmmRow1B0a, __m256i xmmRow1B0b,
__m256i xmmRow2B0a, __m256i xmmRow2B0b,
__m256i xmmRow3B0a, __m256i xmmRow3B0b,
short* B, __m256i* return1, __m256i* return2, __m256i * return3, __m256i* return4)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
__m256i xmmCol0B0b = _mm256_load_si256((__m256i*)B + 1);
//Result for row 0
//Nomenclature:
//r0b0axc0b0a means "Row zero block zero part A times column zero block zero part A. (Blocks > 8 take up > 1 __m256i each (xmm registers))
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
__m256i r0b0bxc0b0b = _mm256_madd_epi16(xmmRow0B0b, xmmCol0B0b);
__m256i result1a = _mm256_add_epi32(r0b0axc0b0a, r0b0bxc0b0b);
//Result for row 1
__m256i r1b0axc0b0a = _mm256_madd_epi16(xmmRow1B0a, xmmCol0B0a);
__m256i r1b0bxc0b0b = _mm256_madd_epi16(xmmRow1B0b, xmmCol0B0b);
__m256i result2a = _mm256_add_epi32(r1b0axc0b0a, r1b0bxc0b0b);
//Result for row 2
__m256i r2b0axc0b0a = _mm256_madd_epi16(xmmRow2B0a, xmmCol0B0a);
__m256i r2b0bxc0b0b = _mm256_madd_epi16(xmmRow2B0b, xmmCol0B0b);
__m256i result3a = _mm256_add_epi32(r2b0axc0b0a, r2b0bxc0b0b);
//Result for row 3
__m256i r3b0axc0b0a = _mm256_madd_epi16(xmmRow3B0a, xmmCol0B0a);
__m256i r3b0bxc0b0b = _mm256_madd_epi16(xmmRow3B0b, xmmCol0B0b);
__m256i result4a = _mm256_add_epi32(r3b0axc0b0a, r3b0bxc0b0b);
*return1 = result1a;
*return2 = result2a;
*return3 = result3a;
*return4 = result4a;
}
FORCEINLINE void BlockHandlerAVX::kernelavx64x1(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
short* B, __m256i* return1)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
__m256i xmmCol0B0b = _mm256_load_si256((__m256i*)B + 1);
__m256i xmmCol0B0c = _mm256_load_si256((__m256i*)B + 2);
__m256i xmmCol0B0d = _mm256_load_si256((__m256i*)B + 3);
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
__m256i r0b0bxc0b0b = _mm256_madd_epi16(xmmRow0B0b, xmmCol0B0b);
__m256i r0b0cxc0b0c = _mm256_madd_epi16(xmmRow0B0c, xmmCol0B0c);
__m256i r0b0dxc0b0d = _mm256_madd_epi16(xmmRow0B0d, xmmCol0B0d);
__m256i result1a = _mm256_add_epi32(r0b0axc0b0a, r0b0bxc0b0b);
__m256i result1b = _mm256_add_epi32(r0b0cxc0b0c, r0b0dxc0b0d);
__m256i result1ab = _mm256_add_epi32(result1a, result1b);
*return1 = result1ab;
//std::cout << "Returning " << u.i[0] << " + " << u.i[4] << "(" << u.i[0] + u.i[4] << ") for first row" << std::endl;
}
FORCEINLINE void BlockHandlerAVX::kernelavx64x4(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
__m256i xmmRow1B0a, __m256i xmmRow1B0b, __m256i xmmRow1B0c, __m256i xmmRow1B0d,
__m256i xmmRow2B0a, __m256i xmmRow2B0b, __m256i xmmRow2B0c, __m256i xmmRow2B0d,
__m256i xmmRow3B0a, __m256i xmmRow3B0b, __m256i xmmRow3B0c, __m256i xmmRow3B0d,
short* B, __m256i* return1, __m256i* return2, __m256i * return3, __m256i* return4)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
__m256i xmmCol0B0b = _mm256_load_si256((__m256i*)B + 1);
__m256i xmmCol0B0c = _mm256_load_si256((__m256i*)B + 2);
__m256i xmmCol0B0d = _mm256_load_si256((__m256i*)B + 3);
//Result for row 0
//Nomenclature:
//r0b0axc0b0a means "Row zero block zero part A times column zero block zero part A. (Blocks > 8 take up > 1 __m256i each (xmm registers))
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
__m256i r0b0bxc0b0b = _mm256_madd_epi16(xmmRow0B0b, xmmCol0B0b);
__m256i r0b0cxc0b0c = _mm256_madd_epi16(xmmRow0B0c, xmmCol0B0c);
__m256i r0b0dxc0b0d = _mm256_madd_epi16(xmmRow0B0d, xmmCol0B0d);
__m256i result1a = _mm256_add_epi32(r0b0axc0b0a, r0b0bxc0b0b);
__m256i result1b = _mm256_add_epi32(r0b0cxc0b0c, r0b0dxc0b0d);
__m256i result1ab = _mm256_add_epi32(result1a, result1b);
//Result for row 1
__m256i r1b0axc0b0a = _mm256_madd_epi16(xmmRow1B0a, xmmCol0B0a);
__m256i r1b0bxc0b0b = _mm256_madd_epi16(xmmRow1B0b, xmmCol0B0b);
__m256i r1b0cxc0b0c = _mm256_madd_epi16(xmmRow1B0c, xmmCol0B0c);
__m256i r1b0dxc0b0d = _mm256_madd_epi16(xmmRow1B0d, xmmCol0B0d);
__m256i result2a = _mm256_add_epi32(r1b0axc0b0a, r1b0bxc0b0b);
__m256i result2b = _mm256_add_epi32(r1b0cxc0b0c, r1b0dxc0b0d);
__m256i result2ab = _mm256_add_epi32(result2a, result2b);
//Result for row 2
__m256i r2b0axc0b0a = _mm256_madd_epi16(xmmRow2B0a, xmmCol0B0a);
__m256i r2b0bxc0b0b = _mm256_madd_epi16(xmmRow2B0b, xmmCol0B0b);
__m256i r2b0cxc0b0c = _mm256_madd_epi16(xmmRow2B0c, xmmCol0B0c);
__m256i r2b0dxc0b0d = _mm256_madd_epi16(xmmRow2B0d, xmmCol0B0d);
__m256i result3a = _mm256_add_epi32(r2b0axc0b0a, r2b0bxc0b0b);
__m256i result3b = _mm256_add_epi32(r2b0cxc0b0c, r2b0dxc0b0d);
__m256i result3ab = _mm256_add_epi32(result3a, result3b);
//Result for row 3
__m256i r3b0axc0b0a = _mm256_madd_epi16(xmmRow3B0a, xmmCol0B0a);
__m256i r3b0bxc0b0b = _mm256_madd_epi16(xmmRow3B0b, xmmCol0B0b);
__m256i r3b0cxc0b0c = _mm256_madd_epi16(xmmRow3B0c, xmmCol0B0c);
__m256i r3b0dxc0b0d = _mm256_madd_epi16(xmmRow3B0d, xmmCol0B0d);
__m256i result4a = _mm256_add_epi32(r3b0axc0b0a, r3b0bxc0b0b);
__m256i result4b = _mm256_add_epi32(r3b0cxc0b0c, r3b0dxc0b0d);
__m256i result4ab = _mm256_add_epi32(result4a, result4b);
*return1 = result1ab;
*return2 = result2ab;
*return3 = result3ab;
*return4 = result4ab;
}
FORCEINLINE void BlockHandlerAVX::kernelavx128x1(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
__m256i xmmRow0B0e, __m256i xmmRow0B0f, __m256i xmmRow0B0g, __m256i xmmRow0B0h,
short* B, __m256i* return1)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
__m256i xmmCol0B0b = _mm256_load_si256((__m256i*)(B + 16));
__m256i xmmCol0B0c = _mm256_load_si256((__m256i*)(B + 32));
__m256i xmmCol0B0d = _mm256_load_si256((__m256i*)(B + 48));
__m256i xmmCol0B0e = _mm256_load_si256((__m256i*)(B + 64));
__m256i xmmCol0B0f = _mm256_load_si256((__m256i*)(B + 80));
__m256i xmmCol0B0g = _mm256_load_si256((__m256i*)(B + 96));
__m256i xmmCol0B0h = _mm256_load_si256((__m256i*)(B + 112));
//Result for row 0
//Nomenclature:
//r0b0axc0b0a means "Row zero block zero part A times column zero block zero part A. (Blocks > 8 take up > 1 __m256i each (xmm registers))
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
__m256i r0b0bxc0b0b = _mm256_madd_epi16(xmmRow0B0b, xmmCol0B0b);
__m256i r0b0cxc0b0c = _mm256_madd_epi16(xmmRow0B0c, xmmCol0B0c);
__m256i r0b0dxc0b0d = _mm256_madd_epi16(xmmRow0B0d, xmmCol0B0d);
__m256i r0b0exc0b0e = _mm256_madd_epi16(xmmRow0B0e, xmmCol0B0e);
__m256i r0b0fxc0b0f = _mm256_madd_epi16(xmmRow0B0f, xmmCol0B0f);
__m256i r0b0gxc0b0g = _mm256_madd_epi16(xmmRow0B0g, xmmCol0B0g);
__m256i r0b0hxc0b0h = _mm256_madd_epi16(xmmRow0B0h, xmmCol0B0h);
__m256i result1a = _mm256_add_epi32(r0b0axc0b0a, r0b0bxc0b0b);
__m256i result1b = _mm256_add_epi32(r0b0cxc0b0c, r0b0dxc0b0d);
__m256i result1c = _mm256_add_epi32(r0b0exc0b0e, r0b0fxc0b0f);
__m256i result1d = _mm256_add_epi32(r0b0gxc0b0g, r0b0hxc0b0h);
__m256i result1ab = _mm256_add_epi32(result1a, result1b);
__m256i result1cd = _mm256_add_epi32(result1c, result1d);
__m256i result1abcd = _mm256_add_epi32(result1ab, result1cd);
*return1 = result1abcd;
//std::cout << "Returning " << u.i[0] << " + " << u.i[4] << "(" << u.i[0] + u.i[4] << ") for first row" << std::endl;
}
FORCEINLINE void BlockHandlerAVX::kernelavx128x4(
__m256i xmmRow0B0a, __m256i xmmRow0B0b, __m256i xmmRow0B0c, __m256i xmmRow0B0d,
__m256i xmmRow0B0e, __m256i xmmRow0B0f, __m256i xmmRow0B0g, __m256i xmmRow0B0h,
__m256i xmmRow1B0a, __m256i xmmRow1B0b, __m256i xmmRow1B0c, __m256i xmmRow1B0d,
__m256i xmmRow1B0e, __m256i xmmRow1B0f, __m256i xmmRow1B0g, __m256i xmmRow1B0h,
__m256i xmmRow2B0a, __m256i xmmRow2B0b, __m256i xmmRow2B0c, __m256i xmmRow2B0d,
__m256i xmmRow2B0e, __m256i xmmRow2B0f, __m256i xmmRow2B0g, __m256i xmmRow2B0h,
__m256i xmmRow3B0a, __m256i xmmRow3B0b, __m256i xmmRow3B0c, __m256i xmmRow3B0d,
__m256i xmmRow3B0e, __m256i xmmRow3B0f, __m256i xmmRow3B0g, __m256i xmmRow3B0h,
short* B, __m256i* return1, __m256i* return2, __m256i * return3, __m256i* return4)
{
__m256i xmmCol0B0a = _mm256_load_si256((__m256i*)B);
__m256i xmmCol0B0b = _mm256_load_si256((__m256i*)(B + 16));
__m256i xmmCol0B0c = _mm256_load_si256((__m256i*)(B + 32));
__m256i xmmCol0B0d = _mm256_load_si256((__m256i*)(B + 48));
__m256i xmmCol0B0e = _mm256_load_si256((__m256i*)(B + 64));
__m256i xmmCol0B0f = _mm256_load_si256((__m256i*)(B + 80));
__m256i xmmCol0B0g = _mm256_load_si256((__m256i*)(B + 96));
__m256i xmmCol0B0h = _mm256_load_si256((__m256i*)(B + 112));
//Result for row 0
//Nomenclature:
//r0b0axc0b0a means "Row zero block zero part A times column zero block zero part A. (Blocks > 8 take up > 1 __m256i each (xmm registers))
__m256i r0b0axc0b0a = _mm256_madd_epi16(xmmRow0B0a, xmmCol0B0a);
__m256i r0b0bxc0b0b = _mm256_madd_epi16(xmmRow0B0b, xmmCol0B0b);
__m256i r0b0cxc0b0c = _mm256_madd_epi16(xmmRow0B0c, xmmCol0B0c);
__m256i r0b0dxc0b0d = _mm256_madd_epi16(xmmRow0B0d, xmmCol0B0d);
__m256i r0b0exc0b0e = _mm256_madd_epi16(xmmRow0B0e, xmmCol0B0e);
__m256i r0b0fxc0b0f = _mm256_madd_epi16(xmmRow0B0f, xmmCol0B0f);
__m256i r0b0gxc0b0g = _mm256_madd_epi16(xmmRow0B0g, xmmCol0B0g);
__m256i r0b0hxc0b0h = _mm256_madd_epi16(xmmRow0B0h, xmmCol0B0h);
__m256i result1a = _mm256_add_epi32(r0b0axc0b0a, r0b0bxc0b0b);
__m256i result1b = _mm256_add_epi32(r0b0cxc0b0c, r0b0dxc0b0d);
__m256i result1c = _mm256_add_epi32(r0b0exc0b0e, r0b0fxc0b0f);
__m256i result1d = _mm256_add_epi32(r0b0gxc0b0g, r0b0hxc0b0h);
__m256i result1ab = _mm256_add_epi32(result1a, result1b);
__m256i result1cd = _mm256_add_epi32(result1c, result1d);
__m256i result1abcd = _mm256_add_epi32(result1ab, result1cd);
//Result for row 1
__m256i r1b0axc0b0a = _mm256_madd_epi16(xmmRow1B0a, xmmCol0B0a);
__m256i r1b0bxc0b0b = _mm256_madd_epi16(xmmRow1B0b, xmmCol0B0b);
__m256i r1b0cxc0b0c = _mm256_madd_epi16(xmmRow1B0c, xmmCol0B0c);
__m256i r1b0dxc0b0d = _mm256_madd_epi16(xmmRow1B0d, xmmCol0B0d);
__m256i r1b0exc0b0e = _mm256_madd_epi16(xmmRow1B0e, xmmCol0B0e);
__m256i r1b0fxc0b0f = _mm256_madd_epi16(xmmRow1B0f, xmmCol0B0f);
__m256i r1b0gxc0b0g = _mm256_madd_epi16(xmmRow1B0g, xmmCol0B0g);
__m256i r1b0hxc0b0h = _mm256_madd_epi16(xmmRow1B0h, xmmCol0B0h);
__m256i result2a = _mm256_add_epi32(r1b0axc0b0a, r1b0bxc0b0b);
__m256i result2b = _mm256_add_epi32(r1b0cxc0b0c, r1b0dxc0b0d);
__m256i result2c = _mm256_add_epi32(r1b0exc0b0e, r1b0fxc0b0f);
__m256i result2d = _mm256_add_epi32(r1b0gxc0b0g, r1b0hxc0b0h);
__m256i result2ab = _mm256_add_epi32(result2a, result2b);
__m256i result2cd = _mm256_add_epi32(result2c, result2d);
__m256i result2abcd = _mm256_add_epi32(result2ab, result2cd);
//Result for row 2
__m256i r2b0axc0b0a = _mm256_madd_epi16(xmmRow2B0a, xmmCol0B0a);
__m256i r2b0bxc0b0b = _mm256_madd_epi16(xmmRow2B0b, xmmCol0B0b);
__m256i r2b0cxc0b0c = _mm256_madd_epi16(xmmRow2B0c, xmmCol0B0c);
__m256i r2b0dxc0b0d = _mm256_madd_epi16(xmmRow2B0d, xmmCol0B0d);
__m256i r2b0exc0b0e = _mm256_madd_epi16(xmmRow2B0e, xmmCol0B0e);
__m256i r2b0fxc0b0f = _mm256_madd_epi16(xmmRow2B0f, xmmCol0B0f);
__m256i r2b0gxc0b0g = _mm256_madd_epi16(xmmRow2B0g, xmmCol0B0g);
__m256i r2b0hxc0b0h = _mm256_madd_epi16(xmmRow2B0h, xmmCol0B0h);
__m256i result3a = _mm256_add_epi32(r2b0axc0b0a, r2b0bxc0b0b);
__m256i result3b = _mm256_add_epi32(r2b0cxc0b0c, r2b0dxc0b0d);
__m256i result3c = _mm256_add_epi32(r2b0exc0b0e, r2b0fxc0b0f);
__m256i result3d = _mm256_add_epi32(r2b0gxc0b0g, r2b0hxc0b0h);
__m256i result3ab = _mm256_add_epi32(result3a, result3b);
__m256i result3cd = _mm256_add_epi32(result3c, result3d);
__m256i result3abcd = _mm256_add_epi32(result3ab, result3cd);
//Result for row 3
__m256i r3b0axc0b0a = _mm256_madd_epi16(xmmRow3B0a, xmmCol0B0a);
__m256i r3b0bxc0b0b = _mm256_madd_epi16(xmmRow3B0b, xmmCol0B0b);
__m256i r3b0cxc0b0c = _mm256_madd_epi16(xmmRow3B0c, xmmCol0B0c);
__m256i r3b0dxc0b0d = _mm256_madd_epi16(xmmRow3B0d, xmmCol0B0d);
__m256i r3b0exc0b0e = _mm256_madd_epi16(xmmRow3B0e, xmmCol0B0e);
__m256i r3b0fxc0b0f = _mm256_madd_epi16(xmmRow3B0f, xmmCol0B0f);
__m256i r3b0gxc0b0g = _mm256_madd_epi16(xmmRow3B0g, xmmCol0B0g);
__m256i r3b0hxc0b0h = _mm256_madd_epi16(xmmRow3B0h, xmmCol0B0h);
__m256i result4a = _mm256_add_epi32(r3b0axc0b0a, r3b0bxc0b0b);
__m256i result4b = _mm256_add_epi32(r3b0cxc0b0c, r3b0dxc0b0d);
__m256i result4c = _mm256_add_epi32(r3b0exc0b0e, r3b0fxc0b0f);
__m256i result4d = _mm256_add_epi32(r3b0gxc0b0g, r3b0hxc0b0h);
__m256i result4ab = _mm256_add_epi32(result4a, result4b);
__m256i result4cd = _mm256_add_epi32(result4c, result4d);
__m256i result4abcd = _mm256_add_epi32(result4ab, result4cd);
//Now we can just add horizontally
*return1 = result1abcd;
*return2 = result2abcd;
*return3 = result3abcd;
*return4 = result4abcd;
}
}}}

Просмотреть файл

@ -0,0 +1,32 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full licence information.
//
#include "stdafx.h"
#include <xmmintrin.h>
#include <emmintrin.h>
#include <tmmintrin.h>
#include "BlockHandlerSSE.h"
#include "BlockMultiplierMatrixUtil.h"
namespace Microsoft { namespace MSR { namespace CNTK {
int BlockHandlerSSE::RowToColOffsetRewrittenA(int row, int kOffset, int blockSize, int rowsPerBlock, int origCols)
{
int rowIdx = row / rowsPerBlock;
int offsetFromBlockBeginning = row % rowsPerBlock;
int colIdx = kOffset * rowsPerBlock * blockSize + (offsetFromBlockBeginning * blockSize);
return (rowIdx * (origCols / blockSize) * rowsPerBlock * blockSize) + colIdx;
}
//col is the original column of B
//kOffset is the offset to the current block we are multiplying against (in absolute
int BlockHandlerSSE::RowToColOffsetRewrittenB(int col, int kOffset, int blockSize, int origCols)
{
return (origCols * blockSize * kOffset) + (col * blockSize);
}
}}}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,161 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full licence information.
//
#pragma once
#define NOMINMAX
#include <fstream>
#include <functional>
#include <iostream>
#include <limits>
#include <string.h>//for memset
#include "BlockMultiplierPlatform.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template<typename ScalarT> void DumpMatrix(ScalarT* pDumpMe, int rows, int cols, std::ostream* pStream, int rowMax = std::numeric_limits<int>::max(),
int colMax = std::numeric_limits<int>::max())
{
for (int r = 0; r < std::min(rows, rowMax); ++r)
{
for (int c = 0; c < std::min(cols, colMax); ++c)
{
(*pStream) << pDumpMe[r * cols + c] << " ";
}
(*pStream) << std::endl;
}
}
// Turn a row+col into an absolute offset
FORCEINLINE int RowColToOffset(int idxRow, int idxCol, int numCols)
{
return idxRow * numCols + idxCol;
}
template<typename ScalarT>struct TransposeArgs
{
int r;
ScalarT* transposeMe;
ScalarT* transposed;
int origRows;
int origCols;
};
template<class ScalarT>void TransposeThread(TransposeArgs<ScalarT> ta)
{
for (int c = 0; c < ta.origCols; ++c)
{
//new c,r = old r,c
int oldOffset = RowColToOffset(ta.r, c, ta.origCols);
int newOffset = RowColToOffset(c, ta.r, ta.origRows);
ta.transposed[newOffset] = ta.transposeMe[oldOffset];
}
}
template<typename ScalarT> class TransposeThreadType
{
public:
void operator()(TransposeArgs<ScalarT> ta)
{
TransposeThread<ScalarT>(ta);
}
};
template<class ScalarT> void Transpose(ScalarT* transposeMe, ScalarT* transposed, int origRows, int origCols)
{
#pragma omp parallel for
for (int r = 0; r < origRows; ++r)
{
for (int c = 0; c < origCols; ++c)
{
int oldOffset = RowColToOffset(r, c, origCols);
int newOffset = RowColToOffset(c, r, origRows);
transposed[newOffset] = transposeMe[oldOffset];
}
}
}
template<typename ScalarT> ScalarT* CreateAlignedMatrix(int m, int n, ScalarT initVal, int alignment = 64)
{
ScalarT* ret = (ScalarT*)ALIGNED_ALLOC(sizeof(ScalarT) * (m * n), alignment);
if (initVal != 0)
{
for (int i = 0; i < m * n; ++i)
{
ret[i] = initVal;// +i;
}
}
else
{
memset(ret, 0, sizeof(ScalarT) * m * n);
}
return ret;
}
template<typename ScalarT> void FreeAlignedMatrix(ScalarT* destroyMe)
{
ALIGNED_FREE(destroyMe);
}
template<typename ScalarT> double MeanSquaredError(ScalarT* lhs, ScalarT* rhs, int m, int n)
{
double accumulatedError = 0.0;
for (int r = 0; r < m; ++r)
{
for(int c = 0; c < n; ++c)
{
double err = ((double)lhs[RowColToOffset(r, c, n)] - (double)rhs[RowColToOffset(r, c, n)]);
err = err * err;
accumulatedError += err;
}
}
return accumulatedError / (double)(m * n);
}
template<typename ScalarT> void RandInitIntMatrix(ScalarT* initMe, int m, int n, ScalarT bound)
{
ScalarT* curr = initMe;
for (int i = 0; i < m * n; ++i)
{
*curr++ = rand() % bound;
}
}
//Helper fn for tests
template<typename ScalarT>static void RandInitFloatMatrix(ScalarT* initMe, int m, int n, ScalarT min, ScalarT max)
{
for (int i = 0; i < m * n; ++i)
{
initMe[i] = min + ((max - min) * ((ScalarT)rand() / RAND_MAX));
}
}
//Viewing matrices and troubleshooting is a lot easier in Octave.
//Utility fn for exporting to Octave format
template<typename ScalarT>void DumpMatrixToOctaveFormat(const ScalarT* dumpMe, int rows, int cols, const char* fileName, const char* id)
{
std::ofstream ofs(fileName);
ofs << "# Created by gemmbenchmark" << std::endl <<
"# name: " << id << std::endl <<
"# type: matrix" << std::endl <<
"# rows: " << rows << std::endl <<
"# columns: " << cols << std::endl;
for (int r = 0; r < rows; ++r)
{
for (int c = 0; c < cols; ++c)
{
ofs << ' ' << (ScalarT)(dumpMe[(cols * r) + c]);
}
ofs << std::endl;
}
}
}}} //End namespaces

Просмотреть файл

@ -0,0 +1,19 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full licence information.
//
#pragma once
#ifdef _MSC_VER
#define ALIGNED_ALLOC(bytes,alignment) _aligned_malloc(bytes,alignment)
#define ALIGNED_FREE(ptr) _aligned_free(ptr)
#define FORCEINLINE __forceinline
#else
#ifdef __GNUC__
#include <stdlib.h>
#define ALIGNED_ALLOC(bytes,alignment) aligned_alloc(alignment,bytes)
#define ALIGNED_FREE(ptr) free(ptr)
//#define FORCEINLINE __attribute__((always_inline))
#define FORCEINLINE inline
#endif
#endif

Просмотреть файл

@ -161,6 +161,11 @@
<ClInclude Include="..\Common\Include\File.h" />
<ClInclude Include="..\Common\Include\fileutil.h" />
<ClInclude Include="BatchNormalizationEngine.h" />
<ClInclude Include="BlockHandlerAVX.h" />
<ClInclude Include="BlockHandlerSSE.h" />
<ClInclude Include="BlockMultiplier.h" />
<ClInclude Include="BlockMultiplierMatrixUtil.h" />
<ClInclude Include="BlockMultiplierPlatform.h" />
<ClInclude Include="CommonMatrix.h" />
<ClInclude Include="ConvolutionEngine.h" />
<ClInclude Include="ConvolveGeometry.h" />
@ -190,6 +195,8 @@
</ItemGroup>
<ItemGroup>
<ClCompile Include="BatchNormalizationEngine.cpp" />
<ClCompile Include="BlockHandlerAVX.cpp" />
<ClCompile Include="BlockHandlerSSE.cpp" />
<ClCompile Include="ConvolutionEngine.cpp" />
<ClCompile Include="CPURNGHandle.cpp" />
<ClCompile Include="CPUSparseMatrix.cpp" />

Просмотреть файл

@ -42,6 +42,12 @@
<Filter>CPU</Filter>
</ClCompile>
<ClCompile Include="RNGHandle.cpp" />
<ClCompile Include="BlockHandlerAVX.cpp">
<Filter>CPU</Filter>
</ClCompile>
<ClCompile Include="BlockHandlerSSE.cpp">
<Filter>CPU</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="CommonMatrix.h" />
@ -105,6 +111,18 @@
<ClInclude Include="CPURNGHandle.h">
<Filter>CPU</Filter>
</ClInclude>
<ClInclude Include="BlockHandlerAVX.h">
<Filter>CPU</Filter>
</ClInclude>
<ClInclude Include="BlockHandlerSSE.h">
<Filter>CPU</Filter>
</ClInclude>
<ClInclude Include="BlockMultiplier.h">
<Filter>CPU</Filter>
</ClInclude>
<ClInclude Include="BlockMultiplierPlatform.h">
<Filter>CPU</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="GPUMatrix.h">
@ -155,4 +173,4 @@
<UniqueIdentifier>{8f982dac-298d-4e48-b060-8e6cba5ff554}</UniqueIdentifier>
</Filter>
</ItemGroup>
</Project>
</Project>

Просмотреть файл

@ -0,0 +1,158 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full licence information.
//
#include "stdafx.h"
#include "../../../Source/Math/BlockMultiplier.h"
namespace Microsoft { namespace MSR { namespace CNTK { namespace TEST {
//The simplest possible matrix multiplier, used here as a check.
template<typename ScalarAT, typename ScalarBT, typename ScalarCT, int MAXRANGE = 1 << ((8 * sizeof(ScalarAT)) - 3)> class ReferenceMultiplier
{
public:
typedef ScalarAT ScalarAT;
typedef ScalarBT ScalarBT;
typedef ScalarCT ScalarCT;
static const int MAXRANGE = MAXRANGE;
ScalarBT* PrepareB(ScalarBT* oldB, int k, int n) { return oldB; }
static ScalarAT* CreateMatrixA(int m, int n)
{
return CreateMatrix<ScalarAT>(m, n);
}
static ScalarBT* CreateMatrixB(int m, int n)
{
return CreateMatrix<ScalarBT>(m, n);
}
static ScalarCT* CreateMatrixC(int m, int n)
{
return CreateMatrix<ScalarCT>(m, n);
}
template<typename ScalarT> static ScalarT* CreateMatrix(int m, int n, ScalarT initVal = ScalarT())
{
ScalarT* ret = new ScalarT[m*n];
if (initVal != ScalarT())
{
for (int i = 0; i < m * n; ++i)
{
ret[i] = initVal;
}
}
return ret;
}
template<typename ScalarT> static void FreeMatrix(ScalarT* destroyMe)
{
delete[] destroyMe;
}
void MultiplyMatrices(ScalarAT* A, int m, int k, ScalarBT* B, int n, ScalarCT* C, ScalarAT alpha = (ScalarAT)1, ScalarBT beta = (ScalarBT)0)
{
alpha;
beta;
for (int r = 0; r < m; ++r)
{
for (int c = 0; c < n; ++c)
{
ScalarCT accum = (ScalarCT)0;
for (int d = 0; d < k; ++d)
{
ScalarCT prod = (ScalarCT)(A[(k * r) + d]) * (ScalarCT)(B[(n*d) + c]);
bool signsIdentical = ((accum > 0) == (prod > 0));
//signed overflow occurs iff signs identical and sum different in sign from operators.
accum += prod;
if (signsIdentical && (accum > 0) != (prod > 0))
{
throw std::runtime_error("overflow!");
}
}
C[(r * n) + c] = accum;
}
}
}
};
template<typename ScalarAT, typename ScalarBT, typename ScalarCT, typename MultiplierT>static void TestMultiplierSub(
int m, int k, int n, MultiplierT& testMult, int numThreads = 1, ScalarCT epsilon = ScalarCT())
{
epsilon;
testMult.SetNumThreads(numThreads);
ReferenceMultiplier<ScalarAT, ScalarBT, ScalarCT> refMult;
ScalarAT* refA = refMult.CreateMatrixA(m, k);
ScalarBT* refB = refMult.CreateMatrixB(k, n);
ScalarCT* refC = refMult.CreateMatrixC(m, n);
ScalarAT* testA = testMult.CreateMatrixA(m, k);
ScalarBT* testB = testMult.CreateMatrixB(k, n);
ScalarCT* testC = testMult.CreateMatrixC(m, n);
RandInitIntMatrix<ScalarAT>(refA, m, k, 63);
RandInitIntMatrix<ScalarBT>(refB, k, n, 63);
memcpy(testA, refA, sizeof(ScalarAT) * m * k);
memcpy(testB, refB, sizeof(ScalarBT) * k * n);
ScalarBT* testBPrepared = testMult.PrepareB(testB, k, n);
refMult.MultiplyMatrices(refA, m, k, refB, n, refC);
//Make sure we can multiply twice on the same matrix correctly.
for (int i = 0; i < 2; ++i)
{
testMult.MultiplyMatrices(testA, m, k, testBPrepared, n, testC);
//This will cause test failure and dump matrix to Octave format for debugging if they don't match
CompareMatricesAndDump(refC, testC, m, k, n);
memset(testC, (ScalarCT)0, sizeof(ScalarCT) * m * n);
}
refMult.FreeMatrix(refA);
refMult.FreeMatrix(refB);
refMult.FreeMatrix(refC);
testMult.FreeMatrix(testA);
testMult.FreeMatrix(testB);
testMult.FreeMatrix(testC);
if (testBPrepared != testB)
{
testMult.FreeMatrix(testBPrepared);
}
}
template<typename ScalarAT, typename ScalarBT, typename ScalarCT, typename MultiplierT>static void TestMultiplierSub(
int m, int k, int n, int numThreads = 1, ScalarCT epsilon = ScalarCT())
{
MultiplierT testMult;
TestMultiplierSub<ScalarAT, ScalarBT, ScalarCT, MultiplierT>(m, k, n, testMult, numThreads, epsilon);
}
template<typename ScalarCT> void CompareMatricesAndDump(const ScalarCT* ref, const ScalarCT* test,
int m, int /*k*/, int n)
{
for (int i = 0; i < m * n; ++i)
{
BOOST_CHECK_EQUAL(ref[i], test[i]);
}
}
BOOST_AUTO_TEST_SUITE(BlockMultiplierSuite)
BOOST_AUTO_TEST_CASE(BlockMultiplyTest)
{
int m = 8;
int k = 128;
int n = 8;
TestMultiplierSub<int16_t, int16_t, int32_t, BlockMultiplier<BlockHandlerSSE>>(m, k, n);
}
BOOST_AUTO_TEST_SUITE_END()
}}}} //end namespaces

Просмотреть файл

@ -138,6 +138,7 @@
</ItemGroup>
<ItemGroup>
<ClCompile Include="BatchNormalizationEngineTests.cpp" />
<ClCompile Include="BlockMultiplierTests.cpp" />
<ClCompile Include="constants.cpp" />
<ClCompile Include="ConvolutionEngineTests.cpp" />
<ClCompile Include="CPUSparseMatrixTests.cpp" />