diff --git a/Source/Math/CPUSparseMatrix.cpp b/Source/Math/CPUSparseMatrix.cpp index 4afd3f4b1..151727f5e 100644 --- a/Source/Math/CPUSparseMatrix.cpp +++ b/Source/Math/CPUSparseMatrix.cpp @@ -244,7 +244,7 @@ void CPUSparseMatrix::SetValue(const CPUSparseMatrix& v) { SetFormat(v.GetFormat()); - RequireSizeAndAllocate(v.GetNumRows(), v.GetNumCols(), v.NzSize()); + RequireSizeAndAllocate(v.GetNumRows(), v.GetNumCols(), v.NzCount() ); // TODO: rename to *Bytes/*Count instead of vague *Size if possible let nz = v.NzCount(); auto matrixFormat = v.GetFormat(); @@ -262,7 +262,8 @@ void CPUSparseMatrix::SetValue(const CPUSparseMatrix& v) } else { - memcpy(GetBlockIds(), v.GetBlockIds(), v.GetBlockSize()); + memcpy(GetBlockIds(), v.GetBlockIds(), v.GetBlockSize() * sizeof(size_t)); // TODO: change block id from size_t to CPUSPARSE_INDEX_TYPE, and rename BlockSize to BlockCount + SetBlockSize(v.GetBlockSize()); } } if (v.m_sliceViewOffset > 0) @@ -644,6 +645,22 @@ void CPUSparseMatrix::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP memcpy(NzValues(), h_Val, sizeof(ElemType)*nz); } +#if 0 // add it back with test +template +void CPUSparseMatrix::SetMatrixFromSBCFormat(const size_t* blockIds, const ElemType* val, const size_t numBlocks, const size_t numRows, const size_t numCols) +{ + if (!OwnBuffer()) + LogicError("Cannot modify since the buffer is managed externally."); + + SetFormat(matrixFormatSparseBlockCol); + Resize(numRows, numCols, numBlocks * numRows); + SetBlockSize(numBlocks); + + memcpy(GetBlockIds(), blockIds, sizeof(size_t)*(numBlocks)); + memcpy(Data(), val, sizeof(ElemType)*numBlocks*numRows); +} +#endif + template ElemType* CPUSparseMatrix::Data() const { @@ -959,8 +976,6 @@ void CPUSparseMatrix::MultiplyAndAdd(ElemType alpha, const CPUMatrix::MultiplyAndAdd(ElemType alpha, const CPUMatrix w2Id; - for (size_t j = 0; j < rhs.GetNumCols(); j++) - { // j ranges over batches - size_t start = rhs.SecondaryIndexLocation()[j]; - size_t end = rhs.SecondaryIndexLocation()[j + 1]; + if (blockSizePrev == 0) + { + c.RequireSizeAndAllocate(m, n, 0, true); // allocate for blockIds + } + + map col2BlockId; + for (size_t blockId = 0; blockId < blockSizePrev; blockId++) + { + col2BlockId[c.GetBlockIds()[blockId]] = blockId; + } + + size_t blockSizeCurr = blockSizePrev; + for (size_t rhsNz = 0; rhsNz < rhs.NzCount(); rhsNz++) + { + size_t resultCol = rhs.MajorIndexLocation()[rhsNz]; + if (col2BlockId.find(resultCol) == col2BlockId.end()) + { + col2BlockId[resultCol] = blockSizeCurr; + c.GetBlockIds()[blockSizeCurr] = resultCol; + blockSizeCurr ++; + } + } + + if (blockSizeCurr > blockSizePrev) + { + c.RequireSizeAndAllocate(m, n, m * blockSizeCurr, true, true); + c.SetBlockSize(blockSizeCurr); + memset(c.Data() + m * blockSizePrev, 0, sizeof(ElemType) * m * (blockSizeCurr - blockSizePrev)); + } + + for (size_t rhsCol = 0; rhsCol < rhs.GetNumCols(); rhsCol++) + { + size_t start = rhs.SecondaryIndexLocation()[rhsCol]; + size_t end = rhs.SecondaryIndexLocation()[rhsCol + 1]; for (size_t p = start; p < end; p++) { - size_t i = rhs.MajorIndexLocation()[p]; // i ranges over words - ElemType val = rhs.Buffer()[p]; // 1 for(i, j) + size_t rhsRow = rhs.MajorIndexLocation()[p]; + ElemType val = rhs.Buffer()[p]; - bool first = true; - if (w2Id.find(i) == w2Id.end()) + ElemType* results = c.Buffer() + col2BlockId[rhsRow] * m; + #pragma omp parallel for + for (int lhsRow = 0; lhsRow < (int)m; lhsRow++) { - size_t id = w2Id.size(); - w2Id[i] = id; - c.GetBlockIds()[c.GetBlockSize()] = i; - c.SetBlockSize(c.GetBlockSize() + 1); - } - else - { - first = false; - } - size_t pos = w2Id[i] * lhs.GetNumRows(); - for (size_t h = 0; h < lhs.GetNumRows(); h++) - { // h range over hidden layer - if (first == true) - { - c.Buffer()[pos] = alpha * lhs(h, j) * val; - } - else - { - c.Buffer()[pos] += alpha * lhs(h, j) * val; - } - pos++; + results[lhsRow] += alpha * lhs((size_t)lhsRow, rhsCol) * val; } } } - if (c.GetBlockSize() * m > c.GetSizeAllocated()) - { - LogicError("Sparse matrix is unexpectedly out of range."); - } } else if (transposeA && !transposeB) { diff --git a/Source/Math/CPUSparseMatrix.h b/Source/Math/CPUSparseMatrix.h index 370edfc52..1dbe682b1 100644 --- a/Source/Math/CPUSparseMatrix.h +++ b/Source/Math/CPUSparseMatrix.h @@ -201,10 +201,19 @@ public: return 0; } - else + else if (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol) { - NOT_IMPLEMENTED; + for (size_t blockId = 0; blockId < GetBlockSize(); blockId++) + { + size_t blockCol = GetBlockIds()[blockId] - GetBlockIdShift(); + if (blockCol == col) + { + return ((ElemType*)Buffer())[blockId * GetNumRows() + row]; + } + } + return 0; } + NOT_IMPLEMENTED; } public: @@ -258,6 +267,11 @@ public: BaseMatrix::SetBlockSize(newBlockSize); } + size_t GetBlockSize() const + { + return BaseMatrix::GetBlockSize(); + } + size_t* BlockIdsLocation() const { if ((GetFormat() != matrixFormatSparseBlockCol) && (GetFormat() != matrixFormatSparseBlockRow)) @@ -325,7 +339,6 @@ public: { return (GetFormat() & matrixFormatRowMajor) ? MajorIndexSize() : SecondaryIndexSize(); } // actual number of bytes in use - }; typedef CPUSparseMatrix CPUSingleSparseMatrix; diff --git a/Source/Math/GPUMatrixCUDAKernels.cuh b/Source/Math/GPUMatrixCUDAKernels.cuh index 62d1fa5aa..002c78113 100644 --- a/Source/Math/GPUMatrixCUDAKernels.cuh +++ b/Source/Math/GPUMatrixCUDAKernels.cuh @@ -3008,6 +3008,10 @@ __global__ void _reshape( newColumnIndex[newNumCols] = oldColumnIndex[oldNumCols]; // set end pointer } +// special markers in BlockId2ColOrRow()/ColOrRow2BlockId() +static const GPUSPARSE_INDEX_TYPE Id_NotAssigned = -1; +static const GPUSPARSE_INDEX_TYPE Id_Pending = INT_MAX; + //called before _determineBlockIds and _denseMulSparseCSCTransposeToSparseBlockCol to determine which columns have values and //what's the mapping from the column id in the resulted SparseBlockCol format to the column id in the dense format //input: rowIndexes: the row indexes of the CSC sparse matrix to be multiplied with @@ -3015,13 +3019,14 @@ __global__ void _reshape( //nnz: number of nonzero value or the size of rowIndexes; template __global__ void _findColsWithValues( - const GPUSPARSE_INDEX_TYPE* rowIndexes, GPUSPARSE_INDEX_TYPE* blockIds, const size_t nnz) + const GPUSPARSE_INDEX_TYPE* rowIndexes, GPUSPARSE_INDEX_TYPE* col2BlockIds, const size_t nnz) { - const size_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= nnz) + const size_t nzIndex = blockIdx.x * blockDim.x + threadIdx.x; + if (nzIndex >= nnz) return; - blockIds[rowIndexes[index]] = 1; // this row has value. + if (col2BlockIds[rowIndexes[nzIndex]] == Id_NotAssigned) + col2BlockIds[rowIndexes[nzIndex]] = Id_Pending; // this row has value. } //called before _denseMulSparseCSCTransposeToSparseBlockCol and after _findColsWithValuesto determine which columns have values and @@ -3030,26 +3035,21 @@ __global__ void _findColsWithValues( //blockId2Col: the blockID to colum id mapping in the resulting matrix; //col2BlockId: the col2BlockId to blockID mapping in the resulting matrix; //numCols: number of columns in the resulting matrix or the size of blockIDs -//blockSize: return the blockSize with values, *blockSize must be zero before passed in. +//blockSize: return the blockSize with values template __global__ void _determineBlockIds( - GPUSPARSE_INDEX_TYPE* blockId2Col, GPUSPARSE_INDEX_TYPE* col2BlockId, const size_t numCols, size_t* blockSize) + GPUSPARSE_INDEX_TYPE* blockId2Col, GPUSPARSE_INDEX_TYPE* col2BlockId, size_t numCols, size_t* blockSize) { - const size_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= numCols) + const size_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= numCols) return; - size_t blockIndex = numCols; - if (blockId2Col[index] > 0) + if (col2BlockId[col] == Id_Pending) { - blockIndex = atomicAdd((unsigned int*) blockSize, (unsigned int) 1); - col2BlockId[index] = blockIndex; + GPUSPARSE_INDEX_TYPE blockIndex = atomicAdd((unsigned int*)blockSize, (unsigned int)1); + col2BlockId[col] = blockIndex; + blockId2Col[blockIndex] = col; } - - __syncthreads(); - - if (blockIndex < numCols) - blockId2Col[blockIndex] = index; } // backward pass from hidden layer to feature weight diff --git a/Source/Math/GPUSparseMatrix.cu b/Source/Math/GPUSparseMatrix.cu index 7fc64632f..4d760aa0e 100644 --- a/Source/Math/GPUSparseMatrix.cu +++ b/Source/Math/GPUSparseMatrix.cu @@ -947,6 +947,42 @@ void GPUSparseMatrix::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP } } +#if 0 // add it back with test +template +void GPUSparseMatrix::SetMatrixFromSBCFormat(const size_t* blockIds, const ElemType* val, const size_t numBlocks, const size_t numRows, const size_t numCols) +{ + VerifyWritable(__FUNCTION__); + + if (blockIds == nullptr || val == nullptr) + LogicError("SetMatrixFromSBCFormat: nullptr passed in."); + + SetFormat(matrixFormatSparseBlockCol); + SetBlockSize(numBlocks); + + if (numBlocks == 0) return; // ====> + + size_t nz = numBlocks * numRows; + RequireSizeAndAllocate(numRows, numCols, nz, true, false); + + static std::vector gpuBlockId2Col(numBlocks); + static std::vector gpuCol2BlockId(numCols); + + std::fill(gpuBlockId2Col.begin(), gpuBlockId2Col.end(), Id_NotAssigned); + std::fill(gpuCol2BlockId.begin(), gpuCol2BlockId.end(), Id_NotAssigned); + + #pragma omp parallel for + for (int i = 0; i < numBlocks; ++i) + { + gpuBlockId2Col[i] = (GPUSPARSE_INDEX_TYPE)blockIds[i]; + gpuCol2BlockId[blockIds[i]] = i; + } + + CUDA_CALL(cudaMemcpy(Data(), val, nz * sizeof(ElemType), cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(BlockId2ColOrRow(), &gpuBlockId2Col[0], numBlocks * sizeof(GPUSPARSE_INDEX_TYPE), cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(ColOrRow2BlockId(), &gpuCol2BlockId[0], numBlocks * sizeof(GPUSPARSE_INDEX_TYPE), cudaMemcpyHostToDevice)); +} +#endif + // this function will allocate memory while the caller needs to release it template void GPUSparseMatrix::GetMatrixFromCSCFormat(GPUSPARSE_INDEX_TYPE*& h_CSCCol, GPUSPARSE_INDEX_TYPE*& h_Row, ElemType*& h_Val, size_t& numElemAllocated, size_t& nz, size_t& numRows, size_t& numCols) const @@ -1223,68 +1259,51 @@ void GPUSparseMatrix::MultiplyAndAdd(ElemType alpha, const GPUMatrix(lhs.GetComputeDeviceId(), 1); - CUDA_CALL(cudaMemset(blockSize, 0, sizeof(size_t))); + size_t* blockSize = TracingGPUMemoryAllocator::Allocate(lhs.GetComputeDeviceId(), 1); + CUDA_CALL(cudaMemcpy(blockSize, &blockSizePrev, sizeof(size_t), cudaMemcpyHostToDevice)); - CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), 0, sizeof(GPUSPARSE_INDEX_TYPE) * (n))); - - blocksPerGrid = (int) ceil(((double) rhs_nz) / GridDim::maxThreadsPerBlock); - _findColsWithValues<<>>( - rhs.RowLocation(), c.BlockId2ColOrRow(), rhs_nz); + blocksPerGrid = (int) ceil(((double) rhs_nz) / GridDim::maxThreadsPerBlock); + _findColsWithValues<<>>( + rhs.RowLocation(), c.ColOrRow2BlockId(), rhs_nz); - blocksPerGrid = (int) ceil(((double) n) / GridDim::maxThreadsPerBlock); - _determineBlockIds<<>>( - c.BlockId2ColOrRow(), c.ColOrRow2BlockId(), n, blockSize); + blocksPerGrid = (int) ceil(((double) n) / GridDim::maxThreadsPerBlock); + _determineBlockIds<<>>( + c.BlockId2ColOrRow(), c.ColOrRow2BlockId(), n, blockSize); - - size_t block = c.GetBlockSize(); - CUDA_CALL(cudaMemcpy(&block, blockSize, sizeof(size_t), cudaMemcpyDeviceToHost)); - TracingGPUMemoryAllocator::Free(lhs.GetComputeDeviceId(), blockSize); - c.SetBlockSize(block); + size_t blockSizeCurr; + CUDA_CALL(cudaMemcpy(&blockSizeCurr, blockSize, sizeof(size_t), cudaMemcpyDeviceToHost)); + TracingGPUMemoryAllocator::Free(lhs.GetComputeDeviceId(), blockSize); + c.SetBlockSize(blockSizeCurr); - size_t nnz = m * c.GetBlockSize(); + if (blockSizeCurr > blockSizePrev) + { + // zero initialize new blocks + size_t nnz = m * blockSizeCurr; c.RequireSizeAndAllocate(m, n, nnz, true, true); // we need to keep the col2blockid and blockid2col info when resizing. - CUDA_CALL(cudaMemset(c.Data(), 0, sizeof(ElemType) * (c.GetSizeAllocated()))); - - LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs) - blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); - _denseMulSparseCSCTransposeToSparseBlockCol2<<>>( - alpha, - lhs.Data(), - m, - l, - rhs.Data(), - rhs.RowLocation(), - rhs.ColLocation(), - c.ColOrRow2BlockId(), - c.Data()); + CUDA_CALL(cudaMemset(c.Data() + m * blockSizePrev, 0, sizeof(ElemType) * m * (blockSizeCurr - blockSizePrev))); } - else - { - c.SetBlockSize(rhs.IdentifyRowsWithValues()); - size_t nnz = m * c.GetBlockSize(); - c.RequireSizeAndAllocate(m, n, nnz, true, false); - CUDA_CALL(cudaMemset(c.Data(), 0, sizeof(ElemType) * (c.GetSizeAllocated()))); - CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), 0, sizeof(GPUSPARSE_INDEX_TYPE) * (c.GetBlockSize()))); - LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs) - blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); - _denseMulSparseCSCTransposeToSparseBlockCol<<>>( - alpha, - lhs.Data(), - m, - l, - rhs.Data(), - rhs.RowLocation(), - rhs.ColLocation(), - rhs.GetTempDeviceBuffer(), - c.Data(), - c.BlockId2ColOrRow()); - } + LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs) + blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); + _denseMulSparseCSCTransposeToSparseBlockCol2<<>>( + alpha, + lhs.Data(), + m, + l, + rhs.Data(), + rhs.RowLocation(), + rhs.ColLocation(), + c.ColOrRow2BlockId(), + c.Data()); } else if (transposeA && !transposeB) { diff --git a/Source/Math/Matrix.cpp b/Source/Math/Matrix.cpp index 72b027c05..bc64817b3 100644 --- a/Source/Math/Matrix.cpp +++ b/Source/Math/Matrix.cpp @@ -4544,7 +4544,18 @@ void Matrix::MultiplyAndWeightedAdd(ElemType alpha, const Matrix SPARSE { - CPUSparseMatrix::MultiplyAndAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUSparseMatrix, transposeB, *c.m_CPUSparseMatrix); + if (beta != 0 && beta != 1) + { + NOT_IMPLEMENTED; + } + else + { + if (beta == 0) + { + c.Reset(); + } + CPUSparseMatrix::MultiplyAndAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUSparseMatrix, transposeB, *c.m_CPUSparseMatrix); + } c.SetDataLocation(CPU, SPARSE); } else @@ -4578,7 +4589,18 @@ void Matrix::MultiplyAndWeightedAdd(ElemType alpha, const Matrix SPARSE { - GPUSparseMatrix::MultiplyAndAdd(alpha, *a.m_GPUMatrix, transposeA, *b.m_GPUSparseMatrix, transposeB, *c.m_GPUSparseMatrix); + if (beta != 0 && beta != 1) + { + NOT_IMPLEMENTED; + } + else + { + if (beta == 0) + { + c.Reset(); + } + GPUSparseMatrix::MultiplyAndAdd(alpha, *a.m_GPUMatrix, transposeA, *b.m_GPUSparseMatrix, transposeB, *c.m_GPUSparseMatrix); + } c.SetDataLocation(GPU, SPARSE); } else if (a.m_matrixType == MatrixType::SPARSE && b.m_matrixType == MatrixType::SPARSE && c.m_matrixType == MatrixType::SPARSE) // GPU, SPARSE * SPARSE -> SPARSE diff --git a/Source/Math/NoGPU.cpp b/Source/Math/NoGPU.cpp index 9eb21cc9a..f1ed021ad 100644 --- a/Source/Math/NoGPU.cpp +++ b/Source/Math/NoGPU.cpp @@ -212,6 +212,13 @@ void GPUSparseMatrix::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP { } +#if 0 +template +void GPUSparseMatrix::SetMatrixFromSBCFormat(const size_t*, const ElemType*, const size_t, const size_t, const size_t) +{ +} +#endif + // forward pass from feature to hidden layer template void GPUSparseMatrix::MultiplyAndWeightedAdd(ElemType alpha, const GPUMatrix& lhs, const bool transposeA, diff --git a/Tests/EndToEndTests/UnitTests/MathTests/baseline.txt b/Tests/EndToEndTests/UnitTests/MathTests/baseline.txt index 677c458cd..3aa31c0dc 100644 --- a/Tests/EndToEndTests/UnitTests/MathTests/baseline.txt +++ b/Tests/EndToEndTests/UnitTests/MathTests/baseline.txt @@ -6051,8 +6051,8 @@ Test module "MathTests" has passed with: 770 assertions out of 770 passed Test suite "CPUMatrixSuite" has passed with: - 30 test cases out of 30 passed - 5228662 assertions out of 5228662 passed + 31 test cases out of 31 passed + 5248662 assertions out of 5248662 passed Test case "CPUMatrixSuite/CPUSparseMatrixColumnSlice" has passed with: 1 assertion out of 1 passed @@ -6060,6 +6060,9 @@ Test module "MathTests" has passed with: Test case "CPUMatrixSuite/CPUSparseMatrixCopyColumnSliceToDense" has passed with: 1 assertion out of 1 passed + Test case "CPUMatrixSuite/CPUSparseMatrixMultiplyAndAdd" has passed with: + 20000 assertions out of 20000 passed + Test case "CPUMatrixSuite/MatrixMultiplyTest" has passed with: 1484 assertions out of 1484 passed @@ -6146,7 +6149,7 @@ Test module "MathTests" has passed with: Test suite "GPUMatrixSuite" has passed with: 56 test cases out of 56 passed - 1263760 assertions out of 1263760 passed + 1263762 assertions out of 1263762 passed Test case "GPUMatrixSuite/GPUBlasMultiplyAndWeightedAdd" has passed with: 1 assertion out of 1 passed @@ -6290,7 +6293,7 @@ Test module "MathTests" has passed with: 4 assertions out of 4 passed Test case "GPUMatrixSuite/MatrixSparseTimesDense" has passed with: - 1 assertion out of 1 passed + 3 assertions out of 3 passed Test case "GPUMatrixSuite/MatrixDenseTimesSparse" has passed with: 2 assertions out of 2 passed @@ -6398,13 +6401,13 @@ Test module "MathTests" has passed with: 2 assertions out of 2 passed Test suite "QuantizersUnitTests" has passed with: - 2 test case out of 2 passed + 2 test cases out of 2 passed 12 assertions out of 12 passed Test case "QuantizersUnitTests/FloatToShort" has passed with: 6 assertions out of 6 passed - - Test case "QuantizersUnitTests/QuantizeZeros" has passed with: + + Test case "QuantizersUnitTests/QuantizeZeros" has passed with: 6 assertions out of 6 passed Test suite "QuantizedOperationsUnitTests" has passed with: @@ -6413,7 +6416,7 @@ Test module "MathTests" has passed with: Test case "QuantizedOperationsUnitTests/MultiplyIntToShort" has passed with: 65 assertions out of 65 passed - + Test suite "MathTensorTests" has passed with: 9 test cases out of 9 passed 6 assertions out of 6 passed diff --git a/Tests/UnitTests/MathTests/CPUSparseMatrixTests.cpp b/Tests/UnitTests/MathTests/CPUSparseMatrixTests.cpp index 8f8206bf2..3ae8639a3 100644 --- a/Tests/UnitTests/MathTests/CPUSparseMatrixTests.cpp +++ b/Tests/UnitTests/MathTests/CPUSparseMatrixTests.cpp @@ -61,6 +61,61 @@ BOOST_FIXTURE_TEST_CASE(CPUSparseMatrixCopyColumnSliceToDense, RandomSeedFixture BOOST_CHECK(dm1.IsEqualTo(dm2, c_epsilonFloatE4)); } +BOOST_FIXTURE_TEST_CASE(CPUSparseMatrixMultiplyAndAdd, RandomSeedFixture) +{ + const size_t m = 100; + const size_t n = 50; + + DenseMatrix dm0(m, n); + dm0.SetUniformRandomValue(-1, 1, IncrementCounter()); + + DenseMatrix dm1(m, n); + dm1.SetUniformRandomValue(-300, 1, IncrementCounter()); + dm1.InplaceTruncateBottom(0); + + SparseMatrix sm1(MatrixFormat::matrixFormatSparseCSC, m, n, 0); + foreach_coord(row, col, dm1) + { + if (dm1(row, col) != 0) + { + sm1.SetValue(row, col, dm1(row, col)); + } + } + + DenseMatrix dm2(m, n); + dm2.SetUniformRandomValue(-200, 1, IncrementCounter()); + dm2.InplaceTruncateBottom(0); + + SparseMatrix sm2(MatrixFormat::matrixFormatSparseCSC, m, n, 0); + foreach_coord(row, col, dm2) + { + if (dm2(row, col) != 0) + { + sm2.SetValue(row, col, dm2(row, col)); + } + } + + // generate SparseBlockCol matrix + DenseMatrix dmMul(m, m); + DenseMatrix::MultiplyAndAdd(dm0, false, dm1, true, dmMul); + + SparseMatrix smMul(MatrixFormat::matrixFormatSparseBlockCol, m, m, 0); + SparseMatrix::MultiplyAndAdd(1, dm0, false, sm1, true, smMul); + + foreach_coord(row, col, dmMul) + { + BOOST_CHECK(abs(smMul(row, col) - dmMul(row, col)) < c_epsilonFloatE4); + } + + SparseMatrix::MultiplyAndAdd(1, dm0, false, sm2, true, smMul); + DenseMatrix::MultiplyAndAdd(dm0, false, dm2, true, dmMul); + + foreach_coord(row, col, dmMul) + { + BOOST_CHECK(abs(smMul(row, col) - dmMul(row, col)) < c_epsilonFloatE4); + } +} + BOOST_AUTO_TEST_SUITE_END() } } } } diff --git a/Tests/UnitTests/MathTests/MatrixSparseDenseInteractionsTests.cpp b/Tests/UnitTests/MathTests/MatrixSparseDenseInteractionsTests.cpp index fab6b5de0..6d80283cf 100644 --- a/Tests/UnitTests/MathTests/MatrixSparseDenseInteractionsTests.cpp +++ b/Tests/UnitTests/MathTests/MatrixSparseDenseInteractionsTests.cpp @@ -44,7 +44,7 @@ BOOST_FIXTURE_TEST_CASE(MatrixSparseTimesDense, RandomSeedFixture) { // DENSE Matrix mAdense(c_deviceIdZero); - mAdense.AssignTruncateBottomOf(Matrix::RandomUniform(dim1, dim2, c_deviceIdZero, -3.0f, 0.1f, IncrementCounter()), 0); + mAdense.AssignTruncateBottomOf(Matrix::RandomUniform(dim1, dim2, c_deviceIdZero, -300.0f, 0.1f, IncrementCounter()), 0); // MATRIX mAsparse becomes sparse Matrix mAsparse(mAdense.DeepClone()); @@ -66,6 +66,44 @@ BOOST_FIXTURE_TEST_CASE(MatrixSparseTimesDense, RandomSeedFixture) Matrix::MultiplyAndWeightedAdd(alpha, mAsparse, transposeA, mB, transposeB, beta, mD); BOOST_CHECK(mD.IsEqualTo(mC, c_epsilonFloatE4)); + + // SPARSE * DENSE -> SPARSE + beta = 1; // note that dense only allow beta == 1, while sparse CPU support arbitrary beta + transposeA = false; + transposeB = true; // only support !transposeA && tranposeB for Dense * CSC -> SBC + + Matrix mA1sparseCSC(mAdense.DeepClone()); + mA1sparseCSC.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC, true); + + Matrix mA2dense(c_deviceIdZero); + mA2dense.AssignTruncateBottomOf(Matrix::RandomUniform(dim1, dim2, c_deviceIdZero, -300.0f, 0.1f, IncrementCounter()), 0); + + Matrix mA2sparseCSC(mA2dense.DeepClone()); + mA2sparseCSC.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC, true); + + // dense for comparison + mC.Resize(dim1, dim1); + mC.SetValue(0.0f); + Matrix::MultiplyAndAdd(mAdense, transposeA, mAdense, transposeB, mC); + Matrix::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2dense, transposeB, beta, mC); + + // make sure (dense * sparse -> dense) == (dense * dense -> dense) + mD.Resize(dim1, dim1); + mD.SetValue(0.0f); + Matrix::MultiplyAndAdd(mAdense, transposeA, mA1sparseCSC, transposeB, mD); + Matrix::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2sparseCSC, transposeB, beta, mD); + + BOOST_CHECK(mD.IsEqualTo(mC, c_epsilonFloatE4)); + + // test on sparse + mD.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseBlockCol, false); + Matrix::MultiplyAndAdd(mAdense, transposeA, mA1sparseCSC, transposeB, mD); + Matrix::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2sparseCSC, transposeB, beta, mD); + + // copy mD to dense and compare + Matrix mE = Matrix::Zeros(dim1, dim1, c_deviceIdZero); + Matrix::ScaleAndAdd(1, mD, mE); + BOOST_CHECK(mE.IsEqualTo(mC, c_epsilonFloatE4)); } BOOST_FIXTURE_TEST_CASE(MatrixDenseTimesSparse, RandomSeedFixture)