Fix sparse matrix MultiplyAndAdd bug that old value may not be kept

This commit is contained in:
KeDengMS 2016-12-06 23:49:01 -08:00
Родитель e7edfa72e8
Коммит c12979b245
9 изменённых файлов: 301 добавлений и 123 удалений

Просмотреть файл

@ -244,7 +244,7 @@ void CPUSparseMatrix<ElemType>::SetValue(const CPUSparseMatrix<ElemType>& v)
{
SetFormat(v.GetFormat());
RequireSizeAndAllocate(v.GetNumRows(), v.GetNumCols(), v.NzSize());
RequireSizeAndAllocate(v.GetNumRows(), v.GetNumCols(), v.NzCount() ); // TODO: rename to *Bytes/*Count instead of vague *Size if possible
let nz = v.NzCount();
auto matrixFormat = v.GetFormat();
@ -262,7 +262,8 @@ void CPUSparseMatrix<ElemType>::SetValue(const CPUSparseMatrix<ElemType>& v)
}
else
{
memcpy(GetBlockIds(), v.GetBlockIds(), v.GetBlockSize());
memcpy(GetBlockIds(), v.GetBlockIds(), v.GetBlockSize() * sizeof(size_t)); // TODO: change block id from size_t to CPUSPARSE_INDEX_TYPE, and rename BlockSize to BlockCount
SetBlockSize(v.GetBlockSize());
}
}
if (v.m_sliceViewOffset > 0)
@ -644,6 +645,22 @@ void CPUSparseMatrix<ElemType>::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP
memcpy(NzValues(), h_Val, sizeof(ElemType)*nz);
}
#if 0 // add it back with test
template <class ElemType>
void CPUSparseMatrix<ElemType>::SetMatrixFromSBCFormat(const size_t* blockIds, const ElemType* val, const size_t numBlocks, const size_t numRows, const size_t numCols)
{
if (!OwnBuffer())
LogicError("Cannot modify since the buffer is managed externally.");
SetFormat(matrixFormatSparseBlockCol);
Resize(numRows, numCols, numBlocks * numRows);
SetBlockSize(numBlocks);
memcpy(GetBlockIds(), blockIds, sizeof(size_t)*(numBlocks));
memcpy(Data(), val, sizeof(ElemType)*numBlocks*numRows);
}
#endif
template <class ElemType>
ElemType* CPUSparseMatrix<ElemType>::Data() const
{
@ -959,8 +976,6 @@ void CPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const CPUMatrix<E
InvalidArgument("CPUSparseMatrix::MultiplyAndAdd: The inner dimensions of a (= %lu) and b (= %lu) don't match.", k, l);
}
c.Reset();
if (!transposeA && !transposeB)
{
NOT_IMPLEMENTED;
@ -972,50 +987,56 @@ void CPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const CPUMatrix<E
// allocate enough memory
c.SetFormat(matrixFormatSparseBlockCol);
c.RequireSizeAndAllocate(m, n, m * min(n, rhs.NzCount()), true, false);
size_t blockSizePrev = c.GetBlockSize();
map<size_t, size_t> w2Id;
for (size_t j = 0; j < rhs.GetNumCols(); j++)
{ // j ranges over batches
size_t start = rhs.SecondaryIndexLocation()[j];
size_t end = rhs.SecondaryIndexLocation()[j + 1];
if (blockSizePrev == 0)
{
c.RequireSizeAndAllocate(m, n, 0, true); // allocate for blockIds
}
map<size_t, size_t> col2BlockId;
for (size_t blockId = 0; blockId < blockSizePrev; blockId++)
{
col2BlockId[c.GetBlockIds()[blockId]] = blockId;
}
size_t blockSizeCurr = blockSizePrev;
for (size_t rhsNz = 0; rhsNz < rhs.NzCount(); rhsNz++)
{
size_t resultCol = rhs.MajorIndexLocation()[rhsNz];
if (col2BlockId.find(resultCol) == col2BlockId.end())
{
col2BlockId[resultCol] = blockSizeCurr;
c.GetBlockIds()[blockSizeCurr] = resultCol;
blockSizeCurr ++;
}
}
if (blockSizeCurr > blockSizePrev)
{
c.RequireSizeAndAllocate(m, n, m * blockSizeCurr, true, true);
c.SetBlockSize(blockSizeCurr);
memset(c.Data() + m * blockSizePrev, 0, sizeof(ElemType) * m * (blockSizeCurr - blockSizePrev));
}
for (size_t rhsCol = 0; rhsCol < rhs.GetNumCols(); rhsCol++)
{
size_t start = rhs.SecondaryIndexLocation()[rhsCol];
size_t end = rhs.SecondaryIndexLocation()[rhsCol + 1];
for (size_t p = start; p < end; p++)
{
size_t i = rhs.MajorIndexLocation()[p]; // i ranges over words
ElemType val = rhs.Buffer()[p]; // 1 for(i, j)
size_t rhsRow = rhs.MajorIndexLocation()[p];
ElemType val = rhs.Buffer()[p];
bool first = true;
if (w2Id.find(i) == w2Id.end())
ElemType* results = c.Buffer() + col2BlockId[rhsRow] * m;
#pragma omp parallel for
for (int lhsRow = 0; lhsRow < (int)m; lhsRow++)
{
size_t id = w2Id.size();
w2Id[i] = id;
c.GetBlockIds()[c.GetBlockSize()] = i;
c.SetBlockSize(c.GetBlockSize() + 1);
}
else
{
first = false;
}
size_t pos = w2Id[i] * lhs.GetNumRows();
for (size_t h = 0; h < lhs.GetNumRows(); h++)
{ // h range over hidden layer
if (first == true)
{
c.Buffer()[pos] = alpha * lhs(h, j) * val;
}
else
{
c.Buffer()[pos] += alpha * lhs(h, j) * val;
}
pos++;
results[lhsRow] += alpha * lhs((size_t)lhsRow, rhsCol) * val;
}
}
}
if (c.GetBlockSize() * m > c.GetSizeAllocated())
{
LogicError("Sparse matrix is unexpectedly out of range.");
}
}
else if (transposeA && !transposeB)
{

Просмотреть файл

@ -201,10 +201,19 @@ public:
return 0;
}
else
else if (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol)
{
NOT_IMPLEMENTED;
for (size_t blockId = 0; blockId < GetBlockSize(); blockId++)
{
size_t blockCol = GetBlockIds()[blockId] - GetBlockIdShift();
if (blockCol == col)
{
return ((ElemType*)Buffer())[blockId * GetNumRows() + row];
}
}
return 0;
}
NOT_IMPLEMENTED;
}
public:
@ -258,6 +267,11 @@ public:
BaseMatrix<ElemType>::SetBlockSize(newBlockSize);
}
size_t GetBlockSize() const
{
return BaseMatrix<ElemType>::GetBlockSize();
}
size_t* BlockIdsLocation() const
{
if ((GetFormat() != matrixFormatSparseBlockCol) && (GetFormat() != matrixFormatSparseBlockRow))
@ -325,7 +339,6 @@ public:
{
return (GetFormat() & matrixFormatRowMajor) ? MajorIndexSize() : SecondaryIndexSize();
} // actual number of bytes in use
};
typedef CPUSparseMatrix<float> CPUSingleSparseMatrix;

