Fix sparse matrix MultiplyAndAdd bug that old value may not be kept
This commit is contained in:
Родитель
e7edfa72e8
Коммит
c12979b245
|
@ -244,7 +244,7 @@ void CPUSparseMatrix<ElemType>::SetValue(const CPUSparseMatrix<ElemType>& v)
|
|||
{
|
||||
SetFormat(v.GetFormat());
|
||||
|
||||
RequireSizeAndAllocate(v.GetNumRows(), v.GetNumCols(), v.NzSize());
|
||||
RequireSizeAndAllocate(v.GetNumRows(), v.GetNumCols(), v.NzCount() ); // TODO: rename to *Bytes/*Count instead of vague *Size if possible
|
||||
let nz = v.NzCount();
|
||||
|
||||
auto matrixFormat = v.GetFormat();
|
||||
|
@ -262,7 +262,8 @@ void CPUSparseMatrix<ElemType>::SetValue(const CPUSparseMatrix<ElemType>& v)
|
|||
}
|
||||
else
|
||||
{
|
||||
memcpy(GetBlockIds(), v.GetBlockIds(), v.GetBlockSize());
|
||||
memcpy(GetBlockIds(), v.GetBlockIds(), v.GetBlockSize() * sizeof(size_t)); // TODO: change block id from size_t to CPUSPARSE_INDEX_TYPE, and rename BlockSize to BlockCount
|
||||
SetBlockSize(v.GetBlockSize());
|
||||
}
|
||||
}
|
||||
if (v.m_sliceViewOffset > 0)
|
||||
|
@ -644,6 +645,22 @@ void CPUSparseMatrix<ElemType>::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP
|
|||
memcpy(NzValues(), h_Val, sizeof(ElemType)*nz);
|
||||
}
|
||||
|
||||
#if 0 // add it back with test
|
||||
template <class ElemType>
|
||||
void CPUSparseMatrix<ElemType>::SetMatrixFromSBCFormat(const size_t* blockIds, const ElemType* val, const size_t numBlocks, const size_t numRows, const size_t numCols)
|
||||
{
|
||||
if (!OwnBuffer())
|
||||
LogicError("Cannot modify since the buffer is managed externally.");
|
||||
|
||||
SetFormat(matrixFormatSparseBlockCol);
|
||||
Resize(numRows, numCols, numBlocks * numRows);
|
||||
SetBlockSize(numBlocks);
|
||||
|
||||
memcpy(GetBlockIds(), blockIds, sizeof(size_t)*(numBlocks));
|
||||
memcpy(Data(), val, sizeof(ElemType)*numBlocks*numRows);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class ElemType>
|
||||
ElemType* CPUSparseMatrix<ElemType>::Data() const
|
||||
{
|
||||
|
@ -959,8 +976,6 @@ void CPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const CPUMatrix<E
|
|||
InvalidArgument("CPUSparseMatrix::MultiplyAndAdd: The inner dimensions of a (= %lu) and b (= %lu) don't match.", k, l);
|
||||
}
|
||||
|
||||
c.Reset();
|
||||
|
||||
if (!transposeA && !transposeB)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
|
@ -972,50 +987,56 @@ void CPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const CPUMatrix<E
|
|||
|
||||
// allocate enough memory
|
||||
c.SetFormat(matrixFormatSparseBlockCol);
|
||||
c.RequireSizeAndAllocate(m, n, m * min(n, rhs.NzCount()), true, false);
|
||||
size_t blockSizePrev = c.GetBlockSize();
|
||||
|
||||
map<size_t, size_t> w2Id;
|
||||
for (size_t j = 0; j < rhs.GetNumCols(); j++)
|
||||
{ // j ranges over batches
|
||||
size_t start = rhs.SecondaryIndexLocation()[j];
|
||||
size_t end = rhs.SecondaryIndexLocation()[j + 1];
|
||||
if (blockSizePrev == 0)
|
||||
{
|
||||
c.RequireSizeAndAllocate(m, n, 0, true); // allocate for blockIds
|
||||
}
|
||||
|
||||
map<size_t, size_t> col2BlockId;
|
||||
for (size_t blockId = 0; blockId < blockSizePrev; blockId++)
|
||||
{
|
||||
col2BlockId[c.GetBlockIds()[blockId]] = blockId;
|
||||
}
|
||||
|
||||
size_t blockSizeCurr = blockSizePrev;
|
||||
for (size_t rhsNz = 0; rhsNz < rhs.NzCount(); rhsNz++)
|
||||
{
|
||||
size_t resultCol = rhs.MajorIndexLocation()[rhsNz];
|
||||
if (col2BlockId.find(resultCol) == col2BlockId.end())
|
||||
{
|
||||
col2BlockId[resultCol] = blockSizeCurr;
|
||||
c.GetBlockIds()[blockSizeCurr] = resultCol;
|
||||
blockSizeCurr ++;
|
||||
}
|
||||
}
|
||||
|
||||
if (blockSizeCurr > blockSizePrev)
|
||||
{
|
||||
c.RequireSizeAndAllocate(m, n, m * blockSizeCurr, true, true);
|
||||
c.SetBlockSize(blockSizeCurr);
|
||||
memset(c.Data() + m * blockSizePrev, 0, sizeof(ElemType) * m * (blockSizeCurr - blockSizePrev));
|
||||
}
|
||||
|
||||
for (size_t rhsCol = 0; rhsCol < rhs.GetNumCols(); rhsCol++)
|
||||
{
|
||||
size_t start = rhs.SecondaryIndexLocation()[rhsCol];
|
||||
size_t end = rhs.SecondaryIndexLocation()[rhsCol + 1];
|
||||
|
||||
for (size_t p = start; p < end; p++)
|
||||
{
|
||||
size_t i = rhs.MajorIndexLocation()[p]; // i ranges over words
|
||||
ElemType val = rhs.Buffer()[p]; // 1 for(i, j)
|
||||
size_t rhsRow = rhs.MajorIndexLocation()[p];
|
||||
ElemType val = rhs.Buffer()[p];
|
||||
|
||||
bool first = true;
|
||||
if (w2Id.find(i) == w2Id.end())
|
||||
ElemType* results = c.Buffer() + col2BlockId[rhsRow] * m;
|
||||
#pragma omp parallel for
|
||||
for (int lhsRow = 0; lhsRow < (int)m; lhsRow++)
|
||||
{
|
||||
size_t id = w2Id.size();
|
||||
w2Id[i] = id;
|
||||
c.GetBlockIds()[c.GetBlockSize()] = i;
|
||||
c.SetBlockSize(c.GetBlockSize() + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
first = false;
|
||||
}
|
||||
size_t pos = w2Id[i] * lhs.GetNumRows();
|
||||
for (size_t h = 0; h < lhs.GetNumRows(); h++)
|
||||
{ // h range over hidden layer
|
||||
if (first == true)
|
||||
{
|
||||
c.Buffer()[pos] = alpha * lhs(h, j) * val;
|
||||
}
|
||||
else
|
||||
{
|
||||
c.Buffer()[pos] += alpha * lhs(h, j) * val;
|
||||
}
|
||||
pos++;
|
||||
results[lhsRow] += alpha * lhs((size_t)lhsRow, rhsCol) * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (c.GetBlockSize() * m > c.GetSizeAllocated())
|
||||
{
|
||||
LogicError("Sparse matrix is unexpectedly out of range.");
|
||||
}
|
||||
}
|
||||
else if (transposeA && !transposeB)
|
||||
{
|
||||
|
|
|
@ -201,10 +201,19 @@ public:
|
|||
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
else if (GetFormat() == MatrixFormat::matrixFormatSparseBlockCol)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
for (size_t blockId = 0; blockId < GetBlockSize(); blockId++)
|
||||
{
|
||||
size_t blockCol = GetBlockIds()[blockId] - GetBlockIdShift();
|
||||
if (blockCol == col)
|
||||
{
|
||||
return ((ElemType*)Buffer())[blockId * GetNumRows() + row];
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -258,6 +267,11 @@ public:
|
|||
BaseMatrix<ElemType>::SetBlockSize(newBlockSize);
|
||||
}
|
||||
|
||||
size_t GetBlockSize() const
|
||||
{
|
||||
return BaseMatrix<ElemType>::GetBlockSize();
|
||||
}
|
||||
|
||||
size_t* BlockIdsLocation() const
|
||||
{
|
||||
if ((GetFormat() != matrixFormatSparseBlockCol) && (GetFormat() != matrixFormatSparseBlockRow))
|
||||
|
@ -325,7 +339,6 @@ public:
|
|||
{
|
||||
return (GetFormat() & matrixFormatRowMajor) ? MajorIndexSize() : SecondaryIndexSize();
|
||||
} // actual number of bytes in use
|
||||
|
||||
};
|
||||
|
||||
typedef CPUSparseMatrix<float> CPUSingleSparseMatrix;
|
||||
|
|
|
@ -3008,6 +3008,10 @@ __global__ void _reshape(
|
|||
newColumnIndex[newNumCols] = oldColumnIndex[oldNumCols]; // set end pointer
|
||||
}
|
||||
|
||||
// special markers in BlockId2ColOrRow()/ColOrRow2BlockId()
|
||||
static const GPUSPARSE_INDEX_TYPE Id_NotAssigned = -1;
|
||||
static const GPUSPARSE_INDEX_TYPE Id_Pending = INT_MAX;
|
||||
|
||||
//called before _determineBlockIds and _denseMulSparseCSCTransposeToSparseBlockCol to determine which columns have values and
|
||||
//what's the mapping from the column id in the resulted SparseBlockCol format to the column id in the dense format
|
||||
//input: rowIndexes: the row indexes of the CSC sparse matrix to be multiplied with
|
||||
|
@ -3015,13 +3019,14 @@ __global__ void _reshape(
|
|||
//nnz: number of nonzero value or the size of rowIndexes;
|
||||
template <class ElemType>
|
||||
__global__ void _findColsWithValues(
|
||||
const GPUSPARSE_INDEX_TYPE* rowIndexes, GPUSPARSE_INDEX_TYPE* blockIds, const size_t nnz)
|
||||
const GPUSPARSE_INDEX_TYPE* rowIndexes, GPUSPARSE_INDEX_TYPE* col2BlockIds, const size_t nnz)
|
||||
{
|
||||
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index >= nnz)
|
||||
const size_t nzIndex = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (nzIndex >= nnz)
|
||||
return;
|
||||
|
||||
blockIds[rowIndexes[index]] = 1; // this row has value.
|
||||
if (col2BlockIds[rowIndexes[nzIndex]] == Id_NotAssigned)
|
||||
col2BlockIds[rowIndexes[nzIndex]] = Id_Pending; // this row has value.
|
||||
}
|
||||
|
||||
//called before _denseMulSparseCSCTransposeToSparseBlockCol and after _findColsWithValuesto determine which columns have values and
|
||||
|
@ -3030,26 +3035,21 @@ __global__ void _findColsWithValues(
|
|||
//blockId2Col: the blockID to colum id mapping in the resulting matrix;
|
||||
//col2BlockId: the col2BlockId to blockID mapping in the resulting matrix;
|
||||
//numCols: number of columns in the resulting matrix or the size of blockIDs
|
||||
//blockSize: return the blockSize with values, *blockSize must be zero before passed in.
|
||||
//blockSize: return the blockSize with values
|
||||
template <class ElemType>
|
||||
__global__ void _determineBlockIds(
|
||||
GPUSPARSE_INDEX_TYPE* blockId2Col, GPUSPARSE_INDEX_TYPE* col2BlockId, const size_t numCols, size_t* blockSize)
|
||||
GPUSPARSE_INDEX_TYPE* blockId2Col, GPUSPARSE_INDEX_TYPE* col2BlockId, size_t numCols, size_t* blockSize)
|
||||
{
|
||||
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index >= numCols)
|
||||
const size_t col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (col >= numCols)
|
||||
return;
|
||||
|
||||
size_t blockIndex = numCols;
|
||||
if (blockId2Col[index] > 0)
|
||||
if (col2BlockId[col] == Id_Pending)
|
||||
{
|
||||
blockIndex = atomicAdd((unsigned int*) blockSize, (unsigned int) 1);
|
||||
col2BlockId[index] = blockIndex;
|
||||
GPUSPARSE_INDEX_TYPE blockIndex = atomicAdd((unsigned int*)blockSize, (unsigned int)1);
|
||||
col2BlockId[col] = blockIndex;
|
||||
blockId2Col[blockIndex] = col;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (blockIndex < numCols)
|
||||
blockId2Col[blockIndex] = index;
|
||||
}
|
||||
|
||||
// backward pass from hidden layer to feature weight
|
||||
|
|
|
@ -947,6 +947,42 @@ void GPUSparseMatrix<ElemType>::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP
|
|||
}
|
||||
}
|
||||
|
||||
#if 0 // add it back with test
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::SetMatrixFromSBCFormat(const size_t* blockIds, const ElemType* val, const size_t numBlocks, const size_t numRows, const size_t numCols)
|
||||
{
|
||||
VerifyWritable(__FUNCTION__);
|
||||
|
||||
if (blockIds == nullptr || val == nullptr)
|
||||
LogicError("SetMatrixFromSBCFormat: nullptr passed in.");
|
||||
|
||||
SetFormat(matrixFormatSparseBlockCol);
|
||||
SetBlockSize(numBlocks);
|
||||
|
||||
if (numBlocks == 0) return; // ====>
|
||||
|
||||
size_t nz = numBlocks * numRows;
|
||||
RequireSizeAndAllocate(numRows, numCols, nz, true, false);
|
||||
|
||||
static std::vector<GPUSPARSE_INDEX_TYPE> gpuBlockId2Col(numBlocks);
|
||||
static std::vector<GPUSPARSE_INDEX_TYPE> gpuCol2BlockId(numCols);
|
||||
|
||||
std::fill(gpuBlockId2Col.begin(), gpuBlockId2Col.end(), Id_NotAssigned);
|
||||
std::fill(gpuCol2BlockId.begin(), gpuCol2BlockId.end(), Id_NotAssigned);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < numBlocks; ++i)
|
||||
{
|
||||
gpuBlockId2Col[i] = (GPUSPARSE_INDEX_TYPE)blockIds[i];
|
||||
gpuCol2BlockId[blockIds[i]] = i;
|
||||
}
|
||||
|
||||
CUDA_CALL(cudaMemcpy(Data(), val, nz * sizeof(ElemType), cudaMemcpyHostToDevice));
|
||||
CUDA_CALL(cudaMemcpy(BlockId2ColOrRow(), &gpuBlockId2Col[0], numBlocks * sizeof(GPUSPARSE_INDEX_TYPE), cudaMemcpyHostToDevice));
|
||||
CUDA_CALL(cudaMemcpy(ColOrRow2BlockId(), &gpuCol2BlockId[0], numBlocks * sizeof(GPUSPARSE_INDEX_TYPE), cudaMemcpyHostToDevice));
|
||||
}
|
||||
#endif
|
||||
|
||||
// this function will allocate memory while the caller needs to release it
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::GetMatrixFromCSCFormat(GPUSPARSE_INDEX_TYPE*& h_CSCCol, GPUSPARSE_INDEX_TYPE*& h_Row, ElemType*& h_Val, size_t& numElemAllocated, size_t& nz, size_t& numRows, size_t& numCols) const
|
||||
|
@ -1223,68 +1259,51 @@ void GPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const GPUMatrix<E
|
|||
|
||||
// based on the size of m_nz in rhs and numCols in the resulted matrix we use different approaches
|
||||
size_t rhs_nz = rhs.NzCount();
|
||||
if (n * 10 < GridDim::maxThreadsPerBlock * rhs_nz)
|
||||
|
||||
size_t blockSizePrev = c.GetBlockSize();
|
||||
if (blockSizePrev == 0)
|
||||
{
|
||||
c.RequireSizeAndAllocate(m, n, 1, true, false); // reserve memory for BlockId2ColOrRow() and ColOrRow2BlockId()
|
||||
c.Resize(m, n, 0);
|
||||
CUDA_CALL(cudaMemset(c.ColOrRow2BlockId(), Id_NotAssigned, sizeof(GPUSPARSE_INDEX_TYPE) * (n)));
|
||||
CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), Id_NotAssigned, sizeof(GPUSPARSE_INDEX_TYPE) * (n)));
|
||||
}
|
||||
|
||||
size_t* blockSize = TracingGPUMemoryAllocator::Allocate<size_t>(lhs.GetComputeDeviceId(), 1);
|
||||
CUDA_CALL(cudaMemset(blockSize, 0, sizeof(size_t)));
|
||||
size_t* blockSize = TracingGPUMemoryAllocator::Allocate<size_t>(lhs.GetComputeDeviceId(), 1);
|
||||
CUDA_CALL(cudaMemcpy(blockSize, &blockSizePrev, sizeof(size_t), cudaMemcpyHostToDevice));
|
||||
|
||||
CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), 0, sizeof(GPUSPARSE_INDEX_TYPE) * (n)));
|
||||
|
||||
blocksPerGrid = (int) ceil(((double) rhs_nz) / GridDim::maxThreadsPerBlock);
|
||||
_findColsWithValues<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
rhs.RowLocation(), c.BlockId2ColOrRow(), rhs_nz);
|
||||
blocksPerGrid = (int) ceil(((double) rhs_nz) / GridDim::maxThreadsPerBlock);
|
||||
_findColsWithValues<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
rhs.RowLocation(), c.ColOrRow2BlockId(), rhs_nz);
|
||||
|
||||
blocksPerGrid = (int) ceil(((double) n) / GridDim::maxThreadsPerBlock);
|
||||
_determineBlockIds<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
c.BlockId2ColOrRow(), c.ColOrRow2BlockId(), n, blockSize);
|
||||
blocksPerGrid = (int) ceil(((double) n) / GridDim::maxThreadsPerBlock);
|
||||
_determineBlockIds<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
c.BlockId2ColOrRow(), c.ColOrRow2BlockId(), n, blockSize);
|
||||
|
||||
|
||||
size_t block = c.GetBlockSize();
|
||||
CUDA_CALL(cudaMemcpy(&block, blockSize, sizeof(size_t), cudaMemcpyDeviceToHost));
|
||||
TracingGPUMemoryAllocator::Free<size_t>(lhs.GetComputeDeviceId(), blockSize);
|
||||
c.SetBlockSize(block);
|
||||
size_t blockSizeCurr;
|
||||
CUDA_CALL(cudaMemcpy(&blockSizeCurr, blockSize, sizeof(size_t), cudaMemcpyDeviceToHost));
|
||||
TracingGPUMemoryAllocator::Free<size_t>(lhs.GetComputeDeviceId(), blockSize);
|
||||
c.SetBlockSize(blockSizeCurr);
|
||||
|
||||
size_t nnz = m * c.GetBlockSize();
|
||||
if (blockSizeCurr > blockSizePrev)
|
||||
{
|
||||
// zero initialize new blocks
|
||||
size_t nnz = m * blockSizeCurr;
|
||||
c.RequireSizeAndAllocate(m, n, nnz, true, true); // we need to keep the col2blockid and blockid2col info when resizing.
|
||||
CUDA_CALL(cudaMemset(c.Data(), 0, sizeof(ElemType) * (c.GetSizeAllocated())));
|
||||
|
||||
LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs)
|
||||
blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock);
|
||||
_denseMulSparseCSCTransposeToSparseBlockCol2<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
alpha,
|
||||
lhs.Data(),
|
||||
m,
|
||||
l,
|
||||
rhs.Data(),
|
||||
rhs.RowLocation(),
|
||||
rhs.ColLocation(),
|
||||
c.ColOrRow2BlockId(),
|
||||
c.Data());
|
||||
CUDA_CALL(cudaMemset(c.Data() + m * blockSizePrev, 0, sizeof(ElemType) * m * (blockSizeCurr - blockSizePrev)));
|
||||
}
|
||||
else
|
||||
{
|
||||
c.SetBlockSize(rhs.IdentifyRowsWithValues());
|
||||
size_t nnz = m * c.GetBlockSize();
|
||||
c.RequireSizeAndAllocate(m, n, nnz, true, false);
|
||||
CUDA_CALL(cudaMemset(c.Data(), 0, sizeof(ElemType) * (c.GetSizeAllocated())));
|
||||
CUDA_CALL(cudaMemset(c.BlockId2ColOrRow(), 0, sizeof(GPUSPARSE_INDEX_TYPE) * (c.GetBlockSize())));
|
||||
|
||||
LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs)
|
||||
blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock);
|
||||
_denseMulSparseCSCTransposeToSparseBlockCol<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
alpha,
|
||||
lhs.Data(),
|
||||
m,
|
||||
l,
|
||||
rhs.Data(),
|
||||
rhs.RowLocation(),
|
||||
rhs.ColLocation(),
|
||||
rhs.GetTempDeviceBuffer(),
|
||||
c.Data(),
|
||||
c.BlockId2ColOrRow());
|
||||
}
|
||||
LONG64 N = (LONG64) lhs.GetNumElements(); // here we process for each row in lhs and each column in rhs (==columns in lhs)
|
||||
blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock);
|
||||
_denseMulSparseCSCTransposeToSparseBlockCol2<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
alpha,
|
||||
lhs.Data(),
|
||||
m,
|
||||
l,
|
||||
rhs.Data(),
|
||||
rhs.RowLocation(),
|
||||
rhs.ColLocation(),
|
||||
c.ColOrRow2BlockId(),
|
||||
c.Data());
|
||||
}
|
||||
else if (transposeA && !transposeB)
|
||||
{
|
||||
|
|
|
@ -4544,7 +4544,18 @@ void Matrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemT
|
|||
}
|
||||
else if (c.GetMatrixType() == MatrixType::SPARSE) // CPU, DENSE * SPARSE -> SPARSE
|
||||
{
|
||||
CPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUSparseMatrix, transposeB, *c.m_CPUSparseMatrix);
|
||||
if (beta != 0 && beta != 1)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (beta == 0)
|
||||
{
|
||||
c.Reset();
|
||||
}
|
||||
CPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUSparseMatrix, transposeB, *c.m_CPUSparseMatrix);
|
||||
}
|
||||
c.SetDataLocation(CPU, SPARSE);
|
||||
}
|
||||
else
|
||||
|
@ -4578,7 +4589,18 @@ void Matrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemT
|
|||
}
|
||||
else if (a.m_matrixType == MatrixType::DENSE && b.m_matrixType == MatrixType::SPARSE && c.m_matrixType == MatrixType::SPARSE) // GPU, DENSE * SPARSE -> SPARSE
|
||||
{
|
||||
GPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_GPUMatrix, transposeA, *b.m_GPUSparseMatrix, transposeB, *c.m_GPUSparseMatrix);
|
||||
if (beta != 0 && beta != 1)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (beta == 0)
|
||||
{
|
||||
c.Reset();
|
||||
}
|
||||
GPUSparseMatrix<ElemType>::MultiplyAndAdd(alpha, *a.m_GPUMatrix, transposeA, *b.m_GPUSparseMatrix, transposeB, *c.m_GPUSparseMatrix);
|
||||
}
|
||||
c.SetDataLocation(GPU, SPARSE);
|
||||
}
|
||||
else if (a.m_matrixType == MatrixType::SPARSE && b.m_matrixType == MatrixType::SPARSE && c.m_matrixType == MatrixType::SPARSE) // GPU, SPARSE * SPARSE -> SPARSE
|
||||
|
|
|
@ -212,6 +212,13 @@ void GPUSparseMatrix<ElemType>::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP
|
|||
{
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::SetMatrixFromSBCFormat(const size_t*, const ElemType*, const size_t, const size_t, const size_t)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
// forward pass from feature to hidden layer
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& lhs, const bool transposeA,
|
||||
|
|
|
@ -6051,8 +6051,8 @@ Test module "MathTests" has passed with:
|
|||
770 assertions out of 770 passed
|
||||
|
||||
Test suite "CPUMatrixSuite" has passed with:
|
||||
30 test cases out of 30 passed
|
||||
5228662 assertions out of 5228662 passed
|
||||
31 test cases out of 31 passed
|
||||
5248662 assertions out of 5248662 passed
|
||||
|
||||
Test case "CPUMatrixSuite/CPUSparseMatrixColumnSlice" has passed with:
|
||||
1 assertion out of 1 passed
|
||||
|
@ -6060,6 +6060,9 @@ Test module "MathTests" has passed with:
|
|||
Test case "CPUMatrixSuite/CPUSparseMatrixCopyColumnSliceToDense" has passed with:
|
||||
1 assertion out of 1 passed
|
||||
|
||||
Test case "CPUMatrixSuite/CPUSparseMatrixMultiplyAndAdd" has passed with:
|
||||
20000 assertions out of 20000 passed
|
||||
|
||||
Test case "CPUMatrixSuite/MatrixMultiplyTest" has passed with:
|
||||
1484 assertions out of 1484 passed
|
||||
|
||||
|
@ -6146,7 +6149,7 @@ Test module "MathTests" has passed with:
|
|||
|
||||
Test suite "GPUMatrixSuite" has passed with:
|
||||
56 test cases out of 56 passed
|
||||
1263760 assertions out of 1263760 passed
|
||||
1263762 assertions out of 1263762 passed
|
||||
|
||||
Test case "GPUMatrixSuite/GPUBlasMultiplyAndWeightedAdd" has passed with:
|
||||
1 assertion out of 1 passed
|
||||
|
@ -6290,7 +6293,7 @@ Test module "MathTests" has passed with:
|
|||
4 assertions out of 4 passed
|
||||
|
||||
Test case "GPUMatrixSuite/MatrixSparseTimesDense" has passed with:
|
||||
1 assertion out of 1 passed
|
||||
3 assertions out of 3 passed
|
||||
|
||||
Test case "GPUMatrixSuite/MatrixDenseTimesSparse" has passed with:
|
||||
2 assertions out of 2 passed
|
||||
|
@ -6398,13 +6401,13 @@ Test module "MathTests" has passed with:
|
|||
2 assertions out of 2 passed
|
||||
|
||||
Test suite "QuantizersUnitTests" has passed with:
|
||||
2 test case out of 2 passed
|
||||
2 test cases out of 2 passed
|
||||
12 assertions out of 12 passed
|
||||
|
||||
Test case "QuantizersUnitTests/FloatToShort" has passed with:
|
||||
6 assertions out of 6 passed
|
||||
|
||||
Test case "QuantizersUnitTests/QuantizeZeros" has passed with:
|
||||
|
||||
Test case "QuantizersUnitTests/QuantizeZeros" has passed with:
|
||||
6 assertions out of 6 passed
|
||||
|
||||
Test suite "QuantizedOperationsUnitTests" has passed with:
|
||||
|
@ -6413,7 +6416,7 @@ Test module "MathTests" has passed with:
|
|||
|
||||
Test case "QuantizedOperationsUnitTests/MultiplyIntToShort" has passed with:
|
||||
65 assertions out of 65 passed
|
||||
|
||||
|
||||
Test suite "MathTensorTests" has passed with:
|
||||
9 test cases out of 9 passed
|
||||
6 assertions out of 6 passed
|
||||
|
|
|
@ -61,6 +61,61 @@ BOOST_FIXTURE_TEST_CASE(CPUSparseMatrixCopyColumnSliceToDense, RandomSeedFixture
|
|||
BOOST_CHECK(dm1.IsEqualTo(dm2, c_epsilonFloatE4));
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(CPUSparseMatrixMultiplyAndAdd, RandomSeedFixture)
|
||||
{
|
||||
const size_t m = 100;
|
||||
const size_t n = 50;
|
||||
|
||||
DenseMatrix dm0(m, n);
|
||||
dm0.SetUniformRandomValue(-1, 1, IncrementCounter());
|
||||
|
||||
DenseMatrix dm1(m, n);
|
||||
dm1.SetUniformRandomValue(-300, 1, IncrementCounter());
|
||||
dm1.InplaceTruncateBottom(0);
|
||||
|
||||
SparseMatrix sm1(MatrixFormat::matrixFormatSparseCSC, m, n, 0);
|
||||
foreach_coord(row, col, dm1)
|
||||
{
|
||||
if (dm1(row, col) != 0)
|
||||
{
|
||||
sm1.SetValue(row, col, dm1(row, col));
|
||||
}
|
||||
}
|
||||
|
||||
DenseMatrix dm2(m, n);
|
||||
dm2.SetUniformRandomValue(-200, 1, IncrementCounter());
|
||||
dm2.InplaceTruncateBottom(0);
|
||||
|
||||
SparseMatrix sm2(MatrixFormat::matrixFormatSparseCSC, m, n, 0);
|
||||
foreach_coord(row, col, dm2)
|
||||
{
|
||||
if (dm2(row, col) != 0)
|
||||
{
|
||||
sm2.SetValue(row, col, dm2(row, col));
|
||||
}
|
||||
}
|
||||
|
||||
// generate SparseBlockCol matrix
|
||||
DenseMatrix dmMul(m, m);
|
||||
DenseMatrix::MultiplyAndAdd(dm0, false, dm1, true, dmMul);
|
||||
|
||||
SparseMatrix smMul(MatrixFormat::matrixFormatSparseBlockCol, m, m, 0);
|
||||
SparseMatrix::MultiplyAndAdd(1, dm0, false, sm1, true, smMul);
|
||||
|
||||
foreach_coord(row, col, dmMul)
|
||||
{
|
||||
BOOST_CHECK(abs(smMul(row, col) - dmMul(row, col)) < c_epsilonFloatE4);
|
||||
}
|
||||
|
||||
SparseMatrix::MultiplyAndAdd(1, dm0, false, sm2, true, smMul);
|
||||
DenseMatrix::MultiplyAndAdd(dm0, false, dm2, true, dmMul);
|
||||
|
||||
foreach_coord(row, col, dmMul)
|
||||
{
|
||||
BOOST_CHECK(abs(smMul(row, col) - dmMul(row, col)) < c_epsilonFloatE4);
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_SUITE_END()
|
||||
}
|
||||
} } }
|
||||
|
|
|
@ -44,7 +44,7 @@ BOOST_FIXTURE_TEST_CASE(MatrixSparseTimesDense, RandomSeedFixture)
|
|||
{
|
||||
// DENSE
|
||||
Matrix<float> mAdense(c_deviceIdZero);
|
||||
mAdense.AssignTruncateBottomOf(Matrix<float>::RandomUniform(dim1, dim2, c_deviceIdZero, -3.0f, 0.1f, IncrementCounter()), 0);
|
||||
mAdense.AssignTruncateBottomOf(Matrix<float>::RandomUniform(dim1, dim2, c_deviceIdZero, -300.0f, 0.1f, IncrementCounter()), 0);
|
||||
|
||||
// MATRIX mAsparse becomes sparse
|
||||
Matrix<float> mAsparse(mAdense.DeepClone());
|
||||
|
@ -66,6 +66,44 @@ BOOST_FIXTURE_TEST_CASE(MatrixSparseTimesDense, RandomSeedFixture)
|
|||
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAsparse, transposeA, mB, transposeB, beta, mD);
|
||||
|
||||
BOOST_CHECK(mD.IsEqualTo(mC, c_epsilonFloatE4));
|
||||
|
||||
// SPARSE * DENSE -> SPARSE
|
||||
beta = 1; // note that dense only allow beta == 1, while sparse CPU support arbitrary beta
|
||||
transposeA = false;
|
||||
transposeB = true; // only support !transposeA && tranposeB for Dense * CSC -> SBC
|
||||
|
||||
Matrix<float> mA1sparseCSC(mAdense.DeepClone());
|
||||
mA1sparseCSC.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC, true);
|
||||
|
||||
Matrix<float> mA2dense(c_deviceIdZero);
|
||||
mA2dense.AssignTruncateBottomOf(Matrix<float>::RandomUniform(dim1, dim2, c_deviceIdZero, -300.0f, 0.1f, IncrementCounter()), 0);
|
||||
|
||||
Matrix<float> mA2sparseCSC(mA2dense.DeepClone());
|
||||
mA2sparseCSC.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC, true);
|
||||
|
||||
// dense for comparison
|
||||
mC.Resize(dim1, dim1);
|
||||
mC.SetValue(0.0f);
|
||||
Matrix<float>::MultiplyAndAdd(mAdense, transposeA, mAdense, transposeB, mC);
|
||||
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2dense, transposeB, beta, mC);
|
||||
|
||||
// make sure (dense * sparse -> dense) == (dense * dense -> dense)
|
||||
mD.Resize(dim1, dim1);
|
||||
mD.SetValue(0.0f);
|
||||
Matrix<float>::MultiplyAndAdd(mAdense, transposeA, mA1sparseCSC, transposeB, mD);
|
||||
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2sparseCSC, transposeB, beta, mD);
|
||||
|
||||
BOOST_CHECK(mD.IsEqualTo(mC, c_epsilonFloatE4));
|
||||
|
||||
// test on sparse
|
||||
mD.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseBlockCol, false);
|
||||
Matrix<float>::MultiplyAndAdd(mAdense, transposeA, mA1sparseCSC, transposeB, mD);
|
||||
Matrix<float>::MultiplyAndWeightedAdd(alpha, mAdense, transposeA, mA2sparseCSC, transposeB, beta, mD);
|
||||
|
||||
// copy mD to dense and compare
|
||||
Matrix<float> mE = Matrix<float>::Zeros(dim1, dim1, c_deviceIdZero);
|
||||
Matrix<float>::ScaleAndAdd(1, mD, mE);
|
||||
BOOST_CHECK(mE.IsEqualTo(mC, c_epsilonFloatE4));
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixDenseTimesSparse, RandomSeedFixture)
|
||||
|
|
Загрузка…
Ссылка в новой задаче