From 43ea68b59e1cff1d51d00d5905455775997bff53 Mon Sep 17 00:00:00 2001 From: Dong Yu Date: Sat, 17 Jan 2015 13:37:00 -0800 Subject: [PATCH] change SparseInputValue node and CPU Sparse matrix to make LM CPU training work. --- MachineLearning/cn/ComputationNetwork.h | 5 +- MachineLearning/cn/ComputationNode.h | 104 ++++---------------- MachineLearning/cn/SimpleNetworkBuilder.cpp | 12 +-- Math/Math/CPUSparseMatrix.cpp | 53 +++++----- Math/Math/GPUMatrix.h | 6 +- 5 files changed, 58 insertions(+), 122 deletions(-) diff --git a/MachineLearning/cn/ComputationNetwork.h b/MachineLearning/cn/ComputationNetwork.h index 8dcd09535..524e8a1da 100644 --- a/MachineLearning/cn/ComputationNetwork.h +++ b/MachineLearning/cn/ComputationNetwork.h @@ -811,10 +811,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { return newNode; } - //sparse matrix size is optionally specified - ComputationNodePtr CreateSparseInputNode(const std::wstring inputName, const size_t rows, const size_t cols, const size_t size = 0) + ComputationNodePtr CreateSparseInputNode(const std::wstring inputName, const size_t rows, const size_t cols) { - ComputationNodePtr newNode(new SparseInputValue(rows, cols, size, m_deviceId, inputName)); + ComputationNodePtr newNode(new SparseInputValue(rows, cols, m_deviceId, inputName)); AddNodeToNet(newNode); return newNode; } diff --git a/MachineLearning/cn/ComputationNode.h b/MachineLearning/cn/ComputationNode.h index f03ad9334..a3b577b46 100644 --- a/MachineLearning/cn/ComputationNode.h +++ b/MachineLearning/cn/ComputationNode.h @@ -1150,112 +1150,38 @@ protected: \ template class InputValue; template - class SparseInputValue : public ComputationNode + class SparseInputValue : public InputValue { UsingComputationNodeMembers; public: - SparseInputValue (size_t rows, size_t cols, size_t size, const DEVICEID_TYPE deviceId=AUTOPLACEMATRIX, const std::wstring name = L"") : ComputationNode(deviceId) + SparseInputValue (size_t rows, size_t cols, const DEVICEID_TYPE deviceId=AUTOPLACEMATRIX, const std::wstring name = L"") : InputValue(rows, cols, deviceId, name) { - if (rows * cols == 0) - throw std::logic_error("This InputValue dimension is 0."); - - m_outputWidth = 1; - m_outputHeight = rows; - m_outputChannels = 1; - - m_nodeName = (name == L""? CreateUniqNodeName() : name); - m_deviceId = deviceId; - MoveMatricesToDevice(deviceId); - m_functionValues.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC); - m_functionValues.Resize(rows, cols, size); - m_needGradient = false; - InitRecurrentNode(); + ConvertToSparseMatrix(); } SparseInputValue (size_t imageWidth, size_t imageHeight, size_t imageChannels, size_t numImages, const DEVICEID_TYPE deviceId=AUTOPLACEMATRIX, const std::wstring name = L"") - : ComputationNode(deviceId) + : InputValue(imageWidth, imageHeight, imageChannels, numImages, deviceId, name) { - size_t rows = imageWidth * imageHeight * imageChannels; - size_t cols = numImages; - - if (rows * cols == 0) - throw std::logic_error("This InputValue dimension is 0."); - - m_outputWidth = imageWidth; - m_outputHeight = imageHeight; - m_outputChannels = imageChannels; - - m_nodeName = (name == L""? CreateUniqNodeName() : name); - m_deviceId = deviceId; - MoveMatricesToDevice(deviceId); - m_functionValues.SwitchToMatrixType(MatrixType::SPARSE); - m_functionValues.Resize(rows, cols); - m_needGradient = false; - InitRecurrentNode(); - } - - SparseInputValue (File& fstream, const size_t modelVersion, const DEVICEID_TYPE deviceId=AUTOPLACEMATRIX, const std::wstring name = L"") : ComputationNode(deviceId) - { - m_nodeName = (name == L""? CreateUniqNodeName() : name); - LoadFromFile(fstream, modelVersion, deviceId); + ConvertToSparseMatrix(); } - virtual void SaveToFile(File& fstream) const + SparseInputValue (File& fstream, const size_t modelVersion, const DEVICEID_TYPE deviceId=AUTOPLACEMATRIX, const std::wstring name = L"") : InputValue(fstream, modelVersion, deviceId, name) { - ComputationNode::SaveToFile(fstream); - - fstream << FunctionValues().GetNumRows() << FunctionValues().GetNumCols(); - fstream << FunctionValues().GetAllocatedSize(); - fstream << m_outputWidth << m_outputHeight << m_outputChannels; + ConvertToSparseMatrix(); } virtual void LoadFromFile(File& fstream, const size_t modelVersion, const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX) { - ComputationNode::LoadFromFile(fstream, modelVersion, deviceId); - - size_t rows, cols; - fstream >> rows >> cols; - if (rows * cols == 0) - throw std::logic_error("This InputValue dimension is 0."); - - size_t size; //sparse matrix size - fstream >> size; - - fstream >> m_outputWidth >> m_outputHeight >> m_outputChannels; - - m_functionValues.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC); - m_functionValues.Resize(rows, cols, size); - m_needGradient = false; + InputValue::LoadFromFile(fstream, modelVersion, deviceId); + ConvertToSparseMatrix(); } virtual const std::wstring OperationName() const {return TypeName();} static const std::wstring TypeName() {return L"SparseInputValue";} - virtual void EvaluateThisNode() {} - virtual void EvaluateThisNode(const size_t /*timeIdxInSeq*/) {} - - virtual void ComputeInputPartial(const size_t /*inputIndex*/) {} - virtual void ComputeInputPartial(const size_t /*inputIndex*/, const size_t /*timeIdxInSeq*/) {} - - virtual void Validate() - { - PrintSelfBeforeValidation(); - //CopyImageSizeFromInputs(); //not necessary since InputValue are leafs. put it here for consistent - } - - virtual void DumpNodeInfo(const bool printValues, File& fstream) const - { - ComputationNode::DumpNodeInfo(printValues, fstream); - - char str[4096]; - sprintf(str, "[%lu,%lu]", FunctionValues().GetNumRows(), FunctionValues().GetNumCols()); - fstream << string(str); - } - // copy constructor - SparseInputValue (const SparseInputValue * node, const std::wstring& newName, const CopyNodeFlags flags) : ComputationNode(node->m_deviceId) + SparseInputValue (const SparseInputValue * node, const std::wstring& newName, const CopyNodeFlags flags) : InputValue(node, newName, flags) { - node->CopyTo(this, newName, flags); } virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const @@ -1266,11 +1192,15 @@ protected: \ return node; } - virtual TaskDescriptor* GetPTaskDescriptor(TaskType /*taskType*/, size_t inputIndex=0) const + private: + void ConvertToSparseMatrix() { - inputIndex; - return nullptr; + size_t rows = m_functionValues.GetNumRows(); + size_t cols = m_functionValues.GetNumCols(); + m_functionValues.SwitchToMatrixType(MatrixType::SPARSE, matrixFormatSparseCSC); + m_functionValues.Resize(rows, cols); //SwitchToMatrixType does not reserve information right now. } + }; diff --git a/MachineLearning/cn/SimpleNetworkBuilder.cpp b/MachineLearning/cn/SimpleNetworkBuilder.cpp index a0064e961..513b3a3f6 100644 --- a/MachineLearning/cn/SimpleNetworkBuilder.cpp +++ b/MachineLearning/cn/SimpleNetworkBuilder.cpp @@ -28,8 +28,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { size_t numRecurrentLayers = m_recurrentLayers.size(); ComputationNodePtr input = nullptr, w = nullptr, b = nullptr, u = nullptr, delay = nullptr, output = nullptr, label = nullptr, prior = nullptr; - //TODO: to figure out sparse matrix size - input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize, 0); + + input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize); m_net->FeatureNodes().push_back(input); if (m_applyMeanVarNorm) @@ -628,8 +628,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { ComputationNodePtr ot=nullptr, it=nullptr, ft=nullptr, gt=nullptr, ct=nullptr, ht=nullptr; ComputationNodePtr delayXI = nullptr, delayXII = nullptr, delayXIII = nullptr, delayXIV = nullptr; - //TODO: to figure out sparse matrix size - input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize, 0); + input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize); m_net->FeatureNodes().push_back(input); if (m_applyMeanVarNorm) @@ -739,8 +738,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { w = m_net->CreateSparseLearnableParameter(msra::strfun::wstrprintf (L"W%d", numHiddenLayers), m_layerSizes[numHiddenLayers+1], m_layerSizes[numHiddenLayers], 0); m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale); // b = m_net->CreateLearnableParameter(msra::strfun::wstrprintf (L"B%d", numHiddenLayers), m_layerSizes[numHiddenLayers+1], 1); - //TODO: to figure out sparse matrix size - label = m_net->CreateSparseInputNode(L"labels", m_layerSizes[numHiddenLayers+1], mbSize, 0); + label = m_net->CreateSparseInputNode(L"labels", m_layerSizes[numHiddenLayers+1], mbSize); AddTrainAndEvalCriterionNodes(input, label, w); output = m_net->Times(w, input); @@ -1102,7 +1100,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { ComputationNodePtr directWIO = nullptr, directInput=nullptr, directOutput=nullptr; ComputationNodePtr outputFromEachLayer[MAX_DEPTH] = {nullptr}; - input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize, m_layerSizes[0] * mbSize); + input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize); m_net->FeatureNodes().push_back(input); if (m_applyMeanVarNorm) diff --git a/Math/Math/CPUSparseMatrix.cpp b/Math/Math/CPUSparseMatrix.cpp index 519502877..faf5c8fde 100644 --- a/Math/Math/CPUSparseMatrix.cpp +++ b/Math/Math/CPUSparseMatrix.cpp @@ -101,14 +101,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { m_nz = 0; m_matrixName = NULL; - if(m_format == MatrixFormat::matrixFormatSparseCSC || m_format == MatrixFormat::matrixFormatSparseCSR) + //if(m_format == MatrixFormat::matrixFormatSparseCSC || m_format == MatrixFormat::matrixFormatSparseCSR) { m_colIdx = -1; m_pArray = NULL; m_unCompIndex = NULL; m_compIndex = NULL; } - else if (m_format == MatrixFormat::matrixFormatSparseBlockCol || m_format == MatrixFormat::matrixFormatSparseBlockRow) + //else if (m_format == MatrixFormat::matrixFormatSparseBlockCol || m_format == MatrixFormat::matrixFormatSparseBlockRow) { m_blockSize = 0; m_blockVal = NULL; @@ -247,8 +247,12 @@ namespace Microsoft { namespace MSR { namespace CNTK { size_t *unCompIndex = new size_t[numNZElemToReserve]; size_t *compIndex = new size_t[newCompIndexSize]; + if (keepExistingValues && (m_nz > numNZElemToReserve || m_compIndexSize > newCompIndexSize)) + throw std::logic_error("Resize: To keep values m_nz should <= numNZElemToReserve and m_compIndexSize <= newCompIndexSize"); + if (keepExistingValues && m_nz > 0) { + assert(m_compIndexSize > 0 && m_nz < numNZElemToReserve); memcpy(pArray, m_pArray, sizeof(ElemType)*m_nz); memcpy(unCompIndex, m_unCompIndex, sizeof(size_t)*m_nz); memcpy(compIndex, m_compIndex, sizeof(size_t)*m_compIndexSize); @@ -270,9 +274,13 @@ namespace Microsoft { namespace MSR { namespace CNTK { ElemType *blockVal = new ElemType[numNZElemToReserve]; size_t *blockIds = new size_t[newCompIndexSize]; + if (keepExistingValues && (m_nz > numNZElemToReserve || m_compIndexSize > newCompIndexSize)) + throw std::logic_error("Resize: To keep values m_nz should <= numNZElemToReserve and m_compIndexSize <= newCompIndexSize"); + if (keepExistingValues && m_elemSizeAllocated > 0) { - memcpy(blockVal, m_blockVal, sizeof(ElemType)*m_elemSizeAllocated); + assert(m_compIndexSize > 0 && m_elemSizeAllocated < numNZElemToReserve); + memcpy(blockVal, m_blockVal, sizeof(ElemType)*m_nz); memcpy(blockIds, m_blockIds, sizeof(size_t)*m_compIndexSize); } @@ -296,18 +304,17 @@ namespace Microsoft { namespace MSR { namespace CNTK { { m_nz = 0; m_colIdx = -1; - m_compIndexSize = 0; m_blockSize = 0; } - //c = op(a) * op(this) or c += op(a) * op(this) + //c = alpha*op(lhs) * op(rhs) + beta*c template void CPUSparseMatrix::MultiplyAndWeightedAdd(ElemType alpha, const CPUMatrix& lhs, const bool transposeA, const CPUSparseMatrix& rhs, const bool transposeB, ElemType beta, CPUMatrix& c) { if (lhs.IsEmpty() || rhs.IsEmpty()) - throw std::logic_error("LeftMultiplyAndAdd: one of the input matrix is empty."); + throw std::logic_error("MultiplyAndWeightedAdd: one of the input matrix is empty."); int m = transposeA? (int)lhs.GetNumCols(): (int)lhs.GetNumRows(); int k = transposeA? (int)lhs.GetNumRows(): (int)lhs.GetNumCols(); @@ -318,7 +325,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { assert (k == l); if (k != l) { - throw std::invalid_argument("CPUSparseMatrix::MultiplyAndAdd: The inner dimensions of a and b must match."); + throw std::invalid_argument("CPUSparseMatrix::MultiplyAndWeightedAdd: The inner dimensions of a and b must match."); } if (c.GetNumRows() != m || c.GetNumCols() != n) @@ -330,7 +337,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { { memset(c.GetArray(), 0, sizeof(ElemType) * c.GetNumElements()); } - else + else if (beta != 1) { #pragma omp parallel for foreach_coord(i,j,c) @@ -339,15 +346,18 @@ namespace Microsoft { namespace MSR { namespace CNTK { } } + if (rhs.GetFormat() != matrixFormatSparseCSC) + NOT_IMPLEMENTED; + if (!transposeA && !transposeB) { for(size_t j = 0; j < rhs.GetNumCols(); j++) { - size_t start = rhs.m_compIndex[j]; + size_t start = rhs.m_compIndex[j]; //ColLocation size_t end = rhs.m_compIndex[j+1]; for(size_t p = start; p < end; p++) { - size_t i = rhs.m_unCompIndex[p]; + size_t i = rhs.m_unCompIndex[p]; //RowLocation ElemType val = rhs.m_pArray[p]; for(size_t h = 0; h < lhs.GetNumRows(); h++) @@ -385,7 +395,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { } } - //c = alpha * op(a) * op(this) + //c = alpha * op(lhs) * op(rhs) template void CPUSparseMatrix::MultiplyAndAdd(ElemType alpha, const CPUMatrix& lhs, const bool transposeA, const CPUSparseMatrix& rhs, const bool transposeB, CPUSparseMatrix& c) @@ -414,10 +424,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { else if (!transposeA && transposeB) { //allocate enough memory - if(c.m_elemSizeAllocated < lhs.GetNumElements()) - { - c.Resize(c.GetNumRows(), c.GetNumCols(), lhs.GetNumElements()); - } + c.SetFormat(matrixFormatSparseBlockCol); + c.Resize(c.GetNumRows(), c.GetNumCols(), lhs.GetNumElements()); map w2Id; for(size_t j = 0; j < rhs.GetNumCols(); j++) @@ -460,7 +468,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { { throw std::logic_error("sparse matrix out of range."); } - c.SetFormat(matrixFormatSparseBlockCol); + //c.SetFormat(matrixFormatSparseBlockCol); } else if (transposeA && !transposeB) { @@ -552,8 +560,10 @@ namespace Microsoft { namespace MSR { namespace CNTK { //allocate enough memory if(etp.m_elemSizeAllocated < etp.GetNumElements()) { - etp.Resize(etp.GetNumRows(), etp.GetNumCols(), etp.GetNumElements()); + etp.Resize(etp.GetNumRows(), etp.GetNumCols(), etp.GetNumElements(), true, false); } + etp.Reset(); + entropyScore(0, 0) = 0; for(size_t j = 0; j < label.GetNumCols(); j++) { @@ -655,11 +665,10 @@ namespace Microsoft { namespace MSR { namespace CNTK { const CPUMatrix& /*idx2cls*/, CPUSparseMatrix& grd) { + grd.SetFormat(matrixFormatSparseBlockRow); //allocate enough memory - if(grd.m_elemSizeAllocated < error.m_nz*input.GetNumRows()) - { - grd.Resize(grd.GetNumRows(), grd.GetNumCols(), error.m_nz*input.GetNumRows()); - } + grd.Resize(grd.GetNumRows(), grd.GetNumCols(), error.m_nz*input.GetNumRows(), true, false); + grd.Reset(); map w2Id; for(size_t j = 0; j < error.GetNumCols(); j++) @@ -701,7 +710,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { { throw std::logic_error("sparse matrix out of range."); } - grd.SetFormat(matrixFormatSparseBlockRow); + //grd.SetFormat(matrixFormatSparseBlockRow); } // normal update for smoothed gradients c and current gradients (this) diff --git a/Math/Math/GPUMatrix.h b/Math/Math/GPUMatrix.h index dd0c40f3f..a84d76f39 100644 --- a/Math/Math/GPUMatrix.h +++ b/Math/Math/GPUMatrix.h @@ -365,10 +365,10 @@ namespace Microsoft { namespace MSR { namespace CNTK { stream << s << format; stream<