Просмотреть файл

@ -3008,6 +3008,10 @@ __global__ void _reshape(
newColumnIndex[newNumCols] = oldColumnIndex[oldNumCols]; // set end pointer
}
// special markers in BlockId2ColOrRow()/ColOrRow2BlockId()
static const GPUSPARSE_INDEX_TYPE Id_NotAssigned = -1;
static const GPUSPARSE_INDEX_TYPE Id_Pending = INT_MAX;
//called before _determineBlockIds and _denseMulSparseCSCTransposeToSparseBlockCol to determine which columns have values and
//what's the mapping from the column id in the resulted SparseBlockCol format to the column id in the dense format
//input: rowIndexes: the row indexes of the CSC sparse matrix to be multiplied with
@ -3015,13 +3019,14 @@ __global__ void _reshape(
//nnz: number of nonzero value or the size of rowIndexes;
template <class ElemType>
__global__ void _findColsWithValues(
const GPUSPARSE_INDEX_TYPE* rowIndexes, GPUSPARSE_INDEX_TYPE* blockIds, const size_t nnz)
const GPUSPARSE_INDEX_TYPE* rowIndexes, GPUSPARSE_INDEX_TYPE* col2BlockIds, const size_t nnz)
{
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= nnz)
const size_t nzIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (nzIndex >= nnz)
return;
blockIds[rowIndexes[index]] = 1; // this row has value.
if (col2BlockIds[rowIndexes[nzIndex]] == Id_NotAssigned)
col2BlockIds[rowIndexes[nzIndex]] = Id_Pending; // this row has value.
}
//called before _denseMulSparseCSCTransposeToSparseBlockCol and after _findColsWithValuesto determine which columns have values and
@ -3030,26 +3035,21 @@ __global__ void _findColsWithValues(
//blockId2Col: the blockID to colum id mapping in the resulting matrix;
//col2BlockId: the col2BlockId to blockID mapping in the resulting matrix;
//numCols: number of columns in the resulting matrix or the size of blockIDs
//blockSize: return the blockSize with values, *blockSize must be zero before passed in.
//blockSize: return the blockSize with values
template <class ElemType>
__global__ void _determineBlockIds(
GPUSPARSE_INDEX_TYPE* blockId2Col, GPUSPARSE_INDEX_TYPE* col2BlockId, const size_t numCols, size_t* blockSize)
GPUSPARSE_INDEX_TYPE* blockId2Col, GPUSPARSE_INDEX_TYPE* col2BlockId, size_t numCols, size_t* blockSize)
{
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= numCols)
const size_t col = blockIdx.x * blockDim.x + threadIdx.x;
if (col >= numCols)
return;
size_t blockIndex = numCols;
if (blockId2Col[index] > 0)
if (col2BlockId[col] == Id_Pending)
{
blockIndex = atomicAdd((unsigned int*) blockSize, (unsigned int) 1);
col2BlockId[index] = blockIndex;
GPUSPARSE_INDEX_TYPE blockIndex = atomicAdd((unsigned int*)blockSize, (unsigned int)1);
col2BlockId[col] = blockIndex;
blockId2Col[blockIndex] = col;
}
__syncthreads();
if (blockIndex < numCols)
blockId2Col[blockIndex] = index;
}
// backward pass from hidden layer to feature weight

