зеркало из https://github.com/microsoft/EdgeML.git
Quantized Convolution and MaxPool Implementation (#194)
* Initial * Implement Convolution and MaxPooling for all filter sizes * Incorporate Review * Additional Documentation * Add Support for Dilations in q_maxpool() and q_convolution() * Extend MBConv To Take Even Number Filter Sizes
This commit is contained in:
Родитель
9f68e66e46
Коммит
f835ab1abc
|
@ -259,7 +259,8 @@ typedef struct Q_FastGRNN_Buffers {
|
|||
* <code>ERR_PRECOMP_NOT_INIT</code> if preComp2 not allocated
|
||||
* <code>ERR_PRECOMP_NOT_INIT</code> if preComp3 not allocated
|
||||
* <code>ERR_NORMFEAT_NOT_INIT</code> if normFeatures not allocated
|
||||
*/
|
||||
* @example Please refer the file: c_reference/tests/fastgrnn/test_quantized_fastgrnn.c
|
||||
*/
|
||||
int q_fastgrnn(INT_T* const hiddenState, ITER_T hiddenDims,
|
||||
const INT_T* const input, ITER_T inputDims, ITER_T steps,
|
||||
const void* params, void* buffers, const void* scales,
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
/**
|
||||
* @brief Model parameters for Quantized MBConv Layer
|
||||
* Note: This implementation doesn't support dilations yet.
|
||||
* @param[in] input pointer to the input buffer
|
||||
* @param[in] filter1 pointer to the first convolution filter buffer
|
||||
* @param[in] BN1W pointer to the buffer holding the multiplication factor of the first BatchNorm computation
|
||||
|
@ -61,6 +62,7 @@
|
|||
* @param[in] shlU3 scale to multiply with the third TreeSum output
|
||||
* @param[in] shlB3 scale to multiply with the third BatchNorm addition factor
|
||||
* @param[in] shlW3 scale to multiply with the third Convolution output
|
||||
* @return none
|
||||
*
|
||||
* @brief The function computes the following three sub-parts:
|
||||
* 1) Convolution(input, filter1) -> Batch Normalization(BN1W, BN1B) -> ReLU(limit1) -> convBuffer1
|
||||
|
@ -68,6 +70,8 @@
|
|||
* 3) Convolution(convBuffer2, filter3) -> Batch Normalization(BN3W, BN3B) -> output
|
||||
* Variables depth1, depth2 and depth3 are used along with treesumBuffer for accumulating the sums during convolutions.
|
||||
* Rest of the variables are used as indicated.
|
||||
*
|
||||
* @example Please refer the file: c_reference/tests/mbconv/test_quantized_mbconv.c
|
||||
*/
|
||||
void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
|
||||
const INT_T* const BN1W, const INT_T* const BN1B, const INT_T* const filter2,
|
||||
|
|
|
@ -25,6 +25,8 @@ typedef int (*q_rnn_t)(INT_T* const, ITER_T, const INT_T* const, ITER_T, ITER_T,
|
|||
* @param[in] rnn2_scales pointer to the scales needed for RNN2
|
||||
* @param[out] output pointer to output, initialized to size 4 * hiddenDims2
|
||||
* @param[in,out] buffer pointer to buffer, intialized to size hiddenDims1 * max{nrows, cols}
|
||||
* @return none
|
||||
* @example Please refer the file: c_reference/tests/rnnpool/test_quantized_rnnpool.c
|
||||
*/
|
||||
int q_rnnpool_block(const INT_T* const patch, ITER_T inputDims, ITER_T patchDim,
|
||||
ITER_T stride, q_rnn_t rnn1, ITER_T hiddenDims1,
|
||||
|
|
|
@ -32,36 +32,194 @@ inline INTM_T q_relu(INTM_T inp, INTM_T limit) {
|
|||
}
|
||||
}
|
||||
|
||||
// Functions for calculating quantized operations and activations.
|
||||
// Function for computing TreeSum from a given vector holding intermediate results.
|
||||
/**
|
||||
* @brief Compute TreeSum from a given vector holding intermediate multiplications, and store the result in the first index of the input vector.
|
||||
* @param[in, out] vec pointer to vector on which TreeSum operation is to be computed
|
||||
* @param[in] len length of the input vector
|
||||
* @param[in] H1 depth parameter for division-by-two used in TreeSum
|
||||
* @param[in] H2 depth parameter for direct sum used in TreeSum
|
||||
* @return none
|
||||
* @example vec = {-425, -169, -3534, 524, -2739, 87, 52, 292}
|
||||
* len = 8
|
||||
* H1 = 3
|
||||
* H2 = 0
|
||||
* vec[0] = {-738}
|
||||
*/
|
||||
void v_q_treesum(INTM_T* const vec, ITER_T len, SCALE_T H1, SCALE_T H2);
|
||||
// Function for computing the element-wise addition between two vectors.
|
||||
/**
|
||||
* @brief Compute the element-wise addition between two vectors.
|
||||
* @param[in] vec1 pointer to the first input vector
|
||||
* @param[in] vec2 pointer to the second input vector
|
||||
* @param[in] len length of the input vectors
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scvec1 scale factor of the first input vector
|
||||
* @param[in] scvec2 scale factor of the second input vector
|
||||
* @param[in] scret scale factor of the output vector
|
||||
* @return none
|
||||
* @example vec1 = {-425, -169, -3534, 524, -2739, 87, 52, 292}
|
||||
* vec2 = {-18777, -9518, 4055, -7309, 8584, -17257, -5280, -7933}
|
||||
* len = 8
|
||||
* scvec1 = 1
|
||||
* scvec2 = 8
|
||||
* scret = 1
|
||||
* ret = {-2772, -1358, -3028, -389, -1666, -2070, -608, -699}
|
||||
*/
|
||||
void v_q_add(const INT_T* const vec1, const INT_T* const vec2, ITER_T len,
|
||||
INT_T* const ret, SCALE_T scvec1, SCALE_T scvec2, SCALE_T scret);
|
||||
// Function for computing the element-wise difference between two vectors.
|
||||
/**
|
||||
* @brief Compute the element-wise subtraction between two vectors.
|
||||
* @param[in] vec1 pointer to the first input vector
|
||||
* @param[in] vec2 pointer to the second input vector
|
||||
* @param[in] len length of the input vectors
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scvec1 scale factor of the first input vector
|
||||
* @param[in] scvec2 scale factor of the second input vector
|
||||
* @param[in] scret scale factor of the output vector
|
||||
* @return none
|
||||
* @example vec1 = {-425, -169, -3534, 524, -2739, 87, 52, 292}
|
||||
* vec2 = {-18777, -9518, 4055, -7309, 8584, -17257, -5280, -7933}
|
||||
* len = 8
|
||||
* scvec1 = 1
|
||||
* scvec2 = 8
|
||||
* scret = 1
|
||||
* ret = {1922, 1020, -4040, 1437, -3812, 2244, 712, 1283}
|
||||
*/
|
||||
void v_q_sub(const INT_T* const vec1, const INT_T* const vec2, ITER_T len,
|
||||
INT_T* const ret, SCALE_T scvec1, SCALE_T scvec2, SCALE_T scret);
|
||||
// Function for computing the Hadamard product between two vectors.
|
||||
/**
|
||||
* @brief Compute the element-wise product (also known as Hadamard product) between two vectors.
|
||||
* @param[in] vec1 pointer to the first input vector
|
||||
* @param[in] vec2 pointer to the second input vector
|
||||
* @param[in] len length of the input vectors
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scvec1 scale factor of the first input vector
|
||||
* @param[in] scvec2 scale factor of the second input vector
|
||||
* @return none
|
||||
* @example vec1 = {16378, 13638, 16378, 9787, 14861, 16378, 10661, 11018}
|
||||
* vec2 = {178, 1064, -2048, 1718, -1663, 851, 1244, 1282}
|
||||
* len = 8
|
||||
* scvec1 = 32
|
||||
* scvec2 = 64
|
||||
* ret = {1423, 7085, -16378, 8209, -12067, 6805, 6475, 6897}
|
||||
*/
|
||||
void v_q_hadamard(const INT_T* const vec1, const INT_T* const vec2, ITER_T len,
|
||||
INT_T* const ret, SCALE_T scvec1, SCALE_T scvec2);
|
||||
// Function for computing the Sigmoid activation on the input vector.
|
||||
/**
|
||||
* @brief Compute the element-wise Sigmoid activation on the input vector.
|
||||
* @param[in] vec pointer to the input vector
|
||||
* @param[in] len length of the input vector
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] div division factor of the input vector
|
||||
* @param[in] add addition offset of the input vector
|
||||
* @param[in] sigmoid_limit saturation limit for the Sigmoid activation
|
||||
* @param[in] scale_in scale factor of the input vector
|
||||
* @param[in] scale_out scale factor of the output vector
|
||||
* @return none
|
||||
* @example formula = saturate(0, (vec_{i} / div) + add, sigmoid_limit) * 2^{scale_out - scale_in}
|
||||
* vec = {-2772, -1358, -3028, -389, -1666, -2070, -608, -699}
|
||||
* len = 8
|
||||
* div = 2
|
||||
* add = 1024
|
||||
* sigmoid_limit = 2048
|
||||
* scale_in = 11
|
||||
* scale_out = 14
|
||||
* ret = {0, 2760, 0, 6640, 1528, 0, 5760, 5400}
|
||||
*/
|
||||
void v_q_sigmoid(const INT_T* const vec, ITER_T len, INT_T* const ret, INT_T div,
|
||||
INT_T add, INT_T sigmoid_limit, SCALE_T scale_in, SCALE_T scale_out);
|
||||
// Function for computing the TanHyperbolic activation on the input vector.
|
||||
/**
|
||||
* @brief Compute the element-wise TanHyperbolic activation on the input vector.
|
||||
* @param[in] vec pointer to the input vector
|
||||
* @param[in] len length of the input vector
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scale_in scale factor of the input vector
|
||||
* @param[in] scale_out scale factor of the output vector
|
||||
* @return none
|
||||
* @example formula = saturate(-2^{scale_in}, vec_{i}, 2^{scale_in}) * 2^{scale_out - scale_in}
|
||||
* vec = {178, 1064, -4162, 1718, -1663, 851, 1244, 1282}
|
||||
* len = 8
|
||||
* scale_in = 11
|
||||
* scale_out = 11
|
||||
* ret = {178, 1064, -2048, 1718, -1663, 851, 1244, 1282}
|
||||
*/
|
||||
void v_q_tanh(const INT_T* const vec, ITER_T len, INT_T* const ret,
|
||||
SCALE_T scale_in, SCALE_T scale_out);
|
||||
// Function for adding a scalar to every element of a vector.
|
||||
/**
|
||||
* @brief Compute the addition of a scalar to every element of a vector.
|
||||
* @param[in] scalar the input scalar to be added to a vector
|
||||
* @param[in] vec pointer to the input vector
|
||||
* @param[in] len length of the input vector
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scscalar scale factor of the input scalar
|
||||
* @param[in] scvec scale factor of the input vector
|
||||
* @param[in] scret scale factor of the output vector
|
||||
* @return none
|
||||
* @example scalar = 30111
|
||||
* vec = {16261, 13521, 16261, 9670, 14744, 16261, 10544, 10901}
|
||||
* len = 8
|
||||
* scscalar = 256
|
||||
* scvec = 1
|
||||
* scret = 1
|
||||
* ret = {16378, 13638, 16378, 9787, 14861, 16378, 10661, 11018}
|
||||
*/
|
||||
void v_q_scalar_add(INT_T scalar, const INT_T* const vec, ITER_T len,
|
||||
INT_T* const ret, SCALE_T scscalar, SCALE_T scvec, SCALE_T scret);
|
||||
// Function for subtracting every element of a vector B from a scalar a.
|
||||
// The resultant vector has elements C_{i} = a - B_{i}.
|
||||
/**
|
||||
* @brief Compute the subtraction of every element of a vector (B) from a scalar (a). The resultant vector has elements C_{i} = a - B_{i}.
|
||||
* @param[in] scalar the input scalar
|
||||
* @param[in] vec pointer to the input vector to be subtracted
|
||||
* @param[in] len length of the input vector
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scscalar scale factor of the input scalar
|
||||
* @param[in] scvec scale factor of the input vector
|
||||
* @param[in] scret scale factor of the output vector
|
||||
* @return none
|
||||
* @example scalar = 16384
|
||||
* vec = {0, 2760, 0, 6640, 1528, 0, 5760, 5400}
|
||||
* len = 8
|
||||
* scscalar = 1
|
||||
* scvec = 1
|
||||
* scret = 1
|
||||
* ret = {16384, 13624, 16384, 9744, 14856, 16384, 10624, 10984}
|
||||
*/
|
||||
void v_q_scalar_sub(INT_T scalar, const INT_T* const vec, ITER_T len,
|
||||
INT_T* const ret, SCALE_T scscalar, SCALE_T scvec, SCALE_T scret);
|
||||
// Function for subtracting a scalar b from every element of a vector A.
|
||||
// The resultant vector has elements C_{i} = A_{i} - b.
|
||||
/**
|
||||
* @brief Compute the subtraction of a scalar (b) from every element of a vector (A). The resultant vector has elements C_{i} = A_{i} - b.
|
||||
* @param[in] scalar the input scalar to be subtracted
|
||||
* @param[in] vec pointer to the input vector
|
||||
* @param[in] len length of the input vector
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scscalar scale factor of the input scalar
|
||||
* @param[in] scvec scale factor of the input vector
|
||||
* @param[in] scret scale factor of the output vector
|
||||
* @return none
|
||||
* @example scalar = 16384
|
||||
* vec = {0, 2760, 0, 6640, 1528, 0, 5760, 5400}
|
||||
* len = 8
|
||||
* scscalar = 1
|
||||
* scvec = 1
|
||||
* scret = 1
|
||||
* ret = {-16384, -13624, -16384, -9744, -14856, -16384, -10624, -10984}
|
||||
*/
|
||||
void v_q_sub_scalar(const INT_T* const vec, INT_T scalar, ITER_T len,
|
||||
INT_T* const ret, SCALE_T scvec, SCALE_T scscalar, SCALE_T scret);
|
||||
// Function for multiplying a scalar to every element of a vector.
|
||||
/**
|
||||
* @brief Compute the multiplication of a scalar to every element of a vector.
|
||||
* @param[in] scalar the input scalar to be multiplied
|
||||
* @param[in] vec pointer to the input vector
|
||||
* @param[in] len length of the input vector
|
||||
* @param[out] ret pointer to the vector storing the output
|
||||
* @param[in] scscalar scale factor of the input scalar
|
||||
* @param[in] scvec scale factor of the input vector
|
||||
* @return none
|
||||
* @example scalar = 32522
|
||||
* vec = {16384, 13624, 16384, 9744, 14856, 16384, 10624, 10984}
|
||||
* len = 8
|
||||
* scscalar = 128
|
||||
* scvec = 256
|
||||
* ret = {16261, 13521, 16261, 9670, 14744, 16261, 10544, 10901}
|
||||
*/
|
||||
void v_q_scalar_mul(INT_T scalar, const INT_T* const vec, ITER_T len,
|
||||
INT_T* const ret, SCALE_T scscalar, SCALE_T scvec);
|
||||
/**
|
||||
|
@ -214,7 +372,36 @@ void m_q_add_vec(const INT_T* const mat, const INT_T* const vec,
|
|||
void m_q_sub_vec(const INT_T* const mat, const INT_T* const vec,
|
||||
ITER_T nrows, ITER_T ncols, INT_T* const ret,
|
||||
SCALE_T scmat, SCALE_T scvec, SCALE_T scret);
|
||||
// Function for multiplying a matrix with a vector.
|
||||
/**
|
||||
* @brief Performs the matrix multiplication of a matrix and a vector.
|
||||
* @param[in] mat pointer to input matrix in row-major order
|
||||
* @param[in] vec pointer to the input vector
|
||||
* @param[in] nrows number of rows of the input matrix
|
||||
* @param[in] ncols number of columns of the input matrix
|
||||
* @param[out] ret pointer to the output vector
|
||||
* @param[in] scmat scale factor of the input matrix
|
||||
* @param[in] scvec scale factor of the input vector
|
||||
* @param[in] H1 depth parameter for division-by-two used in TreeSum
|
||||
* @param[in] H2 depth parameter for direct sum used in TreeSum
|
||||
|
||||
* @return none
|
||||
* @example mat = { {7069, -10389, 1562, -1992},
|
||||
* {3262, -37, -1143, -995},
|
||||
* {5513, -17035, -14615, -6636},
|
||||
* {4733, -403, 4106, -1104},
|
||||
* {-2707, -1287, -18128, -1832},
|
||||
* {-10108, -137, 2064, 1207},
|
||||
* {5233, 226, 831, -1909},
|
||||
* {4489, -1099, 2845, -1261} }
|
||||
* vec = {1040, 1919, 4254, 4024}
|
||||
* nrows = 8
|
||||
* ncols = 4
|
||||
* scmat = 128
|
||||
* scvec = 64
|
||||
* H1 = 2
|
||||
* H2 = 0
|
||||
* ret = {-425, -169, -3534, 524, -2739, 87, 52, 292}
|
||||
*/
|
||||
void m_q_mulvec(const INT_T* const mat, const INT_T* const vec, ITER_T nrows,
|
||||
ITER_T ncols, INT_T* const ret, SCALE_T scmat, SCALE_T scvec,
|
||||
SCALE_T H1, SCALE_T H2);
|
||||
|
@ -228,13 +415,13 @@ void m_q_mulvec(const INT_T* const mat, const INT_T* const vec, ITER_T nrows,
|
|||
* @param[in] mat_values pointer to input matrix which stores the non-zero values of matrix A
|
||||
* @param[in] vec pointer to the input vector
|
||||
* @param[in] ndims dimension of the multiplication vector
|
||||
* @param[out] ret pointer to the output matrix
|
||||
* @param[out] ret pointer to the output vector
|
||||
* @param[in] scmat scale factor of the input matrix
|
||||
* @param[in] scvec scale factor of the input vector
|
||||
* @param[in] scret scale factor of the output matrix
|
||||
@return none
|
||||
* @example mat = {{10, 20, 30, 40, 50, 60, 70, 0, 0, 0, 0, 0, 0, 0},
|
||||
* {0, 80, 0, 90, 0, 100, 0, 110, 0, 120, 0, 130, 0, 140}}
|
||||
* @return none
|
||||
* @example mat = { {10, 20, 30, 40, 50, 60, 70, 0, 0, 0, 0, 0, 0, 0},
|
||||
* {0, 80, 0, 90, 0, 100, 0, 110, 0, 120, 0, 130, 0, 140} }
|
||||
* col_indices = {1, 2, 3, 4, 5, 6, 7, 0, 2, 4, 6, 8, 10, 12, 14, 0}
|
||||
* mat_values = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140}
|
||||
* vec = {1, 2}
|
||||
|
@ -303,9 +490,9 @@ void t_q_add_vec(const INT_T* const mat, const INT_T* const vec,
|
|||
* nrows = 4
|
||||
* ncols = 2
|
||||
* ret = { { {-1451, 2141}, {-386, 1132} },
|
||||
{ {2281, 1120}, {593, 3882} } },
|
||||
{ { {1334, 20}, {1282, 2127} },
|
||||
{ {2375, 1684}, {2561, 4075} } }
|
||||
* { {2281, 1120}, {593, 3882} } },
|
||||
* { { {1334, 20}, {1282, 2127} },
|
||||
* { {2375, 1684}, {2561, 4075} } }
|
||||
* scmat = 1
|
||||
* scvec = 2
|
||||
* scret = 2
|
||||
|
@ -315,4 +502,80 @@ void t_q_sub_vec(const INT_T* const ten, const INT_T* const vec,
|
|||
ITER_T nchannels, INT_T* const ret, SCALE_T scmat,
|
||||
SCALE_T scvec, SCALE_T scret);
|
||||
|
||||
/**
|
||||
* @brief Computes the maxpool operation on the input tensor with the given parameters.
|
||||
* @param[in] input pointer to the tensor on which max-pooling is to be performed
|
||||
* @param[out] output pointer to the output tensor
|
||||
* @param[in] N number of batches of the input tensor
|
||||
* @param[in] H number of rows of the input tensor
|
||||
* @param[in] W number of columns of the input tensor
|
||||
* @param[in] CIn number of channels of the input tensor
|
||||
* @param[in] HF number of rows of the pooling filter
|
||||
* @param[in] WF number of columns of the pooling filter
|
||||
* @param[in] CF number of channels of the pooling filter
|
||||
* @param[in] COut number of channels of the output tensor
|
||||
* @param[in] HOut number of rows of the output tensor
|
||||
* @param[in] WOut number of columns of the output tensor
|
||||
* @param[in] G number of groups of pooling filters
|
||||
* @param[in] HPadU padding over the top row
|
||||
* @param[in] HPadD padding under the bottom row
|
||||
* @param[in] WPadL padding before the leftmost column
|
||||
* @param[in] WPadR padding after the rightmost column
|
||||
* @param[in] HStride stride of the pooling filter along the rows, used for moving the receptive field horizontally within the larger image
|
||||
* @param[in] WStride stride of the pooling filter along the columns, used for moving the receptive field vertically within the larger image
|
||||
* @param[in] HDilation dilation of the convolution filter along the rows (number of skipped input rows between two consecutive filter rows is HDilation - 1)
|
||||
* @param[in] WDilation dilation of the convolution filter along the columns (number of skipped input columns between two consecutive filter rows is WDilation - 1)
|
||||
* @param[in] scinput scale of the input tensor
|
||||
* @param[in] scoutput scale of the output tensor
|
||||
* @return none
|
||||
* @example Please refer the test-case: test_quantized_maxpool() in file: c_reference/tests/utils/test_quantized_utils.c
|
||||
*/
|
||||
void q_maxpool(const INT_T* const input, INT_T* const output, ITER_T N,
|
||||
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF, ITER_T CF,
|
||||
ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G, S_ITER_T HPadU,
|
||||
S_ITER_T HPadD, S_ITER_T WPadL, S_ITER_T WPadR, ITER_T HStride,
|
||||
ITER_T WStride, ITER_T HDilation, ITER_T WDilation,
|
||||
SCALE_T scinput, SCALE_T scoutput);
|
||||
|
||||
/**
|
||||
* @brief Computes the maxpool operation on the input tensor with the given parameters.
|
||||
* @param[in] input pointer to the tensor on which convolution is to be performed
|
||||
* @param[in] filter pointer to the convolutional filter tensor
|
||||
* @param[out] output pointer to the output tensor
|
||||
* @param[in] treesumBuffer pointer to the buffer for computing TreeSum accumulation
|
||||
* @param[in] N number of batches of the input tensor
|
||||
* @param[in] H number of rows of the input tensor
|
||||
* @param[in] W number of columns of the input tensor
|
||||
* @param[in] CIn number of channels of the input tensor
|
||||
* @param[in] HF number of rows of the convolutional filter
|
||||
* @param[in] WF number of columns of the convolutional filter
|
||||
* @param[in] CF number of channels of the convolutional filter
|
||||
* @param[in] COut number of channels of the output tensor
|
||||
* @param[in] HOut number of rows of the output tensor
|
||||
* @param[in] WOut number of columns of the output tensor
|
||||
* @param[in] G number of groups of convolutional filters
|
||||
* @param[in] HPadU padding over the top row
|
||||
* @param[in] HPadD padding under the bottom row
|
||||
* @param[in] WPadL padding before the leftmost column
|
||||
* @param[in] WPadR padding after the rightmost column
|
||||
* @param[in] HStride stride of the convolution filter along the rows, used for moving the receptive field horizontally within the larger image
|
||||
* @param[in] WStride stride of the convolution filter along the columns, used for moving the receptive field vertically within the larger image
|
||||
* @param[in] HDilation dilation of the convolution filter along the rows (number of skipped input rows between two consecutive filter rows is HDilation - 1)
|
||||
* @param[in] WDilation dilation of the convolution filter along the columns (number of skipped input columns between two consecutive filter rows is WDilation - 1)
|
||||
* @param[in] H1 depth parameter for division-by-two used in TreeSum
|
||||
* @param[in] H2 depth parameter for direct sum used in TreeSum
|
||||
* @param[in] scinput scale of the input tensor
|
||||
* @param[in] scoutput scale of the output tensor
|
||||
* @return none
|
||||
* @example Please refer the test-case: test_quantized_convolution() in file: c_reference/tests/utils/test_quantized_utils.c
|
||||
*/
|
||||
void q_convolution(const INT_T* const input, const INT_T* const filter,
|
||||
INT_T* const output, INTM_T* const treesumBuffer, ITER_T N,
|
||||
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF,
|
||||
ITER_T CF, ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G,
|
||||
S_ITER_T HPadU, S_ITER_T HPadD, S_ITER_T WPadL,
|
||||
S_ITER_T WPadR, ITER_T HStride, ITER_T WStride,
|
||||
ITER_T HDilation, ITER_T WDilation, SCALE_T H1, SCALE_T H2,
|
||||
SCALE_T scinput, SCALE_T scoutput);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -18,15 +18,15 @@ void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
|
|||
L_SCALE_T shlB2, L_SCALE_T shlX2, L_SCALE_T shlU3, L_SCALE_T shlB3,
|
||||
L_SCALE_T shlW3) {
|
||||
|
||||
S_ITER_T HOffsetL = ((S_ITER_T)(HF >> 1)) - HPadU;
|
||||
S_ITER_T WOffsetL = ((S_ITER_T)(WF >> 1)) - WPadL;
|
||||
S_ITER_T HOffsetL = ((S_ITER_T)((HF - 1) >> 1)) - HPadU;
|
||||
S_ITER_T WOffsetL = ((S_ITER_T)((WF - 1) >> 1)) - WPadL;
|
||||
S_ITER_T HOffsetR = ((S_ITER_T)(HF >> 1)) - HPadD;
|
||||
S_ITER_T WOffsetR = ((S_ITER_T)(WF >> 1)) - WPadR;
|
||||
|
||||
for (ITER_T n = 0; n < N; n++) {
|
||||
ITER_T margin = 0, nstart = 0;
|
||||
if (HOffsetL + ((S_ITER_T)(HF >> 1) + 1) - ((S_ITER_T)HStride) > 0) {
|
||||
margin = (ITER_T)(HOffsetL + ((S_ITER_T)(HF >> 1) + 1) - ((S_ITER_T)HStride));
|
||||
if ((S_ITER_T)HF - HPadU - (S_ITER_T)HStride > 0) {
|
||||
margin = (ITER_T)((S_ITER_T)HF - HPadU - (S_ITER_T)HStride);
|
||||
}
|
||||
if (HPadU < 0) {
|
||||
// nstart will always be zero unless HPadU is negative.
|
||||
|
@ -66,7 +66,6 @@ void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
|
|||
for (ITER_T k = 0; k < CTemp; k++) {
|
||||
ITER_T iRed = (i + margin + hout * HStride) % HF;
|
||||
ITER_T iFull = i + margin + hout * HStride;
|
||||
convBuffer1[iRed * W * CTemp + j * CTemp + k] = 0;
|
||||
for (ITER_T l = 0; l < CIn; l++) {
|
||||
if (iFull < H) {
|
||||
treesumBuffer[l] = ((INTM_T)input[n * H * W * CIn + iFull * W * CIn + j * CIn + l]) *
|
||||
|
@ -98,13 +97,13 @@ void q_mbconv_block(const INT_T* const input, const INT_T* const filter1,
|
|||
for (S_ITER_T w = WOffsetL; w < ((S_ITER_T)W) - WOffsetR; wout++, w += ((S_ITER_T)WStride)) {
|
||||
for (ITER_T g = 0; g < CTemp; g++) {
|
||||
ITER_T counter = 0;
|
||||
for (S_ITER_T hf = -(HF >> 1); hf <= (HF >> 1); hf++) {
|
||||
for (S_ITER_T wf = -(WF >> 1); wf <= (WF >> 1); wf++) {
|
||||
for (S_ITER_T hf = -((HF - 1) >> 1); hf <= (HF >> 1); hf++) {
|
||||
for (S_ITER_T wf = -((WF - 1) >> 1); wf <= (WF >> 1); wf++) {
|
||||
if (((h + hf) < 0) || ((h + hf) >= (S_ITER_T)H) || ((w + wf) < 0) || ((w + wf) >= (S_ITER_T)W)) {
|
||||
treesumBuffer[counter] = 0;
|
||||
} else {
|
||||
treesumBuffer[counter] = ((INTM_T)convBuffer1[(((ITER_T)(h + hf)) % HF) * W * CTemp + ((ITER_T)(w + wf)) * CTemp + g]) *
|
||||
((INTM_T)filter2[g * HF * WF + ((ITER_T)(hf + (HF >> 1))) * WF + ((ITER_T)(wf + (WF >> 1)))]);
|
||||
((INTM_T)filter2[g * HF * WF + ((ITER_T)(hf + ((HF - 1) >> 1))) * WF + ((ITER_T)(wf + ((WF - 1) >> 1)))]);
|
||||
}
|
||||
counter++;
|
||||
}
|
||||
|
|
|
@ -355,3 +355,135 @@ void t_q_sub_vec(const INT_T* const mat, const INT_T* const vec,
|
|||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void q_maxpool(const INT_T* const input, INT_T* const output, ITER_T N,
|
||||
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF, ITER_T CF,
|
||||
ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G, S_ITER_T HPadU,
|
||||
S_ITER_T HPadD, S_ITER_T WPadL, S_ITER_T WPadR, ITER_T HStride,
|
||||
ITER_T WStride, ITER_T HDilation, ITER_T WDilation,
|
||||
SCALE_T scinput, SCALE_T scoutput) {
|
||||
S_ITER_T HOffsetL = ((S_ITER_T)HDilation * (S_ITER_T)((HF - 1) >> 1)) - HPadU;
|
||||
S_ITER_T WOffsetL = ((S_ITER_T)WDilation * (S_ITER_T)((WF - 1) >> 1)) - WPadL;
|
||||
S_ITER_T HOffsetR = ((S_ITER_T)HDilation * (S_ITER_T)(HF >> 1)) - HPadD;
|
||||
S_ITER_T WOffsetR = ((S_ITER_T)WDilation * (S_ITER_T)(WF >> 1)) - WPadR;
|
||||
|
||||
ITER_T HOffsetIn = W * CIn;
|
||||
ITER_T NOffsetIn = H * HOffsetIn;
|
||||
ITER_T WOffsetOut = (COut * G);
|
||||
ITER_T HOffsetOut = WOut * WOffsetOut;
|
||||
ITER_T NOffsetOut = HOut * HOffsetOut;
|
||||
for (ITER_T n = 0; n < N; n++) {
|
||||
ITER_T hout = 0;
|
||||
ITER_T NIndexIn = n * NOffsetIn;
|
||||
ITER_T NIndexOut = n * NOffsetOut;
|
||||
for (S_ITER_T h = HOffsetL; h < (S_ITER_T)H - HOffsetR; h += (S_ITER_T)HStride, hout++) {
|
||||
ITER_T wout = 0;
|
||||
ITER_T HIndexOut = hout * HOffsetOut;
|
||||
for (S_ITER_T w = WOffsetL; w < (S_ITER_T)W - WOffsetR; w += (S_ITER_T)WStride, wout++) {
|
||||
ITER_T WIndexOut = wout * WOffsetOut;
|
||||
for (ITER_T g = 0; g < G; g++) {
|
||||
ITER_T CIndexIn = g * CF;
|
||||
ITER_T CIndexOut = g * COut;
|
||||
for (ITER_T c = 0; c < COut; c++) {
|
||||
|
||||
INT_T max = INT_TMIN;
|
||||
for (S_ITER_T hf = -((HF - 1) >> 1); hf <= (HF >> 1); hf++) {
|
||||
S_ITER_T hoffset = h + ((S_ITER_T)HDilation * hf);
|
||||
ITER_T HIndexIn = ((ITER_T)hoffset) * HOffsetIn;
|
||||
for (S_ITER_T wf = -((WF - 1) >> 1); wf <= (WF >> 1); wf++) {
|
||||
S_ITER_T woffset = w + ((S_ITER_T)WDilation * wf);
|
||||
ITER_T WIndexIn = ((ITER_T)woffset) * CIn;
|
||||
for (ITER_T cf = 0; cf < CF; cf++) {
|
||||
if ((hoffset < 0) || (hoffset >= (S_ITER_T)H) || (woffset < 0) || (woffset >= (S_ITER_T)W)) {
|
||||
if (max < 0) {
|
||||
max = 0;
|
||||
}
|
||||
} else {
|
||||
INT_T a = input[NIndexIn + HIndexIn + WIndexIn + (cf + CIndexIn)];
|
||||
if (max < a) {
|
||||
max = a;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SHIFT
|
||||
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = (max >> (scinput + scoutput));
|
||||
#else
|
||||
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = ((max / scinput) / scoutput);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void q_convolution(const INT_T* const input, const INT_T* const filter,
|
||||
INT_T* const output, INTM_T* const treesumBuffer, ITER_T N,
|
||||
ITER_T H, ITER_T W, ITER_T CIn, ITER_T HF, ITER_T WF,
|
||||
ITER_T CF, ITER_T COut, ITER_T HOut, ITER_T WOut, ITER_T G,
|
||||
S_ITER_T HPadU, S_ITER_T HPadD, S_ITER_T WPadL,
|
||||
S_ITER_T WPadR, ITER_T HStride, ITER_T WStride,
|
||||
ITER_T HDilation, ITER_T WDilation, SCALE_T H1, SCALE_T H2,
|
||||
SCALE_T scinput, SCALE_T scoutput) {
|
||||
S_ITER_T HOffsetL = ((S_ITER_T)HDilation * (S_ITER_T)((HF - 1) >> 1)) - HPadU;
|
||||
S_ITER_T WOffsetL = ((S_ITER_T)WDilation * (S_ITER_T)((WF - 1) >> 1)) - WPadL;
|
||||
S_ITER_T HOffsetR = ((S_ITER_T)HDilation * (S_ITER_T)(HF >> 1)) - HPadD;
|
||||
S_ITER_T WOffsetR = ((S_ITER_T)WDilation * (S_ITER_T)(WF >> 1)) - WPadR;
|
||||
|
||||
ITER_T HOffsetIn = W * CIn;
|
||||
ITER_T NOffsetIn = H * HOffsetIn;
|
||||
ITER_T WOffsetF = CF * COut;
|
||||
ITER_T HOffsetF = WF * WOffsetF;
|
||||
ITER_T WOffsetOut = (COut * G);
|
||||
ITER_T HOffsetOut = WOut * WOffsetOut;
|
||||
ITER_T NOffsetOut = HOut * HOffsetOut;
|
||||
for (ITER_T n = 0; n < N; n++) {
|
||||
ITER_T hout = 0;
|
||||
ITER_T NIndexIn = n * NOffsetIn;
|
||||
ITER_T NIndexOut = n * NOffsetOut;
|
||||
for (S_ITER_T h = HOffsetL; h < (S_ITER_T)H - HOffsetR; h += (S_ITER_T)HStride, hout++) {
|
||||
ITER_T wout = 0;
|
||||
ITER_T HIndexOut = hout * HOffsetOut;
|
||||
for (S_ITER_T w = WOffsetL; w < (S_ITER_T)W - WOffsetR; w += (S_ITER_T)WStride, wout++) {
|
||||
ITER_T WIndexOut = wout * WOffsetOut;
|
||||
for (ITER_T g = 0; g < G; g++) {
|
||||
ITER_T CIndexIn = g * CF;
|
||||
ITER_T CIndexOut = g * COut;
|
||||
for (ITER_T c = 0; c < COut; c++) {
|
||||
|
||||
ITER_T counter = 0;
|
||||
for (S_ITER_T hf = -((HF - 1) >> 1); hf <= (HF >> 1); hf++) {
|
||||
S_ITER_T hoffset = h + ((S_ITER_T)HDilation * hf);
|
||||
ITER_T HIndexIn = ((ITER_T)hoffset) * HOffsetIn;
|
||||
ITER_T HIndexF = ((ITER_T)(hf + ((HF - 1) >> 1))) * HOffsetF;
|
||||
for (S_ITER_T wf = -((WF - 1) >> 1); wf <= (WF >> 1); wf++) {
|
||||
S_ITER_T woffset = w + ((S_ITER_T)WDilation * wf);
|
||||
ITER_T WIndexIn = ((ITER_T)woffset) * CIn;
|
||||
ITER_T WIndexF = ((ITER_T)(wf + ((WF - 1) >> 1))) * WOffsetF;
|
||||
for (ITER_T cf = 0; cf < CF; cf++) {
|
||||
if ((hoffset < 0) || (hoffset >= (S_ITER_T)H) || (woffset < 0) || (woffset >= (S_ITER_T)W)) {
|
||||
treesumBuffer[counter] = 0;
|
||||
} else {
|
||||
treesumBuffer[counter] = ((INTM_T)input[NIndexIn + HIndexIn + WIndexIn + (cf + CIndexIn)]) *
|
||||
((INTM_T)filter[HIndexF + WIndexF + (c + cf * COut)]);
|
||||
}
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v_q_treesum(&treesumBuffer[0], HF * WF * CF, H1, H2);
|
||||
#ifdef SHIFT
|
||||
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = (treesumBuffer[0] >> (scinput + scoutput));
|
||||
#else
|
||||
output[NIndexOut + HIndexOut + WIndexOut + (c + CIndexOut)] = ((treesumBuffer[0] / scinput) / scoutput);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -402,6 +402,199 @@ int test_t_q_sub_vec() {
|
|||
return check_output(pred, expected, 16);
|
||||
}
|
||||
|
||||
// Test q_maxpool() function.
|
||||
int test_q_maxpool() {
|
||||
const INT_T qmat_A[2 * 2 * 2 * 2] = {11, 220,
|
||||
130, 40,
|
||||
|
||||
50, 60,
|
||||
66, 76,
|
||||
|
||||
|
||||
86, 910,
|
||||
411, 312,
|
||||
|
||||
513, 514,
|
||||
715, 716};
|
||||
const INT_T qmat_B[2 * 2 * 2 * 2] = {100, 992,
|
||||
15, 26,
|
||||
|
||||
27, 8,
|
||||
3, 4,
|
||||
|
||||
|
||||
5, 2,
|
||||
2, 2,
|
||||
|
||||
7, 8,
|
||||
29, 140};
|
||||
const INT_T expected_A[2 * 1 * 1 * 2] = {32, 55,
|
||||
|
||||
|
||||
178, 227};
|
||||
const INT_T expected_B[2 * 3 * 3 * 2] = {100, 992,
|
||||
100, 992,
|
||||
15, 26,
|
||||
|
||||
100, 992,
|
||||
100, 992,
|
||||
15, 26,
|
||||
|
||||
27, 8,
|
||||
27, 8,
|
||||
3, 4,
|
||||
|
||||
|
||||
5, 2,
|
||||
5, 2,
|
||||
2, 2,
|
||||
|
||||
7, 8,
|
||||
29, 140,
|
||||
29, 140,
|
||||
|
||||
7, 8,
|
||||
29, 140,
|
||||
29, 140};
|
||||
const INT_T expected_C[2 * 2 * 2 * 2] = {100, 992,
|
||||
100, 992,
|
||||
|
||||
100, 992,
|
||||
100, 992,
|
||||
|
||||
|
||||
29, 140,
|
||||
29, 140,
|
||||
|
||||
29, 140,
|
||||
29, 140};
|
||||
const INT_T expected_D[2 * 3 * 3 * 2] = {16, 19,
|
||||
0, 0,
|
||||
12, 15,
|
||||
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
|
||||
32, 10,
|
||||
0, 0,
|
||||
2, 55,
|
||||
|
||||
|
||||
178, 179,
|
||||
0, 0,
|
||||
128, 128,
|
||||
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
|
||||
102, 78,
|
||||
0, 0,
|
||||
21, 227};
|
||||
INT_T pred_A[2 * 1 * 1 * 2], pred_B[2 * 3 * 3 * 2], pred_C[2 * 2 * 2 * 2], pred_D[2 * 3 * 3 * 2];
|
||||
|
||||
#ifdef SHIFT
|
||||
q_maxpool(qmat_A, pred_A, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1);
|
||||
q_maxpool(qmat_B, pred_B, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0);
|
||||
q_maxpool(qmat_B, pred_C, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0);
|
||||
q_maxpool(qmat_A, pred_D, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 1, 1);
|
||||
#else
|
||||
q_maxpool(qmat_A, pred_A, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2);
|
||||
q_maxpool(qmat_B, pred_B, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||
q_maxpool(qmat_B, pred_C, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
|
||||
q_maxpool(qmat_A, pred_D, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2);
|
||||
#endif
|
||||
return (check_output(pred_A, expected_A, 4) || check_output(pred_B, expected_B, 36) || check_output(pred_C, expected_C, 8) || check_output(pred_D, expected_D, 36));
|
||||
}
|
||||
|
||||
// Test q_convolution() function.
|
||||
int test_q_convolution() {
|
||||
const INT_T qmat_A[2 * 2 * 2 * 2] = {11, 220,
|
||||
130, 40,
|
||||
|
||||
50, 60,
|
||||
66, 76,
|
||||
|
||||
|
||||
86, 910,
|
||||
411, 312,
|
||||
|
||||
513, 514,
|
||||
715, 716};
|
||||
//Convolution Filters
|
||||
const INT_T qmat_B[2 * 2 * 1 * 1] = {0, 1,
|
||||
1, 0};
|
||||
const INT_T qmat_C[2 * 2 * 2 * 1] = {0, 1,
|
||||
1, 0,
|
||||
|
||||
1, 0,
|
||||
0, 1};
|
||||
const INT_T qmat_D[3 * 3 * 1 * 1] = {0, 0, 1,
|
||||
0, 1, 0,
|
||||
1, 0, 0};
|
||||
|
||||
const INT_T expected_A[2 * 1 * 1 * 2] = {44, 25,
|
||||
|
||||
|
||||
230, 206};
|
||||
const INT_T expected_B[2 * 1 * 1 * 1] = {58,
|
||||
|
||||
|
||||
317};
|
||||
const INT_T expected_C[2 * 2 * 2 * 2] = {1, 27,
|
||||
22, 12,
|
||||
|
||||
22, 12,
|
||||
8, 9,
|
||||
|
||||
|
||||
10, 113,
|
||||
115, 103,
|
||||
|
||||
115, 103,
|
||||
89, 89};
|
||||
const INT_T expected_D[2 * 3 * 3 * 2] = {0, 0,
|
||||
0, 0,
|
||||
1, 1,
|
||||
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
|
||||
4, 1,
|
||||
0, 0,
|
||||
0, 0,
|
||||
|
||||
|
||||
0, 0,
|
||||
0, 0,
|
||||
16, 16,
|
||||
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
|
||||
12, 9,
|
||||
0, 0,
|
||||
0, 0};
|
||||
INT_T pred_A[2 * 1 * 1 * 2], pred_B[2 * 1 * 1 * 1], pred_C[2 * 2 * 2 * 2], pred_D[2 * 3 * 3 * 2];
|
||||
INTM_T temp[16];
|
||||
|
||||
#ifdef SHIFT
|
||||
q_convolution(qmat_A, qmat_B, pred_A, temp, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 0, 0, 0);
|
||||
q_convolution(qmat_A, qmat_C, pred_B, temp, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 3, 0, 0, 0);
|
||||
q_convolution(qmat_A, qmat_D, pred_C, temp, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 3, 0, 0, 0);
|
||||
q_convolution(qmat_A, qmat_B, pred_D, temp, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 3, 0, 1, 1);
|
||||
#else
|
||||
q_convolution(qmat_A, qmat_B, pred_A, temp, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 0, 1, 1);
|
||||
q_convolution(qmat_A, qmat_C, pred_B, temp, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 3, 0, 1, 1);
|
||||
q_convolution(qmat_A, qmat_D, pred_C, temp, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 3, 0, 1, 1);
|
||||
q_convolution(qmat_A, qmat_B, pred_D, temp, 2, 2, 2, 2, 2, 2, 1, 1, 3, 3, 2, 2, 2, 2, 2, 1, 1, 3, 3, 3, 0, 2, 2);
|
||||
#endif
|
||||
return (check_output(pred_A, expected_A, 4) || check_output(pred_B, expected_B, 2) || check_output(pred_C, expected_C, 16) || check_output(pred_D, expected_D, 36));
|
||||
}
|
||||
|
||||
int main() {
|
||||
if (test_v_q_treesum()) {
|
||||
printf("Test Failure for v_q_treesum()!\n");
|
||||
|
@ -449,6 +642,10 @@ int main() {
|
|||
printf("Test Failure for t_q_add_vec()!\n");
|
||||
} else if (test_t_q_sub_vec()) {
|
||||
printf("Test Failure for t_q_sub_vec()!\n");
|
||||
} else if (test_q_maxpool()) {
|
||||
printf("Test Failure for q_maxpool()!\n");
|
||||
} else if (test_q_convolution()) {
|
||||
printf("Test Failure for q_convolution()!\n");
|
||||
} else {
|
||||
printf("All Tests Passed!\n");
|
||||
return 0;
|
||||
|
|
Загрузка…
Ссылка в новой задаче