Quantized Convolution and MaxPool Implementation (#194)

* Initial

* Implement Convolution and MaxPooling for all filter sizes

* Incorporate Review

* Additional Documentation

* Add Support for Dilations in q_maxpool() and q_convolution()

* Extend MBConv To Take Even Number Filter Sizes
This commit is contained in:
Shikhar Jaiswal 2020-07-20 10:21:56 +05:30 коммит произвёл GitHub
Родитель 9f68e66e46
Коммит f835ab1abc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 628 добавлений и 30 удалений

Просмотреть файл

@ -259,7 +259,8 @@ typedef struct Q_FastGRNN_Buffers {
* <code>ERR_PRECOMP_NOT_INIT</code> if preComp2 not allocated
* <code>ERR_PRECOMP_NOT_INIT</code> if preComp3 not allocated
* <code>ERR_NORMFEAT_NOT_INIT</code> if normFeatures not allocated
*/
* @example Please refer the file: c_reference/tests/fastgrnn/test_quantized_fastgrnn.c
*/
int q_fastgrnn(INT_T* const hiddenState, ITER_T hiddenDims,
const INT_T* const input, ITER_T inputDims, ITER_T steps,
const void* params, void* buffers, const void* scales,

Просмотреть файл

@ -8,6 +8,7 @@
/**
* @brief Model parameters for Quantized MBConv Layer
* Note: This implementation doesn't support dilations yet.
* @param[in] input pointer to the input buffer
* @param[in] filter1 pointer to the first convolution filter buffer
* @param[in] BN1W pointer to the buffer holding the multiplication factor of the first BatchNorm computation
@ -61,6 +62,7 @@
* @param[in] shlU3 scale to multiply with the third TreeSum output
* @param[in] shlB3 scale to multiply with the third BatchNorm addition factor
* @param[in] shlW3 scale to multiply with the third Convolution output
* @return none
*
* @brief The function computes the following three sub-parts:
* 1) Convolution(input, filter1) -> Batch Normalization(BN1W, BN1B) -> ReLU(limit1) -> convBuffer1
@ -68,6 +70,8 @@
* 3) Convolution(convBuffer2, filter3) -> Batch Normalization(BN3W, BN3B) -> output
* Variables depth1, depth2 and depth3 are used along with treesumBuffer for accumulating the sums during convolutions.
* Rest of the variables are used as indicated.
*
* @example Please refer the file: c_reference/tests/mbconv/test_quantized_mbconv.c
*/
void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
const INT_T* const BN1W, const INT_T* const BN1B, const INT_T* const filter2,

Просмотреть файл

@ -25,6 +25,8 @@ typedef int (*q_rnn_t)(INT_T* const, ITER_T, const INT_T* const, ITER_T, ITER_T,
* @param[in] rnn2_scales pointer to the scales needed for RNN2
* @param[out] output pointer to output, initialized to size 4 * hiddenDims2
* @param[in,out] buffer pointer to buffer, intialized to size hiddenDims1 * max{nrows, cols}
* @return none
* @example Please refer the file: c_reference/tests/rnnpool/test_quantized_rnnpool.c
*/
int q_rnnpool_block(const INT_T* const patch, ITER_T inputDims, ITER_T patchDim,
ITER_T stride, q_rnn_t rnn1, ITER_T hiddenDims1,

Просмотреть файл

@ -32,36 +32,194 @@ inline INTM_T q_relu(INTM_T inp, INTM_T limit) {
}
}
// Functions for calculating quantized operations and activations.
// Function for computing TreeSum from a given vector holding intermediate results.
/**
* @brief Compute TreeSum from a given vector holding intermediate multiplications, and store the result in the first index of the input vector.
* @param[in, out] vec pointer to vector on which TreeSum operation is to be computed
* @param[in] len length of the input vector
* @param[in] H1 depth parameter for division-by-two used in TreeSum
* @param[in] H2 depth parameter for direct sum used in TreeSum
* @return none
* @example vec = {-425, -169, -3534, 524, -2739, 87, 52, 292}
* len = 8
* H1 = 3
* H2 = 0
* vec[0] = {-738}
*/
void v_q_treesum(INTM_T* const vec, ITER_T len, SCALE_T H1, SCALE_T H2);
// Function for computing the element-wise addition between two vectors.
/**
* @brief Compute the element-wise addition between two vectors.
* @param[in] vec1 pointer to the first input vector
* @param[in] vec2 pointer to the second input vector
* @param[in] len length of the input vectors
* @param[out] ret pointer to the vector storing the output
* @param[in] scvec1 scale factor of the first input vector
* @param[in] scvec2 scale factor of the second input vector
* @param[in] scret scale factor of the output vector
* @return none
* @example vec1 = {-425, -169, -3534, 524, -2739, 87, 52, 292}
* vec2 = {-18777, -9518, 4055, -7309, 8584, -17257, -5280, -7933}
* len = 8
* scvec1 = 1
* scvec2 = 8
* scret = 1
* ret = {-2772, -1358, -3028, -389, -1666, -2070, -608, -699}
*/
void v_q_add(const INT_T* const vec1, const INT_T* const vec2, ITER_T len,
INT_T* const ret, SCALE_T scvec1, SCALE_T scvec2, SCALE_T scret);
// Function for computing the element-wise difference between two vectors.
/**
* @brief Compute the element-wise subtraction between two vectors.
* @param[in] vec1 pointer to the first input vector
* @param[in] vec2 pointer to the second input vector
* @param[in] len length of the input vectors
* @param[out] ret pointer to the vector storing the output
* @param[in] scvec1 scale factor of the first input vector
* @param[in] scvec2 scale factor of the second input vector
* @param[in] scret scale factor of the output vector
* @return none
* @example vec1 = {-425, -169, -3534, 524, -2739, 87, 52, 292}
* vec2 = {-18777, -9518, 4055, -7309, 8584, -17257, -5280, -7933}
* len = 8
* scvec1 = 1
* scvec2 = 8
* scret = 1
* ret = {1922, 1020, -4040, 1437, -3812, 2244, 712, 1283}
*/
void v_q_sub(const INT_T* const vec1, const INT_T* const vec2, ITER_T len,
INT_T* const ret, SCALE_T scvec1, SCALE_T scvec2, SCALE_T scret);
// Function for computing the Hadamard product between two vectors.
/**
* @brief Compute the element-wise product (also known as Hadamard product) between two vectors.
* @param[in] vec1 pointer to the first input vector
* @param[in] vec2 pointer to the second input vector
* @param[in] len length of the input vectors
* @param[out] ret pointer to the vector storing the output
* @param[in] scvec1 scale factor of the first input vector
* @param[in] scvec2 scale factor of the second input vector
* @return none
* @example vec1 = {16378, 13638, 16378, 9787, 14861, 16378, 10661, 11018}
* vec2 = {178, 1064, -2048, 1718, -1663, 851, 1244, 1282}
* len = 8
* scvec1 = 32
* scvec2 = 64
* ret = {1423, 7085, -16378, 8209, -12067, 6805, 6475, 6897}
*/
void v_q_hadamard(const INT_T* const vec1, const INT_T* const vec2, ITER_T len,
INT_T* const ret, SCALE_T scvec1, SCALE_T scvec2);
// Function for computing the Sigmoid activation on the input vector.
/**
* @brief Compute the element-wise Sigmoid activation on the input vector.
* @param[in] vec pointer to the input vector
* @param[in] len length of the input vector
* @param[out] ret pointer to the vector storing the output
* @param[in] div division factor of the input vector
* @param[in] add addition offset of the input vector
* @param[in] sigmoid_limit saturation limit for the Sigmoid activation
* @param[in] scale_in scale factor of the input vector
* @param[in] scale_out scale factor of the output vector
* @return none
* @example formula = saturate(0, (vec_{i} / div) + add, sigmoid_limit) * 2^{scale_out - scale_in}
* vec = {-2772, -1358, -3028, -389, -1666, -2070, -608, -699}
* len = 8
* div = 2
* add = 1024
* sigmoid_limit = 2048
* scale_in = 11
* scale_out = 14
* ret = {0, 2760, 0, 6640, 1528, 0, 5760, 5400}
*/
void v_q_sigmoid(const INT_T* const vec, ITER_T len, INT_T* const ret, INT_T div,
INT_T add, INT_T sigmoid_limit, SCALE_T scale_in, SCALE_T scale_out);
// Function for computing the TanHyperbolic activation on the input vector.
/**
* @brief Compute the element-wise TanHyperbolic activation on the input vector.
* @param[in] vec pointer to the input vector
* @param[in] len length of the input vector
* @param[out] ret pointer to the vector storing the output
* @param[in] scale_in scale factor of the input vector
* @param[in] scale_out scale factor of the output vector
* @return none
* @example formula = saturate(-2^{scale_in}, vec_{i}, 2^{scale_in}) * 2^{scale_out - scale_in}
* vec = {178, 1064, -4162, 1718, -1663, 851, 1244, 1282}
* len = 8
* scale_in = 11
* scale_out = 11
* ret = {178, 1064, -2048, 1718, -1663, 851, 1244, 1282}
*/
void v_q_tanh(const INT_T* const vec, ITER_T len, INT_T* const ret,
SCALE_T scale_in, SCALE_T scale_out);
// Function for adding a scalar to every element of a vector.
/**
* @brief Compute the addition of a scalar to every element of a vector.
* @param[in] scalar the input scalar to be added to a vector
* @param[in] vec pointer to the input vector
* @param[in] len length of the input vector
* @param[out] ret pointer to the vector storing the output
* @param[in] scscalar scale factor of the input scalar
* @param[in] scvec scale factor of the input vector
* @param[in] scret scale factor of the output vector
* @return none
* @example scalar = 30111
* vec = {16261, 13521, 16261, 9670, 14744, 16261, 10544, 10901}
* len = 8
* scscalar = 256
* scvec = 1
* scret = 1
* ret = {16378, 13638, 16378, 9787, 14861, 16378, 10661, 11018}
*/
void v_q_scalar_add(INT_T scalar, const INT_T* const vec, ITER_T len,
INT_T* const ret, SCALE_T scscalar, SCALE_T scvec, SCALE_T scret);
// Function for subtracting every element of a vector B from a scalar a.
// The resultant vector has elements C_{i} = a - B_{i}.
/**
* @brief Compute the subtraction of every element of a vector (B) from a scalar (a). The resultant vector has elements C_{i} = a - B_{i}.
* @param[in] scalar the input scalar
* @param[in] vec pointer to the input vector to be subtracted
* @param[in] len length of the input vector
* @param[out] ret pointer to the vector storing the output
* @param[in] scscalar scale factor of the input scalar
* @param[in] scvec scale factor of the input vector
* @param[in] scret scale factor of the output vector
* @return none
* @example scalar = 16384
* vec = {0, 2760, 0, 6640, 1528, 0, 5760, 5400}
* len = 8
* scscalar = 1
* scvec = 1
* scret = 1
* ret = {16384, 13624, 16384, 9744, 14856, 16384, 10624, 10984}
*/
void v_q_scalar_sub(INT_T scalar, const INT_T* const vec, ITER_T len,
INT_T* const ret, SCALE_T scscalar, SCALE_T scvec, SCALE_T scret);
// Function for subtracting a scalar b from every element of a vector A.
// The resultant vector has elements C_{i} = A_{i} - b.
/**
* @brief Compute the subtraction of a scalar (b) from every element of a vector (A). The resultant vector has elements C_{i} = A_{i} - b.
* @param[in] scalar the input scalar to be subtracted
* @param[in] vec pointer to the input vector
* @param[in] len length of the input vector
* @param[out] ret pointer to the vector storing the output
* @param[in] scscalar scale factor of the input scalar
* @param[in] scvec scale factor of the input vector
* @param[in] scret scale factor of the output vector
* @return none
* @example scalar = 16384
* vec = {0, 2760, 0, 6640, 1528, 0, 5760, 5400}
* len = 8
* scscalar = 1
* scvec = 1
* scret = 1
* ret = {-16384, -13624, -16384, -9744, -14856, -16384, -10624, -10984}
*/
void v_q_sub_scalar(const INT_T* const vec, INT_T scalar, ITER_T len,
INT_T* const ret, SCALE_T scvec, SCALE_T scscalar, SCALE_T scret);
// Function for multiplying a scalar to every element of a vector.
/**
* @brief Compute the multiplication of a scalar to every element of a vector.
* @param[in] scalar the input scalar to be multiplied
* @param[in] vec pointer to the input vector
* @param[in] len length of the input vector
* @param[out] ret pointer to the vector storing the output
* @param[in] scscalar scale factor of the input scalar
* @param[in] scvec scale factor of the input vector
* @return none
* @example scalar = 32522
* vec = {16384, 13624, 16384, 9744, 14856, 16384, 10624, 10984}
* len = 8
* scscalar = 128
* scvec = 256
* ret = {16261, 13521, 16261, 9670, 14744, 16261, 10544, 10901}
*/
void v_q_scalar_mul(INT_T scalar, const INT_T* const vec, ITER_T len,
INT_T* const ret, SCALE_T scscalar, SCALE_T scvec);
/**
@ -214,7 +372,36 @@ void m_q_add_vec(const INT_T* const mat, const INT_T* const vec,
void m_q_sub_vec(const INT_T* const mat, const INT_T* const vec,
ITER_T nrows, ITER_T ncols, INT_T* const ret,
SCALE_T scmat, SCALE_T scvec, SCALE_T scret);
// Function for multiplying a matrix with a vector.
/**
* @brief Performs the matrix multiplication of a matrix and a vector.
* @param[in] mat pointer to input matrix in row-major order
* @param[in] vec pointer to the input vector
* @param[in] nrows number of rows of the input matrix
* @param[in] ncols number of columns of the input matrix
* @param[out] ret pointer to the output vector
* @param[in] scmat scale factor of the input matrix
* @param[in] scvec scale factor of the input vector
* @param[in] H1 depth parameter for division-by-two used in TreeSum
* @param[in] H2 depth parameter for direct sum used in TreeSum
* @return none
* @example mat = { {7069, -10389, 1562, -1992},
* {3262, -37, -1143, -995},
* {5513, -17035, -14615, -6636},
* {4733, -403, 4106, -1104},
* {-2707, -1287, -18128, -1832},
* {-10108, -137, 2064, 1207},
* {5233, 226, 831, -1909},
* {4489, -1099, 2845, -1261} }
* vec = {1040, 1919, 4254, 4024}
* nrows = 8
* ncols = 4
* scmat = 128
* scvec = 64
* H1 = 2
* H2 = 0
* ret = {-425, -169, -3534, 524, -2739, 87, 52, 292}
*/
void m_q_mulvec(const INT_T* const mat, const INT_T* const vec, ITER_T nrows,
ITER_T ncols, INT_T* const ret, SCALE_T scmat, SCALE_T scvec,
SCALE_T H1, SCALE_T H2);
@ -228,13 +415,13 @@ void m_q_mulvec(const INT_T* const mat, const INT_T* const vec, ITER_T nrows,
* @param[in] mat_values pointer to input matrix which stores the non-zero values of matrix A
* @param[in] vec pointer to the input vector
* @param[in] ndims dimension of the multiplication vector
* @param[out] ret pointer to the output matrix
* @param[out] ret pointer to the output vector
* @param[in] scmat scale factor of the input matrix
* @param[in] scvec scale factor of the input vector
* @param[in] scret scale factor of the output matrix
@return none
* @example mat = {{10, 20, 30, 40, 50, 60, 70, 0, 0, 0, 0, 0, 0, 0},
* {0, 80, 0, 90, 0, 100, 0, 110, 0, 120, 0, 130, 0, 140}}
* @return none
* @example mat = { {10, 20, 30, 40, 50, 60, 70, 0, 0, 0, 0, 0, 0, 0},
* {0, 80, 0, 90, 0, 100, 0, 110, 0, 120, 0, 130, 0, 140} }
* col_indices = {1, 2, 3, 4, 5, 6, 7, 0, 2, 4, 6, 8, 10, 12, 14, 0}
* mat_values = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140}
* vec = {1, 2}
@ -303,9 +490,9 @@ void t_q_add_vec(const INT_T* const mat, const INT_T* const vec,
* nrows = 4
* ncols = 2
* ret = { { {-1451, 2141}, {-386, 1132} },
{ {2281, 1120}, {593, 3882} } },
{ { {1334, 20}, {1282, 2127} },
{ {2375, 1684}, {2561, 4075} } }
* { {2281, 1120}, {593, 3882} } },
* { { {1334, 20}, {1282, 2127} },
* { {2375, 1684}, {2561, 4075} } }
* scmat = 1
* scvec = 2
* scret = 2
@ -315,4 +502,80 @@ void t_q_sub_vec(const INT_T* const ten, const INT_T* const vec,
ITER_T nchannels, INT_T* const ret, SCALE_T scmat,
SCALE_T scvec, SCALE_T scret);
/**
* @brief Computes the maxpool operation on the input tensor with the given parameters.
* @param[in] input pointer to the tensor on which max-pooling is to be performed
* @param[out] output pointer to the output tensor
* @param[in] N number of batches of the input tensor
* @param[in] H number of rows of the input tensor
* @param[in] W number of columns of the input tensor
* @param[in] CIn number of channels of the input tensor
* @param[in] HF number of rows of the pooling filter
* @param[in] WF number of columns of the pooling filter
* @param[in] CF number of channels of the pooling filter
* @param[in] COut number of channels of the output tensor
* @param[in] HOut number of rows of the output tensor
* @param[in] WOut number of columns of the output tensor
* @param[in] G number of groups of pooling filters
* @param[in] HPadU padding over the top row
* @param[in] HPadD padding under the bottom row
* @param[in] WPadL padding before the leftmost column
* @param[in] WPadR padding after the rightmost column
* @param[in] HStride stride of the pooling filter along the rows, used for moving the receptive field horizontally within the larger image
* @param[in] WStride stride of the pooling filter along the columns, used for moving the receptive field vertically within the larger image
* @param[in] HDilation dilation of the convolution filter along the rows (number of skipped input rows between two consecutive filter rows is HDilation - 1)
* @param[in] WDilation dilation of the convolution filter along the columns (number of skipped input columns between two consecutive filter rows is WDilation - 1)
* @param[in] scinput scale of the input tensor
* @param[in] scoutput scale of the output tensor
* @return none
* @example Please refer the test-case: test_quantized_maxpool() in file: c_reference/tests/utils/test_quantized_utils.c
*/
void q_maxpool(const INT_T* const input, INT_T* const output, ITER_T N,
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF, ITER_T CF,
ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G, S_ITER_T HPadU,
S_ITER_T HPadD, S_ITER_T WPadL, S_ITER_T WPadR, ITER_T HStride,
ITER_T WStride, ITER_T HDilation, ITER_T WDilation,
SCALE_T scinput, SCALE_T scoutput);
/**
* @brief Computes the maxpool operation on the input tensor with the given parameters.
* @param[in] input pointer to the tensor on which convolution is to be performed
* @param[in] filter pointer to the convolutional filter tensor
* @param[out] output pointer to the output tensor
* @param[in] treesumBuffer pointer to the buffer for computing TreeSum accumulation
* @param[in] N number of batches of the input tensor
* @param[in] H number of rows of the input tensor
* @param[in] W number of columns of the input tensor
* @param[in] CIn number of channels of the input tensor
* @param[in] HF number of rows of the convolutional filter
* @param[in] WF number of columns of the convolutional filter
* @param[in] CF number of channels of the convolutional filter
* @param[in] COut number of channels of the output tensor
* @param[in] HOut number of rows of the output tensor
* @param[in] WOut number of columns of the output tensor
* @param[in] G number of groups of convolutional filters
* @param[in] HPadU padding over the top row
* @param[in] HPadD padding under the bottom row
* @param[in] WPadL padding before the leftmost column
* @param[in] WPadR padding after the rightmost column
* @param[in] HStride stride of the convolution filter along the rows, used for moving the receptive field horizontally within the larger image
* @param[in] WStride stride of the convolution filter along the columns, used for moving the receptive field vertically within the larger image
* @param[in] HDilation dilation of the convolution filter along the rows (number of skipped input rows between two consecutive filter rows is HDilation - 1)
* @param[in] WDilation dilation of the convolution filter along the columns (number of skipped input columns between two consecutive filter rows is WDilation - 1)
* @param[in] H1 depth parameter for division-by-two used in TreeSum
* @param[in] H2 depth parameter for direct sum used in TreeSum
* @param[in] scinput scale of the input tensor
* @param[in] scoutput scale of the output tensor
* @return none
* @example Please refer the test-case: test_quantized_convolution() in file: c_reference/tests/utils/test_quantized_utils.c
*/
void q_convolution(const INT_T* const input, const INT_T* const filter,
INT_T* const output, INTM_T* const treesumBuffer, ITER_T N,
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF,
ITER_T CF, ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G,
S_ITER_T HPadU, S_ITER_T HPadD, S_ITER_T WPadL,
S_ITER_T WPadR, ITER_T HStride, ITER_T WStride,
ITER_T HDilation, ITER_T WDilation, SCALE_T H1, SCALE_T H2,
SCALE_T scinput, SCALE_T scoutput);
#endif

Просмотреть файл

@ -18,15 +18,15 @@ void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
L_SCALE_T shlB2, L_SCALE_T shlX2, L_SCALE_T shlU3, L_SCALE_T shlB3,
L_SCALE_T shlW3) {
S_ITER_T HOffsetL = ((S_ITER_T)(HF >> 1)) - HPadU;
S_ITER_T WOffsetL = ((S_ITER_T)(WF >> 1)) - WPadL;
S_ITER_T HOffsetL = ((S_ITER_T)((HF - 1) >> 1)) - HPadU;
S_ITER_T WOffsetL = ((S_ITER_T)((WF - 1) >> 1)) - WPadL;
S_ITER_T HOffsetR = ((S_ITER_T)(HF >> 1)) - HPadD;
S_ITER_T WOffsetR = ((S_ITER_T)(WF >> 1)) - WPadR;
for (ITER_T n = 0; n < N; n++) {
ITER_T margin = 0, nstart = 0;
if (HOffsetL + ((S_ITER_T)(HF >> 1) + 1) - ((S_ITER_T)HStride) > 0) {
margin = (ITER_T)(HOffsetL + ((S_ITER_T)(HF >> 1) + 1) - ((S_ITER_T)HStride));
if ((S_ITER_T)HF - HPadU - (S_ITER_T)HStride > 0) {
margin = (ITER_T)((S_ITER_T)HF - HPadU - (S_ITER_T)HStride);
}
if (HPadU < 0) {
// nstart will always be zero unless HPadU is negative.
@ -66,7 +66,6 @@ void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
for (ITER_T k = 0; k < CTemp; k++) {
ITER_T iRed = (i + margin + hout * HStride) % HF;
ITER_T iFull = i + margin + hout * HStride;
convBuffer1[iRed * W * CTemp + j * CTemp + k] = 0;
for (ITER_T l = 0; l < CIn; l++) {
if (iFull < H) {
treesumBuffer[l] = ((INTM_T)input[n * H * W * CIn + iFull * W * CIn + j * CIn + l]) *
@ -98,13 +97,13 @@ void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
for (S_ITER_T w = WOffsetL; w < ((S_ITER_T)W) - WOffsetR; wout++, w += ((S_ITER_T)WStride)) {
for (ITER_T g = 0; g < CTemp; g++) {
ITER_T counter = 0;
for (S_ITER_T hf = -(HF >> 1); hf <= (HF >> 1); hf++) {
for (S_ITER_T wf = -(WF >> 1); wf <= (WF >> 1); wf++) {
for (S_ITER_T hf = -((HF - 1) >> 1); hf <= (HF >> 1); hf++) {
for (S_ITER_T wf = -((WF - 1) >> 1); wf <= (WF >> 1); wf++) {
if (((h + hf) < 0) || ((h + hf) >= (S_ITER_T)H) || ((w + wf) < 0) || ((w + wf) >= (S_ITER_T)W)) {
treesumBuffer[counter] = 0;
} else {
treesumBuffer[counter] = ((INTM_T)convBuffer1[(((ITER_T)(h + hf)) % HF) * W * CTemp + ((ITER_T)(w + wf)) * CTemp + g]) *
((INTM_T)filter2[g * HF * WF + ((ITER_T)(hf + (HF >> 1))) * WF + ((ITER_T)(wf + (WF >> 1)))]);
((INTM_T)filter2[g * HF * WF + ((ITER_T)(hf + ((HF - 1) >> 1))) * WF + ((ITER_T)(wf + ((WF - 1) >> 1)))]);
}
counter++;
}

Просмотреть файл

@ -355,3 +355,135 @@ void t_q_sub_vec(const INT_T* const mat, const INT_T* const vec,
#endif
}
}
void q_maxpool(const INT_T* const input, INT_T* const output, ITER_T N,
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF, ITER_T CF,
ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G, S_ITER_T HPadU,
S_ITER_T HPadD, S_ITER_T WPadL, S_ITER_T WPadR, ITER_T HStride,
ITER_T WStride, ITER_T HDilation, ITER_T WDilation,
SCALE_T scinput, SCALE_T scoutput) {
S_ITER_T HOffsetL = ((S_ITER_T)HDilation * (S_ITER_T)((HF - 1) >> 1)) - HPadU;
S_ITER_T WOffsetL = ((S_ITER_T)WDilation * (S_ITER_T)((WF - 1) >> 1)) - WPadL;
S_ITER_T HOffsetR = ((S_ITER_T)HDilation * (S_ITER_T)(HF >> 1)) - HPadD;
S_ITER_T WOffsetR = ((S_ITER_T)WDilation * (S_ITER_T)(WF >> 1)) - WPadR;
ITER_T HOffsetIn = W * CIn;
ITER_T NOffsetIn = H * HOffsetIn;
ITER_T WOffsetOut = (COut * G);
ITER_T HOffsetOut = WOut * WOffsetOut;
ITER_T NOffsetOut = HOut * HOffsetOut;
for (ITER_T n = 0; n < N; n++) {
ITER_T hout = 0;
ITER_T NIndexIn = n * NOffsetIn;
ITER_T NIndexOut = n * NOffsetOut;
for (S_ITER_T h = HOffsetL; h < (S_ITER_T)H - HOffsetR; h += (S_ITER_T)HStride, hout++) {
ITER_T wout = 0;
ITER_T HIndexOut = hout * HOffsetOut;
for (S_ITER_T w = WOffsetL; w < (S_ITER_T)W - WOffsetR; w += (S_ITER_T)WStride, wout++) {
ITER_T WIndexOut = wout * WOffsetOut;
for (ITER_T g = 0; g < G; g++) {
ITER_T CIndexIn = g * CF;
ITER_T CIndexOut = g * COut;
for (ITER_T c = 0; c < COut; c++) {
INT_T max = INT_TMIN;
for (S_ITER_T hf = -((HF - 1) >> 1); hf <= (HF >> 1); hf++) {
S_ITER_T hoffset = h + ((S_ITER_T)HDilation * hf);
ITER_T HIndexIn = ((ITER_T)hoffset) * HOffsetIn;
for (S_ITER_T wf = -((WF - 1) >> 1); wf <= (WF >> 1); wf++) {
S_ITER_T woffset = w + ((S_ITER_T)WDilation * wf);
ITER_T WIndexIn = ((ITER_T)woffset) * CIn;
for (ITER_T cf = 0; cf < CF; cf++) {
if ((hoffset < 0) || (hoffset >= (S_ITER_T)H) || (woffset < 0) || (woffset >= (S_ITER_T)W)) {
if (max < 0) {
max = 0;
}
} else {
INT_T a = input[NIndexIn + HIndexIn + WIndexIn + (cf + CIndexIn)];
if (max < a) {
max = a;
}
}
}
}
}
#ifdef SHIFT
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = (max >> (scinput + scoutput));
#else
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = ((max / scinput) / scoutput);
#endif
}
}
}
}
}
}
void q_convolution(const INT_T* const input, const INT_T* const filter,
INT_T* const output, INTM_T* const treesumBuffer, ITER_T N,
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF,
ITER_T CF, ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G,
S_ITER_T HPadU, S_ITER_T HPadD, S_ITER_T WPadL,
S_ITER_T WPadR, ITER_T HStride, ITER_T WStride,
ITER_T HDilation, ITER_T WDilation, SCALE_T H1, SCALE_T H2,
SCALE_T scinput, SCALE_T scoutput) {
S_ITER_T HOffsetL = ((S_ITER_T)HDilation * (S_ITER_T)((HF - 1) >> 1)) - HPadU;
S_ITER_T WOffsetL = ((S_ITER_T)WDilation * (S_ITER_T)((WF - 1) >> 1)) - WPadL;
S_ITER_T HOffsetR = ((S_ITER_T)HDilation * (S_ITER_T)(HF >> 1)) - HPadD;
S_ITER_T WOffsetR = ((S_ITER_T)WDilation * (S_ITER_T)(WF >> 1)) - WPadR;
ITER_T HOffsetIn = W * CIn;
ITER_T NOffsetIn = H * HOffsetIn;
ITER_T WOffsetF = CF * COut;
ITER_T HOffsetF = WF * WOffsetF;
ITER_T WOffsetOut = (COut * G);
ITER_T HOffsetOut = WOut * WOffsetOut;
ITER_T NOffsetOut = HOut * HOffsetOut;
for (ITER_T n = 0; n < N; n++) {
ITER_T hout = 0;
ITER_T NIndexIn = n * NOffsetIn;
ITER_T NIndexOut = n * NOffsetOut;
for (S_ITER_T h = HOffsetL; h < (S_ITER_T)H - HOffsetR; h += (S_ITER_T)HStride, hout++) {
ITER_T wout = 0;
ITER_T HIndexOut = hout * HOffsetOut;
for (S_ITER_T w = WOffsetL; w < (S_ITER_T)W - WOffsetR; w += (S_ITER_T)WStride, wout++) {
ITER_T WIndexOut = wout * WOffsetOut;
for (ITER_T g = 0; g < G; g++) {
ITER_T CIndexIn = g * CF;
ITER_T CIndexOut = g * COut;
for (ITER_T c = 0; c < COut; c++) {
ITER_T counter = 0;
for (S_ITER_T hf = -((HF - 1) >> 1); hf <= (HF >> 1); hf++) {
S_ITER_T hoffset = h + ((S_ITER_T)HDilation * hf);
ITER_T HIndexIn = ((ITER_T)hoffset) * HOffsetIn;
ITER_T HIndexF = ((ITER_T)(hf + ((HF - 1) >> 1))) * HOffsetF;
for (S_ITER_T wf = -((WF - 1) >> 1); wf <= (WF >> 1); wf++) {
S_ITER_T woffset = w + ((S_ITER_T)WDilation * wf);
ITER_T WIndexIn = ((ITER_T)woffset) * CIn;
ITER_T WIndexF = ((ITER_T)(wf + ((WF - 1) >> 1))) * WOffsetF;
for (ITER_T cf = 0; cf < CF; cf++) {
if ((hoffset < 0) || (hoffset >= (S_ITER_T)H) || (woffset < 0) || (woffset >= (S_ITER_T)W)) {
treesumBuffer[counter] = 0;
} else {
treesumBuffer[counter] = ((INTM_T)input[NIndexIn + HIndexIn + WIndexIn + (cf + CIndexIn)]) *
((INTM_T)filter[HIndexF + WIndexF + (c + cf * COut)]);
}
counter++;
}
}
}
v_q_treesum(&treesumBuffer[0], HF * WF * CF, H1, H2);
#ifdef SHIFT
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = (treesumBuffer[0] >> (scinput + scoutput));
#else
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = ((treesumBuffer[0] / scinput) / scoutput);
#endif
}
}
}
}
}
}

Просмотреть файл

@ -402,6 +402,199 @@ int test_t_q_sub_vec() {
return check_output(pred, expected, 16);
}
// Test q_maxpool() function.
int test_q_maxpool() {
const INT_T qmat_A[2 * 2 * 2 * 2] = {11, 220,
130, 40,
50, 60,
66, 76,
86, 910,
411, 312,
513, 514,
715, 716};
const INT_T qmat_B[2 * 2 * 2 * 2] = {100, 992,
15, 26,
27, 8,
3, 4,
5, 2,
2, 2,
7, 8,
29, 140};
const INT_T expected_A[2 * 1 * 1 * 2] = {32, 55,
178, 227};
const INT_T expected_B[2 * 3 * 3 * 2] = {100, 992,
100, 992,
15, 26,
100, 992,
100, 992,
15, 26,
27, 8,
27, 8,
3, 4,
5, 2,
5, 2,
2, 2,
7, 8,
29, 140,
29, 140,
7, 8,
29, 140,
29, 140};
const INT_T expected_C[2 * 2 * 2 * 2] = {100, 992,
100, 992,
100, 992,
100, 992,
29, 140,
29, 140,
29, 140,
29, 140};
const INT_T expected_D[2 * 3 * 3 * 2] = {16, 19,
0, 0,
12, 15,
0, 0,
0, 0,
0, 0,
32, 10,
0, 0,
2, 55,
178, 179,
0, 0,
128, 128,
0, 0,
0, 0,
0, 0,
102, 78,
0, 0,
21, 227};
INT_T pred_A[2 * 1 * 1 * 2], pred_B[2 * 3 * 3 * 2], pred_C[2 * 2 * 2 * 2], pred_D[2 * 3 * 3 * 2];
#ifdef SHIFT
q_maxpool(qmat_A, pred_A, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1);
q_maxpool(qmat_B, pred_B, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0);
q_maxpool(qmat_B, pred_C, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0);
q_maxpool(qmat_A, pred_D, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 1, 1);
#else
q_maxpool(qmat_A, pred_A, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2);
q_maxpool(qmat_B, pred_B, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
q_maxpool(qmat_B, pred_C, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
q_maxpool(qmat_A, pred_D, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2);
#endif
return (check_output(pred_A, expected_A, 4) || check_output(pred_B, expected_B, 36) || check_output(pred_C, expected_C, 8) || check_output(pred_D, expected_D, 36));
}
// Test q_convolution() function.
int test_q_convolution() {
const INT_T qmat_A[2 * 2 * 2 * 2] = {11, 220,
130, 40,
50, 60,
66, 76,
86, 910,
411, 312,
513, 514,
715, 716};
//Convolution Filters
const INT_T qmat_B[2 * 2 * 1 * 1] = {0, 1,
1, 0};
const INT_T qmat_C[2 * 2 * 2 * 1] = {0, 1,
1, 0,
1, 0,
0, 1};
const INT_T qmat_D[3 * 3 * 1 * 1] = {0, 0, 1,
0, 1, 0,
1, 0, 0};
const INT_T expected_A[2 * 1 * 1 * 2] = {44, 25,
230, 206};
const INT_T expected_B[2 * 1 * 1 * 1] = {58,
317};
const INT_T expected_C[2 * 2 * 2 * 2] = {1, 27,
22, 12,
22, 12,
8, 9,
10, 113,
115, 103,
115, 103,
89, 89};
const INT_T expected_D[2 * 3 * 3 * 2] = {0, 0,
0, 0,
1, 1,
0, 0,
0, 0,
0, 0,
4, 1,
0, 0,
0, 0,
0, 0,
0, 0,
16, 16,
0, 0,
0, 0,
0, 0,
12, 9,
0, 0,
0, 0};
INT_T pred_A[2 * 1 * 1 * 2], pred_B[2 * 1 * 1 * 1], pred_C[2 * 2 * 2 * 2], pred_D[2 * 3 * 3 * 2];
INTM_T temp[16];
#ifdef SHIFT
q_convolution(qmat_A, qmat_B, pred_A, temp, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 0, 0, 0);
q_convolution(qmat_A, qmat_C, pred_B, temp, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 3, 0, 0, 0);
q_convolution(qmat_A, qmat_D, pred_C, temp, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 3, 0, 0, 0);
q_convolution(qmat_A, qmat_B, pred_D, temp, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 3, 0, 1, 1);
#else
q_convolution(qmat_A, qmat_B, pred_A, temp, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 0, 1, 1);
q_convolution(qmat_A, qmat_C, pred_B, temp, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 3, 0, 1, 1);
q_convolution(qmat_A, qmat_D, pred_C, temp, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 3, 0, 1, 1);
q_convolution(qmat_A, qmat_B, pred_D, temp, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 3, 0, 2, 2);
#endif
return (check_output(pred_A, expected_A, 4) || check_output(pred_B, expected_B, 2) || check_output(pred_C, expected_C, 16) || check_output(pred_D, expected_D, 36));
}
int main() {
if (test_v_q_treesum()) {
printf("Test Failure for v_q_treesum()!\n");
@ -449,6 +642,10 @@ int main() {
printf("Test Failure for t_q_add_vec()!\n");
} else if (test_t_q_sub_vec()) {
printf("Test Failure for t_q_sub_vec()!\n");
} else if (test_q_maxpool()) {
printf("Test Failure for q_maxpool()!\n");
} else if (test_q_convolution()) {
printf("Test Failure for q_convolution()!\n");
} else {
printf("All Tests Passed!\n");
return 0;