Просмотреть файл

@ -947,6 +947,42 @@ void GPUSparseMatrix<ElemType>::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP
}
}
#if 0 // add it back with test
template <class ElemType>
void GPUSparseMatrix<ElemType>::SetMatrixFromSBCFormat(const size_t* blockIds, const ElemType* val, const size_t numBlocks, const size_t numRows, const size_t numCols)
{
VerifyWritable(__FUNCTION__);
if (blockIds == nullptr || val == nullptr)
LogicError("SetMatrixFromSBCFormat: nullptr passed in.");
SetFormat(matrixFormatSparseBlockCol);
SetBlockSize(numBlocks);
if (numBlocks == 0) return; // ====>
size_t nz = numBlocks * numRows;
RequireSizeAndAllocate(numRows, numCols, nz, true, false);
static std::vector<GPUSPARSE_INDEX_TYPE> gpuBlockId2Col(numBlocks);
static std::vector<GPUSPARSE_INDEX_TYPE> gpuCol2BlockId(numCols);
std::fill(gpuBlockId2Col.begin(), gpuBlockId2Col.end(), Id_NotAssigned);
std::fill(gpuCol2BlockId.begin(), gpuCol2BlockId.end(), Id_NotAssigned);
#pragma omp parallel for
for (int i = 0; i < numBlocks; ++i)
{
gpuBlockId2Col[i] = (GPUSPARSE_INDEX_TYPE)blockIds[i];
gpuCol2BlockId[blockIds[i]] = i;
}
CUDA_CALL(cudaMemcpy(Data(), val, nz * sizeof(ElemType), cudaMemcpyHostToDevice));
CUDA_CALL(cudaMemcpy(BlockId2ColOrRow(), &gpuBlockId2Col[0], numBlocks * sizeof(GPUSPARSE_INDEX_TYPE), cudaMemcpyHostToDevice));
CUDA_CALL(cudaMemcpy(ColOrRow2BlockId(), &gpuCol2BlockId[0], numBlocks * sizeof(GPUSPARSE_INDEX_TYPE), cudaMemcpyHostToDevice));
}
#endif
// this function will allocate memory while the caller needs to release it
template <class ElemType>
void GPUSparseMatrix<ElemType>::GetMatrixFromCSCFormat(GPUSPARSE_INDEX_TYPE*& h_CSCCol, GPUSPARSE_INDEX_TYPE*& h_Row, ElemType*& h_Val, size_t& numElemAllocated, size_t& nz, size_t& numRows, size_t& numCols) const
@ -1223,68 +1259,51 @@ void GPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const GPUMatrix<E
// based on the size of m_nz in rhs and numCols in the resulted matrix we use different approaches
size_t rhs_nz = rhs.NzCount();
if (n * 10 < GridDim::maxThreadsPerBlock * rhs_nz)
size_t blockSizePrev = c.GetBlockSize();
if (blockSizePrev == 0)
{
c.RequireSizeAndAllocate(m, n, 1, true, false); // reserve memory for BlockId2ColOrRow() and ColOrRow2BlockId()
c.Resize(m, n, 0);
CUDA_CALL(cudaMemset(c.ColOrRow2BlockId(), Id_NotAssigned, sizeof(GPUSPARSE_INDEX_TYPE) * (n)));
CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), Id_NotAssigned, sizeof(GPUSPARSE_INDEX_TYPE) * (n)));
}
size_t* blockSize = TracingGPUMemoryAllocator::Allocate<size_t>(lhs.GetComputeDeviceId(), 1);
CUDA_CALL(cudaMemset(blockSize, 0, sizeof(size_t)));
size_t* blockSize = TracingGPUMemoryAllocator::Allocate<size_t>(lhs.GetComputeDeviceId(), 1);
CUDA_CALL(cudaMemcpy(blockSize, &blockSizePrev, sizeof(size_t), cudaMemcpyHostToDevice));
CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), 0, sizeof(GPUSPARSE_INDEX_TYPE) * (n)));
blocksPerGrid = (int) ceil(((double) rhs_nz) / GridDim::maxThreadsPerBlock);
_findColsWithValues<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
rhs.RowLocation(), c.BlockId2ColOrRow(), rhs_nz);
blocksPerGrid = (int) ceil(((double) rhs_nz) / GridDim::maxThreadsPerBlock);
_findColsWithValues<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
rhs.RowLocation(), c.ColOrRow2BlockId(), rhs_nz);
blocksPerGrid = (int) ceil(((double) n) / GridDim::maxThreadsPerBlock);
_determineBlockIds<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
c.BlockId2ColOrRow(), c.ColOrRow2BlockId(), n, blockSize);
blocksPerGrid = (int) ceil(((double) n) / GridDim::maxThreadsPerBlock);
_determineBlockIds<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
c.BlockId2ColOrRow(), c.ColOrRow2BlockId(), n, blockSize);
size_t block = c.GetBlockSize();
CUDA_CALL(cudaMemcpy(&block, blockSize, sizeof(size_t), cudaMemcpyDeviceToHost));
TracingGPUMemoryAllocator::Free<size_t>(lhs.GetComputeDeviceId(), blockSize);
c.SetBlockSize(block);
size_t blockSizeCurr;
CUDA_CALL(cudaMemcpy(&blockSizeCurr, blockSize, sizeof(size_t), cudaMemcpyDeviceToHost));
TracingGPUMemoryAllocator::Free<size_t>(lhs.GetComputeDeviceId(), blockSize);
c.SetBlockSize(blockSizeCurr);
size_t nnz = m * c.GetBlockSize();
if (blockSizeCurr > blockSizePrev)
{
// zero initialize new blocks
size_t nnz = m * blockSizeCurr;
c.RequireSizeAndAllocate(m, n, nnz, true, true); // we need to keep the col2blockid and blockid2col info when resizing.
CUDA_CALL(cudaMemset(c.Data(), 0, sizeof(ElemType) * (c.GetSizeAllocated())));
LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs)
blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock);
_denseMulSparseCSCTransposeToSparseBlockCol2<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
alpha,
lhs.Data(),
m,
l,
rhs.Data(),
rhs.RowLocation(),
rhs.ColLocation(),
c.ColOrRow2BlockId(),
c.Data());
CUDA_CALL(cudaMemset(c.Data() + m * blockSizePrev, 0, sizeof(ElemType) * m * (blockSizeCurr - blockSizePrev)));
}
else
{
c.SetBlockSize(rhs.IdentifyRowsWithValues());
size_t nnz = m * c.GetBlockSize();
c.RequireSizeAndAllocate(m, n, nnz, true, false);
CUDA_CALL(cudaMemset(c.Data(), 0, sizeof(ElemType) * (c.GetSizeAllocated())));
CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), 0, sizeof(GPUSPARSE_INDEX_TYPE) * (c.GetBlockSize())));
LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs)
blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock);
_denseMulSparseCSCTransposeToSparseBlockCol<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
alpha,
lhs.Data(),
m,
l,
rhs.Data(),
rhs.RowLocation(),
rhs.ColLocation(),
rhs.GetTempDeviceBuffer(),
c.Data(),
c.BlockId2ColOrRow());
}
LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs)
blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock);
_denseMulSparseCSCTransposeToSparseBlockCol2<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
alpha,
lhs.Data(),
m,
l,
rhs.Data(),
rhs.RowLocation(),
rhs.ColLocation(),
c.ColOrRow2BlockId(),
c.Data());
}
else if (transposeA && !transposeB)
{

Просмотреть файл

@ -4544,7 +4544,18 @@ void Matrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemT
}
else if (c.GetMatrixType() == MatrixType::SPARSE) // CPU, DENSE * SPARSE -> SPARSE
{
CPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUSparseMatrix, transposeB, *c.m_CPUSparseMatrix);
if (beta != 0 && beta != 1)
{
NOT_IMPLEMENTED;
}
else
{
if (beta == 0)
{
c.Reset();
}
CPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUSparseMatrix, transposeB, *c.m_CPUSparseMatrix);
}
c.SetDataLocation(CPU, SPARSE);
}
else
@ -4578,7 +4589,18 @@ void Matrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemT
}
else if (a.m_matrixType == MatrixType::DENSE && b.m_matrixType == MatrixType::SPARSE && c.m_matrixType == MatrixType::SPARSE) // GPU, DENSE * SPARSE -> SPARSE
{
GPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_GPUMatrix, transposeA, *b.m_GPUSparseMatrix, transposeB, *c.m_GPUSparseMatrix);
if (beta != 0 && beta != 1)
{
NOT_IMPLEMENTED;
}
else
{
if (beta == 0)
{
c.Reset();
}
GPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_GPUMatrix, transposeA, *b.m_GPUSparseMatrix, transposeB, *c.m_GPUSparseMatrix);
}
c.SetDataLocation(GPU, SPARSE);
}
else if (a.m_matrixType == MatrixType::SPARSE && b.m_matrixType == MatrixType::SPARSE && c.m_matrixType == MatrixType::SPARSE) // GPU, SPARSE * SPARSE -> SPARSE

