From 4ff3ab53f3fbc09f6e159f51b626f98ef1b33fad Mon Sep 17 00:00:00 2001 From: Dong Yu Date: Fri, 9 Jan 2015 00:19:17 -0800 Subject: [PATCH] Add SetNZCount to OmmonMatrix to support setting number of Non-zero values externally. Change the Resize function in sparse matrices to make it clear that the numNZ is used to reserve memory only so that we can call Resize repeatedly without affecting the actual number of non-zero values. Change the cpu sparse matrix's Resize to support keeping existing values when memory is reallocated. Change SetValue function in cpu sparsematrix to support automatic resizing. --- Math/Math/CPUSparseMatrix.cpp | 77 +++++++++++++++++++++++------------ Math/Math/CPUSparseMatrix.h | 4 +- Math/Math/CommonMatrix.h | 1 + Math/Math/GPUSparseMatrix.cu | 54 ++++++++++++++++-------- Math/Math/GPUSparseMatrix.h | 17 ++++++-- Math/Math/Matrix.cpp | 6 +-- Math/Math/Matrix.h | 2 +- 7 files changed, 108 insertions(+), 53 deletions(-) diff --git a/Math/Math/CPUSparseMatrix.cpp b/Math/Math/CPUSparseMatrix.cpp index cb419aa8d..519502877 100644 --- a/Math/Math/CPUSparseMatrix.cpp +++ b/Math/Math/CPUSparseMatrix.cpp @@ -95,6 +95,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { m_numRows = 0; m_numCols = 0; m_elemSizeAllocated = 0; + m_compIndexSize = 0; m_externalBuffer = false; m_computeDevice = CPUDEVICE; m_nz = 0; @@ -181,11 +182,13 @@ namespace Microsoft { namespace MSR { namespace CNTK { throw std::logic_error("CPUSparseMatrix: unsupported SetValue() call."); } - if(m_elemSizeAllocated < m_nz +1) { - throw std::logic_error("CPUSparseMatrix: allocated size is too small."); + if(m_elemSizeAllocated < m_nz +1) //automatic resize + { + Resize(m_numRows, m_numCols, m_nz + 100); //allocate 100 more elelemnts and keep existing values } - if(rIdx < 0 || rIdx >= m_numRows) { + if(rIdx < 0 || rIdx >= m_numRows) + { throw std::logic_error("CPUSparseMatrix: SetValue() invalid row id"); } @@ -228,43 +231,62 @@ namespace Microsoft { namespace MSR { namespace CNTK { } template - void CPUSparseMatrix::Resize(const size_t numRows, const size_t numCols, size_t size, const bool growOnly) + void CPUSparseMatrix::Resize(const size_t numRows, const size_t numCols, size_t numNZElemToReserve, const bool growOnly, const bool keepExistingValues) { - m_nz = 0; - m_colIdx = -1; - m_numRows = numRows; - m_numCols = numCols; + size_t newCompIndexSize = (numCols > numRows ? numCols : numRows) + 1; + bool reallocate = (m_elemSizeAllocated < numNZElemToReserve || (m_elemSizeAllocated > numNZElemToReserve && !growOnly) || m_compIndexSize < newCompIndexSize); - if (m_elemSizeAllocated < size || (m_elemSizeAllocated > size && ! growOnly)) + m_numRows = numRows; + m_numCols = numCols; + + if (reallocate) { - m_elemSizeAllocated = size; - if(m_format == MatrixFormat::matrixFormatSparseCSC || m_format == MatrixFormat::matrixFormatSparseCSR) + if (m_format == MatrixFormat::matrixFormatSparseCSC || m_format == MatrixFormat::matrixFormatSparseCSR) { - if(m_pArray != NULL) + ElemType *pArray = new ElemType[numNZElemToReserve]; + size_t *unCompIndex = new size_t[numNZElemToReserve]; + size_t *compIndex = new size_t[newCompIndexSize]; + + if (keepExistingValues && m_nz > 0) + { + memcpy(pArray, m_pArray, sizeof(ElemType)*m_nz); + memcpy(unCompIndex, m_unCompIndex, sizeof(size_t)*m_nz); + memcpy(compIndex, m_compIndex, sizeof(size_t)*m_compIndexSize); + } + + if (m_pArray != NULL) delete[] m_pArray; - if(m_unCompIndex != NULL) + if (m_unCompIndex != NULL) delete[] m_unCompIndex; - if(m_compIndex != NULL) - delete[] m_compIndex; - - //int len = m_format == MatrixFormat::matrixFormatSparseCSC ? numCols : numRows; - size_t len = numCols > numRows ? numCols : numRows; - m_pArray = new ElemType[size]; - m_unCompIndex = new size_t[size]; - m_compIndex = new size_t[len+1]; - - } + if (m_compIndex != NULL) + delete[] m_compIndex; + + m_pArray = pArray; + m_unCompIndex = unCompIndex; + m_compIndex = compIndex; + } else if(m_format == MatrixFormat::matrixFormatSparseBlockCol || m_format == MatrixFormat::matrixFormatSparseBlockRow) { - if(m_blockVal != NULL) + ElemType *blockVal = new ElemType[numNZElemToReserve]; + size_t *blockIds = new size_t[newCompIndexSize]; + + if (keepExistingValues && m_elemSizeAllocated > 0) + { + memcpy(blockVal, m_blockVal, sizeof(ElemType)*m_elemSizeAllocated); + memcpy(blockIds, m_blockIds, sizeof(size_t)*m_compIndexSize); + } + + if (m_blockVal != NULL) delete[] m_blockVal; if(m_blockIds != NULL) delete[] m_blockIds; - size_t max = numCols > numRows ? numCols : numRows; - m_blockVal = new ElemType[size]; - m_blockIds = new size_t[max]; + m_blockVal = blockVal; + m_blockIds = blockIds; } + + m_elemSizeAllocated = numNZElemToReserve; + m_compIndexSize = newCompIndexSize; } } @@ -274,6 +296,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { { m_nz = 0; m_colIdx = -1; + m_compIndexSize = 0; m_blockSize = 0; } diff --git a/Math/Math/CPUSparseMatrix.h b/Math/Math/CPUSparseMatrix.h index fab7f58e1..7bf18d034 100644 --- a/Math/Math/CPUSparseMatrix.h +++ b/Math/Math/CPUSparseMatrix.h @@ -86,7 +86,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { int GetComputeDeviceId() const {return -1;} - void Resize(const size_t numRows, const size_t numCols, size_t size = 0, const bool growOnly = true); + void Resize(const size_t numRows, const size_t numCols, size_t numNZElemToReserve = 0, const bool growOnly = true, const bool keepExistingValues = true); void Reset(); public: @@ -133,6 +133,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { private: int m_colIdx; //used to SetValue() + size_t m_compIndexSize; + //non-zero values are stored in m_pArray size_t *m_unCompIndex; //row/col ids in CSC/CSR format size_t *m_compIndex; //begin ids of col/row in CSC/CSR format diff --git a/Math/Math/CommonMatrix.h b/Math/Math/CommonMatrix.h index fddd75dca..bab1673cc 100644 --- a/Math/Math/CommonMatrix.h +++ b/Math/Math/CommonMatrix.h @@ -81,6 +81,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { void SetOwnBuffer(bool own) {m_externalBuffer = !own;} wchar_t* GetMatrixName() const { return m_matrixName; } size_t NzCount() const {return m_nz;} + void SetNzCount(const size_t nz) { m_nz = nz; } size_t GetSizeAllocated() const {return m_elemSizeAllocated; } void SetMatrixName(const wchar_t* s) { diff --git a/Math/Math/GPUSparseMatrix.cu b/Math/Math/GPUSparseMatrix.cu index d759b4d3a..600cfef45 100644 --- a/Math/Math/GPUSparseMatrix.cu +++ b/Math/Math/GPUSparseMatrix.cu @@ -130,6 +130,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { deepCopy.PrepareDevice(); Resize(deepCopy.m_numRows, deepCopy.m_numCols, deepCopy.m_nz, deepCopy.m_format); + m_nz = deepCopy.m_nz; CUDACALL(cudaMemcpy(NzValues(), deepCopy.NzValues(), NzSize(), cudaMemcpyDeviceToDevice)); CUDACALL(cudaMemcpy(MajorIndexLocation(), deepCopy.MajorIndexLocation(), MajorIndexSize(), cudaMemcpyDeviceToDevice)); CUDACALL(cudaMemcpy(SecondaryIndexLocation(), deepCopy.SecondaryIndexLocation(), SecondaryIndexSize(), cudaMemcpyDeviceToDevice)); @@ -199,6 +200,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { //we need to do conversion because CPUSparseMatrix uses size_t for indexes while GPUSparseMatrix uses int GPUSPARSE_INDEX_TYPE *h_CSRRow, *h_Col; cpuSparseMatrix.Resize(GetNumRows(), GetNumCols(), GetNumNZElements()); + cpuSparseMatrix.SetNzCount(GetNumNZElements()); PrepareDevice(); h_CSRRow = new GPUSPARSE_INDEX_TYPE[m_numRows + 1]; @@ -219,6 +221,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { //we need to do conversion because CPUSparseMatrix uses size_t for indexes while GPUSparseMatrix uses int GPUSPARSE_INDEX_TYPE *h_CSCCol, *h_Row; cpuSparseMatrix.Resize(GetNumRows(), GetNumCols(), GetNumNZElements()); + cpuSparseMatrix.SetNzCount(GetNumNZElements()); PrepareDevice(); h_CSCCol = new GPUSPARSE_INDEX_TYPE[m_numCols + 1]; @@ -322,6 +325,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { outMatrix.ChangeDeviceTo(GetComputeDeviceId()); outMatrix.Resize(m_numRows, m_numCols, m_nz,newFormat); + outMatrix.SetNzCount(m_nz); if (oldFormat == matrixFormatSparseCSR && newFormat == matrixFormatSparseCSC) { @@ -475,6 +479,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { CUDACALL(cudaEventDestroy(done)); Resize(numRows, numCols, nnzTotalDevHostPtr, matrixFormat); + SetNzCount(nnzTotalDevHostPtr); CUDACALL(cudaEventCreate(&done)); @@ -605,6 +610,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { void GPUSparseMatrix::ResizeAsAndCopyIndexFrom(const GPUSparseMatrix& a, const bool growOnly /*= true*/) { Resize(a.m_numRows, a.m_numCols, a.m_nz, a.m_format, growOnly); + SetNzCount(a.m_nz); CUDACALL(cudaMemcpy(MajorIndexLocation(), a.MajorIndexLocation(), MajorIndexSize(), cudaMemcpyDeviceToDevice)); CUDACALL(cudaMemcpy(SecondaryIndexLocation(), a.SecondaryIndexLocation(), SecondaryIndexSize(), cudaMemcpyDeviceToDevice)); @@ -630,30 +636,29 @@ namespace Microsoft { namespace MSR { namespace CNTK { } template - void GPUSparseMatrix::Resize(const size_t numRows, const size_t numCols, const size_t numNZ, const bool growOnly) + void GPUSparseMatrix::Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve, const bool growOnly) { - Resize(numRows, numCols, numNZ, GetFormat(), growOnly); + Resize(numRows, numCols, numNZElemToReserve, GetFormat(), growOnly); } + //WARNING: When memory is reallocated existing information will be lost, workaround is to allocte enough memory from start. + //TODO: add keepExistingValues (default to true) argument so that the existing values are kept even after reallocation template - void GPUSparseMatrix::Resize(const size_t numRows, const size_t numCols, const size_t numNZ, const MatrixFormat matrixFormat, const bool growOnly /*= true*/) + void GPUSparseMatrix::Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve, const MatrixFormat matrixFormat, const bool growOnly /*= true*/) { m_numRows = numRows; m_numCols = numCols; - m_nz = numNZ; if (matrixFormat == MatrixFormat::matrixFormatSparseCSC || matrixFormat == MatrixFormat::matrixFormatSparseCSR) { - bool reallocate = (m_totalBufferSizeAllocated < BufferSizeNeeded() || (!growOnly && m_totalBufferSizeAllocated > BufferSizeNeeded())); + size_t bufferSizeNeeded = BufferSizeNeeded(numNZElemToReserve); + bool reallocate = (m_totalBufferSizeAllocated < bufferSizeNeeded || (!growOnly && m_totalBufferSizeAllocated > bufferSizeNeeded)); if (reallocate) { if (!OwnBuffer()) throw logic_error("Cannot Resize since the buffer is managed externally."); - m_totalBufferSizeAllocated = BufferSizeNeeded(); - m_elemSizeAllocated = numNZ; - if (m_pArray != nullptr) CUDACALL(cudaFree(m_pArray)); if (m_block2Id != nullptr) @@ -663,21 +668,29 @@ namespace Microsoft { namespace MSR { namespace CNTK { PrepareDevice(); - CUDACALL(cudaMalloc((void **)&m_pArray, m_totalBufferSizeAllocated)); + CUDACALL(cudaMalloc((void **)&m_pArray, bufferSizeNeeded)); CUDACALL(cudaMalloc((void **)&m_block2Id, sizeof(size_t)*(numCols * 2))); CUDACALL(cudaMalloc((void **)&m_block2UniqId, sizeof(size_t)*(numCols * 2))); + + m_totalBufferSizeAllocated = bufferSizeNeeded; + m_elemSizeAllocated = numNZElemToReserve; } } else if (matrixFormat == MatrixFormat::matrixFormatSparseBlockCol || matrixFormat == MatrixFormat::matrixFormatSparseBlockRow) { - if (m_blockVal != nullptr) - CUDACALL(cudaFree(m_blockVal)); - if (m_blockIds != nullptr) - CUDACALL(cudaFree(m_blockIds)); - PrepareDevice(); - CUDACALL(cudaMalloc((void **)&m_blockVal, sizeof(ElemType)*numNZ)); - int max = numCols > numRows ? numCols : numRows; - CUDACALL(cudaMalloc((void **)&m_blockIds, sizeof(size_t)*max)); + if (m_elemSizeAllocated < numNZElemToReserve || (m_elemSizeAllocated > numNZElemToReserve && !growOnly)) + { + if (m_blockVal != nullptr) + CUDACALL(cudaFree(m_blockVal)); + if (m_blockIds != nullptr) + CUDACALL(cudaFree(m_blockIds)); + PrepareDevice(); + CUDACALL(cudaMalloc((void **)&m_blockVal, sizeof(ElemType)*numNZElemToReserve)); + int max = numCols > numRows ? numCols : numRows; + CUDACALL(cudaMalloc((void **)&m_blockIds, sizeof(size_t)*max)); + + m_elemSizeAllocated = numNZElemToReserve; + } } else NOT_IMPLEMENTED; @@ -701,6 +714,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { m_format = matrixFormatSparseCSR; Resize(numRows, numCols, nz); + SetNzCount(nz); cudaMemcpyKind kind = IsOnDevice ? cudaMemcpyDeviceToDevice : cudaMemcpyHostToDevice; CUDACALL(cudaMemcpy(RowLocation(), h_CSRRow, RowSize(), kind)); @@ -741,6 +755,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { SetComputeDeviceId(devId); m_format = matrixFormatSparseCSC; Resize(numRows, numCols, nz); + SetNzCount(nz); cudaMemcpyKind kind = IsOnDevice ? cudaMemcpyDeviceToDevice : cudaMemcpyHostToDevice; CUDACALL(cudaMemcpy(RowLocation(), h_Row, RowSize(), kind)); @@ -792,6 +807,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { { m_format = matrixFormatSparseCSC; Resize(m_numRows, m_numCols, labelSize); + SetNzCount(labelSize); m_expandedSize = expandedSize; m_blockSize = blockSize; @@ -1320,6 +1336,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { // now we know the number of Non-zeros in the result set, set the output size c.Resize(m, n, nnzC); + c.m_nz = nnzC; + CUDACALL(cudaMemcpy(c.SecondaryIndexLocation(),csrRowPtrC,c.SecondaryIndexSize(),cudaMemcpyDeviceToDevice)); // if we allocated the buffer, free it here @@ -1805,6 +1823,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { PrepareDevice(); GPUSparseMatrix c(GetFormat(), GetComputeDeviceId()); c.Resize(n, m, nnz, GetFormat()); + c.m_nz = nnz; cusparseHandle_t cusparseHandle = 0; CUSPARSECALL(cusparseCreate(&cusparseHandle)); @@ -2283,6 +2302,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { NOT_IMPLEMENTED; us.Resize(rownum, colnum, nz); + us.SetNzCount(nz); if (nz > 0) { diff --git a/Math/Math/GPUSparseMatrix.h b/Math/Math/GPUSparseMatrix.h index 3ca54c9ac..1ba149ed0 100644 --- a/Math/Math/GPUSparseMatrix.h +++ b/Math/Math/GPUSparseMatrix.h @@ -58,7 +58,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { size_t MajorIndexSize() const { return sizeof(GPUSPARSE_INDEX_TYPE)*MajorIndexCount(); } // actual number of major index bytes in use GPUSPARSE_INDEX_TYPE* SecondaryIndexLocation() const { return MajorIndexLocation() + m_elemSizeAllocated; } //this is the compressed index, col/row in CSC/CSR format - size_t SecondaryIndexCount() const + size_t SecondaryIndexCount(const size_t numNZ) const { if (m_format&matrixFormatCompressed) { @@ -67,12 +67,21 @@ namespace Microsoft { namespace MSR { namespace CNTK { return cnt; } else - return m_nz; // COO format + return numNZ; // COO format } + + size_t SecondaryIndexCount() const + { + return SecondaryIndexCount(m_nz); + } + // get size for compressed index size_t SecondaryIndexSize() const { return (SecondaryIndexCount())*sizeof(GPUSPARSE_INDEX_TYPE); } size_t BufferSizeNeeded() const { return NzSize() + MajorIndexSize() + SecondaryIndexSize(); } + size_t BufferSizeNeeded(const size_t numNZ) const + { return sizeof(ElemType)*numNZ + sizeof(GPUSPARSE_INDEX_TYPE)*(numNZ + SecondaryIndexCount(numNZ)); } + size_t BufferSizeAllocated() const { return m_totalBufferSizeAllocated; } ElemType* BufferPointer() const; @@ -88,8 +97,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { void SetValue(const GPUMatrix& denseMatrix); void ResizeAsAndCopyIndexFrom(const GPUSparseMatrix& a, const bool growOnly = true); - void Resize(const size_t numRows, const size_t numCols, const size_t numNZ, const MatrixFormat matrixFormat, const bool growOnly = true); //matrix format will affect the size to allocate - void Resize(const size_t numRows, const size_t numCols, const size_t numNZ, const bool growOnly = true); + void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve, const MatrixFormat matrixFormat, const bool growOnly = true); //matrix format will affect the size to allocate + void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve, const bool growOnly = true); GPUSparseMatrix Transpose() const; void InplaceTranspose(); diff --git a/Math/Math/Matrix.cpp b/Math/Math/Matrix.cpp index f813c6a1f..dafc01036 100644 --- a/Math/Math/Matrix.cpp +++ b/Math/Math/Matrix.cpp @@ -1150,14 +1150,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { } template - void Matrix::Resize(const size_t numRows, const size_t numCols, const size_t allocatedSize /*=0*/, bool growOnly /*=true*/) + void Matrix::Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve /*=0*/, bool growOnly /*=true*/) { DISPATCH_MATRIX_ON_FLAG(this, this, m_CPUMatrix->Resize(numRows,numCols,growOnly), m_GPUMatrix->Resize(numRows,numCols,growOnly), - m_CPUSparseMatrix->Resize(numRows, numCols, allocatedSize, growOnly), - m_GPUSparseMatrix->Resize(numRows, numCols, allocatedSize, growOnly) + m_CPUSparseMatrix->Resize(numRows, numCols, numNZElemToReserve, growOnly), + m_GPUSparseMatrix->Resize(numRows, numCols, numNZElemToReserve, growOnly) ); } diff --git a/Math/Math/Matrix.h b/Math/Math/Matrix.h index ced6638c7..26a825bcc 100644 --- a/Math/Math/Matrix.h +++ b/Math/Math/Matrix.h @@ -112,7 +112,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { void RmsProp(Matrix& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN); void Reshape(const size_t numRows, const size_t numCols); - void Resize(const size_t numRows, const size_t numCols, const size_t allocatedSize = 0, bool growOnly = true); //by default we only reallocate if need to grow + void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve = 0, bool growOnly = true); //by default we only reallocate if need to grow size_t GetAllocatedSize() const; void Reset(); //reset for sparse matrix