Просмотреть файл

@ -212,6 +212,13 @@ void GPUSparseMatrix<ElemType>::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP
{
}
#if 0
template <class ElemType>
void GPUSparseMatrix<ElemType>::SetMatrixFromSBCFormat(const size_t*, const ElemType*, const size_t, const size_t, const size_t)
{
}
#endif
// forward pass from feature to hidden layer
template <class ElemType>
void GPUSparseMatrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& lhs, const bool transposeA,

Просмотреть файл

@ -6051,8 +6051,8 @@ Test module "MathTests" has passed with:
770 assertions out of 770 passed
Test suite "CPUMatrixSuite" has passed with:
30 test cases out of 30 passed
5228662 assertions out of 5228662 passed
31 test cases out of 31 passed
5248662 assertions out of 5248662 passed
Test case "CPUMatrixSuite/CPUSparseMatrixColumnSlice" has passed with:
1 assertion out of 1 passed
@ -6060,6 +6060,9 @@ Test module "MathTests" has passed with:
Test case "CPUMatrixSuite/CPUSparseMatrixCopyColumnSliceToDense" has passed with:
1 assertion out of 1 passed
Test case "CPUMatrixSuite/CPUSparseMatrixMultiplyAndAdd" has passed with:
20000 assertions out of 20000 passed
Test case "CPUMatrixSuite/MatrixMultiplyTest" has passed with:
1484 assertions out of 1484 passed
@ -6146,7 +6149,7 @@ Test module "MathTests" has passed with:
Test suite "GPUMatrixSuite" has passed with:
56 test cases out of 56 passed
1263760 assertions out of 1263760 passed
1263762 assertions out of 1263762 passed
Test case "GPUMatrixSuite/GPUBlasMultiplyAndWeightedAdd" has passed with:
1 assertion out of 1 passed
@ -6290,7 +6293,7 @@ Test module "MathTests" has passed with:
4 assertions out of 4 passed
Test case "GPUMatrixSuite/MatrixSparseTimesDense" has passed with:
1 assertion out of 1 passed
3 assertions out of 3 passed
Test case "GPUMatrixSuite/MatrixDenseTimesSparse" has passed with:
2 assertions out of 2 passed
@ -6398,13 +6401,13 @@ Test module "MathTests" has passed with:
2 assertions out of 2 passed
Test suite "QuantizersUnitTests" has passed with:
2 test case out of 2 passed
2 test cases out of 2 passed
12 assertions out of 12 passed
Test case "QuantizersUnitTests/FloatToShort" has passed with:
6 assertions out of 6 passed
Test case "QuantizersUnitTests/QuantizeZeros" has passed with:
Test case "QuantizersUnitTests/QuantizeZeros" has passed with:
6 assertions out of 6 passed
Test suite "QuantizedOperationsUnitTests" has passed with:
@ -6413,7 +6416,7 @@ Test module "MathTests" has passed with:
Test case "QuantizedOperationsUnitTests/MultiplyIntToShort" has passed with:
65 assertions out of 65 passed
Test suite "MathTensorTests" has passed with:
9 test cases out of 9 passed
6 assertions out of 6 passed

Просмотреть файл

@ -61,6 +61,61 @@ BOOST_FIXTURE_TEST_CASE(CPUSparseMatrixCopyColumnSliceToDense, RandomSeedFixture
BOOST_CHECK(dm1.IsEqualTo(dm2, c_epsilonFloatE4));
}
BOOST_FIXTURE_TEST_CASE(CPUSparseMatrixMultiplyAndAdd, RandomSeedFixture)
{
const size_t m = 100;
const size_t n = 50;
DenseMatrix dm0(m, n);
dm0.SetUniformRandomValue(-1, 1, IncrementCounter());
DenseMatrix dm1(m, n);
dm1.SetUniformRandomValue(-300, 1, IncrementCounter());
dm1.InplaceTruncateBottom(0);
SparseMatrix sm1(MatrixFormat::matrixFormatSparseCSC, m, n, 0);
foreach_coord(row, col, dm1)
{
if (dm1(row, col) != 0)
{
sm1.SetValue(row, col, dm1(row, col));
}
}
DenseMatrix dm2(m, n);
dm2.SetUniformRandomValue(-200, 1, IncrementCounter());
dm2.InplaceTruncateBottom(0);
SparseMatrix sm2(MatrixFormat::matrixFormatSparseCSC, m, n, 0);
foreach_coord(row, col, dm2)
{
if (dm2(row, col) != 0)
{
sm2.SetValue(row, col, dm2(row, col));
}
}
// generate SparseBlockCol matrix
DenseMatrix dmMul(m, m);
DenseMatrix::MultiplyAndAdd(dm0, false, dm1, true, dmMul);
SparseMatrix smMul(MatrixFormat::matrixFormatSparseBlockCol, m, m, 0);
SparseMatrix::MultiplyAndAdd(1, dm0, false, sm1, true, smMul);
foreach_coord(row, col, dmMul)
{
BOOST_CHECK(abs(smMul(row, col) - dmMul(row, col)) < c_epsilonFloatE4);
}
SparseMatrix::MultiplyAndAdd(1, dm0, false, sm2, true, smMul);
DenseMatrix::MultiplyAndAdd(dm0, false, dm2, true, dmMul);
foreach_coord(row, col, dmMul)
{
BOOST_CHECK(abs(smMul(row, col) - dmMul(row, col)) < c_epsilonFloatE4);
}
}
BOOST_AUTO_TEST_SUITE_END()
}
} } }

Просмотреть файл

@ -44,7 +44,7 @@ BOOST_FIXTURE_TEST_CASE(MatrixSparseTimesDense, RandomSeedFixture)
{
// DENSE
Matrix<float> mAdense(c_deviceIdZero);
mAdense.AssignTruncateBottomOf(Matrix<float>::RandomUniform(dim1, dim2, c_deviceIdZero, -3.0f, 0.1f, IncrementCounter()), 0);
mAdense.AssignTruncateBottomOf(Matrix<float>::RandomUniform(dim1, dim2, c_deviceIdZero, -300.0f, 0.1f, IncrementCounter()), 0);
// MATRIX mAsparse becomes sparse
Matrix<float> mAsparse(mAdense.DeepClone());
@ -66,6 +66,44 @@ BOOST_FIXTURE_TEST_CASE(MatrixSparseTimesDense, RandomSeedFixture)
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAsparse, transposeA, mB, transposeB, beta, mD);
BOOST_CHECK(mD.IsEqualTo(mC, c_epsilonFloatE4));
// SPARSE * DENSE -> SPARSE
beta = 1; // note that dense only allow beta == 1, while sparse CPU support arbitrary beta
transposeA = false;
transposeB = true; // only support !transposeA && tranposeB for Dense * CSC -> SBC
Matrix<float> mA1sparseCSC(mAdense.DeepClone());
mA1sparseCSC.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC, true);
Matrix<float> mA2dense(c_deviceIdZero);
mA2dense.AssignTruncateBottomOf(Matrix<float>::RandomUniform(dim1, dim2, c_deviceIdZero, -300.0f, 0.1f, IncrementCounter()), 0);
Matrix<float> mA2sparseCSC(mA2dense.DeepClone());
mA2sparseCSC.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC, true);
// dense for comparison
mC.Resize(dim1, dim1);
mC.SetValue(0.0f);
Matrix<float>::MultiplyAndAdd(mAdense, transposeA, mAdense, transposeB, mC);
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2dense, transposeB, beta, mC);
// make sure (dense * sparse -> dense) == (dense * dense -> dense)
mD.Resize(dim1, dim1);
mD.SetValue(0.0f);
Matrix<float>::MultiplyAndAdd(mAdense, transposeA, mA1sparseCSC, transposeB, mD);
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2sparseCSC, transposeB, beta, mD);
BOOST_CHECK(mD.IsEqualTo(mC, c_epsilonFloatE4));
// test on sparse
mD.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseBlockCol, false);
Matrix<float>::MultiplyAndAdd(mAdense, transposeA, mA1sparseCSC, transposeB, mD);
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2sparseCSC, transposeB, beta, mD);
// copy mD to dense and compare
Matrix<float> mE = Matrix<float>::Zeros(dim1, dim1, c_deviceIdZero);
Matrix<float>::ScaleAndAdd(1, mD, mE);
BOOST_CHECK(mE.IsEqualTo(mC, c_epsilonFloatE4));
}
BOOST_FIXTURE_TEST_CASE(MatrixDenseTimesSparse, RandomSeedFixture)