1350 строки
54 KiB
C++
1350 строки
54 KiB
C++
//
|
|
// Copyright (c) Microsoft. All rights reserved.
|
|
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
|
//
|
|
|
|
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
|
|
|
|
#include "Basics.h"
|
|
#include "ComputationNode.h"
|
|
#include "ComputationNetwork.h"
|
|
#include "ComputationNetworkBuilder.h" // used for load & save
|
|
#include "LinearAlgebraNodes.h"
|
|
#include "NonlinearityNodes.h"
|
|
#include "ConvolutionalNodes.h"
|
|
#include "RecurrentNodes.h"
|
|
#include "ReshapingNodes.h"
|
|
#include "TrainingNodes.h"
|
|
#include "PreComputeNodes.h"
|
|
#include "EvaluationNodes.h"
|
|
#include "SpecialPurposeNodes.h"
|
|
#include "MPIWrapper.h" // TODO: does not belong here
|
|
#include <string>
|
|
#include <vector>
|
|
#include <stack>
|
|
#include <list>
|
|
#include <set>
|
|
|
|
using namespace std;
|
|
|
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
|
|
|
// -----------------------------------------------------------------------
|
|
// MatrixPool methods
|
|
// -----------------------------------------------------------------------
|
|
|
|
template <>
|
|
vector<shared_ptr<Matrix<float>>>& MatrixPool::GetReleasedMatrices<float>()
|
|
{
|
|
return m_releasedFloatMatrices;
|
|
}
|
|
|
|
template <>
|
|
vector<shared_ptr<Matrix<double>>>& MatrixPool::GetReleasedMatrices<double>()
|
|
{
|
|
return m_releasedDoubleMatrices;
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// construction
|
|
// -----------------------------------------------------------------------
|
|
|
|
// clear the object to empty state; this is used in the destructor, and also when loading
|
|
// This is necessary to make sure we don't leave nodes hanging due to recurrent cyclic references.
|
|
void ComputationNetwork::ClearNetwork()
|
|
{
|
|
// release all references to nodes
|
|
InvalidateCompiledNetwork();
|
|
|
|
for (auto groupIter : GetAllNodeGroups())
|
|
groupIter->clear();
|
|
|
|
// break cycles
|
|
// BUGBUG: This only works if nodes are not shared across networks.
|
|
// Once we allow that (BrainScript editing), we need proper cycle detectors. Luckily, we know our cycles, so it won't be too hard.
|
|
// Or just use weak ptrs.
|
|
for (auto& iter : m_nameToNodeMap)
|
|
iter.second->DetachInputs();
|
|
|
|
m_nameToNodeMap.clear();
|
|
|
|
m_pMBLayout->Init(1, 0);
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// serialization
|
|
// -----------------------------------------------------------------------
|
|
|
|
// after after editing--network is possibly not validated/compiled
|
|
void ComputationNetwork::SaveEdited(const wstring& fileName, const FileOptions fileFormat)
|
|
{
|
|
if (!IsCompiled())
|
|
CompileNetwork();
|
|
Save(fileName, fileFormat);
|
|
}
|
|
|
|
void ComputationNetwork::Save(const wstring& fileName, const FileOptions fileFormat) const
|
|
{
|
|
VerifyIsCompiled("Save");
|
|
// In case of parallel training only the main node should we saving the model to prevent
|
|
// the parallel training nodes from colliding to write the same file
|
|
// TODO: This does not belong here.
|
|
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
|
|
{
|
|
// Saving into temporary file and then renaming it to the requested fileName
|
|
// This is a standard trick to avoid havign corrupted model files if process dies during writing
|
|
wstring tmpFileName = fileName + L".tmp";
|
|
SaveToFileImpl(tmpFileName, fileFormat);
|
|
renameOrDie(tmpFileName, fileName);
|
|
}
|
|
}
|
|
|
|
// TODO: how does the file distinguish float vs double nodes?
|
|
void ComputationNetwork::SaveToFileImpl(const wstring& fileName, const FileOptions fileFormat) const
|
|
{
|
|
File fstream(fileName, fileFormat | FileOptions::fileOptionsWrite);
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCN");
|
|
|
|
// model version
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BVersion");
|
|
fstream << (size_t) CURRENT_CNTK_MODEL_VERSION;
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EVersion");
|
|
|
|
fstream << (size_t) m_nameToNodeMap.size();
|
|
|
|
// put all node info first
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BNodeList");
|
|
for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
|
|
{
|
|
ComputationNodeBasePtr nodePtr = nodeIter->second;
|
|
nodePtr->Save(fstream);
|
|
}
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ENodeList");
|
|
|
|
// put relationship
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRelation");
|
|
for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
|
|
{
|
|
ComputationNodeBasePtr nodePtr = nodeIter->second;
|
|
fstream << nodePtr->NodeName() << nodePtr->GetNumInputs();
|
|
for (size_t i = 0; i < nodePtr->GetNumInputs(); i++)
|
|
{
|
|
if (!nodePtr->Input(i))
|
|
fprintf(stderr, "Warning: node %ls 's child is null, please check your ndl/mel file.\n", nodePtr->NodeName().c_str());
|
|
else
|
|
fstream << nodePtr->Input(i)->NodeName();
|
|
}
|
|
}
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERelation");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BRootNodes");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BFeatureNodes");
|
|
fstream << m_features.size();
|
|
for (size_t i = 0; i < m_features.size(); i++)
|
|
fstream << m_features[i]->NodeName();
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EFeatureNodes");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BLabelNodes");
|
|
fstream << m_labels.size();
|
|
for (size_t i = 0; i < m_labels.size(); i++)
|
|
fstream << m_labels[i]->NodeName();
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ELabelNodes");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BCriterionNodes");
|
|
fstream << m_finalCriteria.size();
|
|
for (size_t i = 0; i < m_finalCriteria.size(); i++)
|
|
fstream << m_finalCriteria[i]->NodeName();
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECriterionNodes");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BEvalNodes");
|
|
fstream << m_evalNodes.size();
|
|
for (size_t i = 0; i < m_evalNodes.size(); i++)
|
|
fstream << m_evalNodes[i]->NodeName();
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EEvalNodes");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerBeginSection, L"BOutputNodes");
|
|
fstream << m_outputNodes.size();
|
|
for (size_t i = 0; i < m_outputNodes.size(); i++)
|
|
{
|
|
fstream << m_outputNodes[i]->NodeName();
|
|
}
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"EOutputNodes");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ERootNodes");
|
|
|
|
fstream.PutMarker(FileMarker::fileMarkerEndSection, L"ECN");
|
|
|
|
fstream.Flush();
|
|
}
|
|
|
|
// load the section of nodes that contain persistable parameters
|
|
// This is used for reloading a model without recreating it, e.g. during training.
|
|
// TODO: Why not just reload it? Because SGD::Train() holds pointers to the parameters directly? That should be fixed.
|
|
template <class ElemType>
|
|
void ComputationNetwork::ReadPersistableParameters(File& fstream, bool create)
|
|
{
|
|
fstream.GetMarker(FileMarker::fileMarkerBeginSection, L"BCN");
|
|
|
|
// model version
|
|
size_t modelVersion = CNTK_MODEL_VERSION_1; // if version info is not there it is version 1
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BVersion"))
|
|
{
|
|
fstream >> modelVersion;
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EVersion");
|
|
}
|
|
if (modelVersion > CURRENT_CNTK_MODEL_VERSION)
|
|
InvalidArgument("Read: The model file has a newer format version (%d) than this CNTK version can handle (%d).", (int)modelVersion, (int)CURRENT_CNTK_MODEL_VERSION);
|
|
|
|
size_t numNodes;
|
|
fstream >> numNodes;
|
|
|
|
// get all node info first
|
|
fstream.GetMarker(FileMarker::fileMarkerBeginSection, L"BNodeList");
|
|
for (size_t i = 0; i < numNodes; i++)
|
|
{
|
|
wstring opName, nodeName;
|
|
fstream >> opName >> nodeName;
|
|
|
|
ComputationNodeBasePtr node;
|
|
if (create) // loading from scratch
|
|
node = ComputationNetworkBuilder<ElemType>::NewNode(opName, m_deviceId, nodeName);
|
|
else // reloading existing
|
|
node = GetNodeFromName(nodeName);
|
|
|
|
node->Load(fstream, modelVersion);
|
|
|
|
if (create) // loaded from scratch
|
|
AddNodeToNet(node);
|
|
else // reloaded existing
|
|
node->Validate(true); // nothing that propagates should have changed --TODO: have a more rigid mechanism to prevent resizing; this should only reload the model parameters
|
|
}
|
|
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ENodeList");
|
|
}
|
|
|
|
// deserialize the model
|
|
// This does not post-process the model (CompileNetwork()). Use Load() instead.
|
|
template <class ElemType>
|
|
void ComputationNetwork::Read(const wstring& fileName)
|
|
{
|
|
ClearNetwork();
|
|
|
|
File fstream(fileName, FileOptions::fileOptionsBinary | FileOptions::fileOptionsRead);
|
|
|
|
ReadPersistableParameters<ElemType>(fstream, true);
|
|
|
|
size_t numNodes = m_nameToNodeMap.size();
|
|
|
|
// get relationship
|
|
fstream.GetMarker(FileMarker::fileMarkerBeginSection, L"BRelation");
|
|
for (size_t i = 0; i < numNodes; i++)
|
|
{
|
|
wstring nodeName;
|
|
size_t numChildren;
|
|
fstream >> nodeName >> numChildren;
|
|
if (numChildren > 0)
|
|
{
|
|
vector<wstring> childrenNames;
|
|
childrenNames.resize(numChildren);
|
|
for (size_t j = 0; j < numChildren; j++)
|
|
fstream >> childrenNames[j];
|
|
|
|
// TODO: how does the file distinguish float from double?
|
|
ComputationNodeBasePtr nodePtr = GetNodeFromName(nodeName);
|
|
vector<ComputationNodeBasePtr> childrenNodes;
|
|
childrenNodes.resize(numChildren);
|
|
for (int j = 0; j < numChildren; j++)
|
|
childrenNodes[j] = GetNodeFromName(childrenNames[j]);
|
|
|
|
nodePtr->AttachInputs(childrenNodes);
|
|
}
|
|
}
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ERelation");
|
|
|
|
fstream.GetMarker(FileMarker::fileMarkerBeginSection, L"BRootNodes");
|
|
{
|
|
wstring nodeName;
|
|
size_t num;
|
|
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BFeatureNodes"))
|
|
{
|
|
fstream >> num;
|
|
|
|
for (size_t i = 0; i < num; i++)
|
|
{
|
|
fstream >> nodeName;
|
|
m_features.push_back(GetNodeFromName(nodeName));
|
|
}
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EFeatureNodes");
|
|
}
|
|
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BLabelNodes"))
|
|
{
|
|
fstream >> num;
|
|
for (size_t i = 0; i < num; i++)
|
|
{
|
|
fstream >> nodeName;
|
|
m_labels.push_back(GetNodeFromName(nodeName));
|
|
}
|
|
}
|
|
// BUGBUG: Should this be inside the block?
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ELabelNodes");
|
|
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BCriterionNodes") ||
|
|
fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BCriteriaNodes" /*legacy*/))
|
|
{
|
|
fstream >> num;
|
|
for (size_t i = 0; i < num; i++)
|
|
{
|
|
fstream >> nodeName;
|
|
m_finalCriteria.push_back(GetNodeFromName(nodeName));
|
|
}
|
|
|
|
if (!fstream.TryGetMarker(FileMarker::fileMarkerEndSection, L"ECriteriaNodes" /*legacy*/))
|
|
{
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ECriterionNodes"); // check legacy first so err msg will use new name
|
|
}
|
|
}
|
|
|
|
// this section is for back compat only, skip over
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BNodesReqMultiSeqHandling"))
|
|
{
|
|
fprintf(stderr, "WARNING: Ignoring defunct 'BNodesReqMultiSeqHandling' section in input file.\n");
|
|
fstream >> num;
|
|
for (size_t i = 0; i < num; i++)
|
|
fstream >> nodeName; // dummy
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ENodesReqMultiSeqHandling");
|
|
}
|
|
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BEvalNodes"))
|
|
{
|
|
fstream >> num;
|
|
for (size_t i = 0; i < num; i++)
|
|
{
|
|
fstream >> nodeName;
|
|
m_evalNodes.push_back(GetNodeFromName(nodeName));
|
|
}
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EEvalNodes");
|
|
}
|
|
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BOutputNodes"))
|
|
{
|
|
fstream >> num;
|
|
for (size_t i = 0; i < num; i++)
|
|
{
|
|
fstream >> nodeName;
|
|
m_outputNodes.push_back(GetNodeFromName(nodeName));
|
|
}
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EOutputNodes");
|
|
}
|
|
|
|
// this section is for back compat only, skip over
|
|
if (fstream.TryGetMarker(FileMarker::fileMarkerBeginSection, L"BPairNodes"))
|
|
{
|
|
fstream >> num;
|
|
if (num > 0)
|
|
RuntimeError("Read: PairNodes are no longer supported");
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"EPairNodes");
|
|
}
|
|
}
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ERootNodes");
|
|
|
|
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ECN");
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// node construction
|
|
// -----------------------------------------------------------------------
|
|
|
|
// non-static version needed because it accesses m_randomSeedOffset
|
|
// Excessively used by SimpleNetworkBuilder, but always after CreateLearnableParameter(), so we should really absorb it there
|
|
template <class ElemType>
|
|
void ComputationNetwork::InitLearnableParameters(const ComputationNodeBasePtr& node, const bool uniformInit, const unsigned long randomSeed, const ElemType initValueScale, bool initOnCPUOnly)
|
|
{
|
|
auto learnableParameterNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(node);
|
|
learnableParameterNode->InitRandom(uniformInit, randomSeed + GetRandomSeedOffset(), initValueScale, initOnCPUOnly);
|
|
}
|
|
|
|
bool ComputationNetwork::IsTypicalCriterionNode(ComputationNodeBasePtr nodePtr)
|
|
{
|
|
// TODO: just use return!
|
|
if (nodePtr->OperationName() == OperationNameOf(SquareErrorNode) ||
|
|
nodePtr->OperationName() == OperationNameOf(LogisticNode) ||
|
|
nodePtr->OperationName() == OperationNameOf(CrossEntropyWithSoftmaxNode) ||
|
|
nodePtr->OperationName() == OperationNameOf(SequenceWithSoftmaxNode) ||
|
|
nodePtr->OperationName() == OperationNameOf(CrossEntropyNode) ||
|
|
nodePtr->OperationName() == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode) ||
|
|
nodePtr->OperationName() == OperationNameOf(ErrorPredictionNode) ||
|
|
#ifdef COMING_SOON
|
|
nodePtr->OperationName() == OperationNameOf(CRFNode) ||
|
|
#endif
|
|
nodePtr->OperationName() == OperationNameOf(DummyCriterionNode))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
// return list of nodes that require precomputation and not precomputed yet
|
|
list<ComputationNodeBasePtr> ComputationNetwork::GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode, bool checkComputed)
|
|
{
|
|
list<ComputationNodeBasePtr> nodes;
|
|
for (const auto& node : GetEvalOrder(rootNode))
|
|
{
|
|
auto pcnode = dynamic_pointer_cast<IPreComputeNode>(node);
|
|
if (pcnode)
|
|
{
|
|
assert(node->RequiresPreCompute());
|
|
if (!checkComputed || !pcnode->HasComputed())
|
|
nodes.push_back(node);
|
|
}
|
|
}
|
|
return nodes;
|
|
}
|
|
|
|
// create the m_inputValues[] and m_learnableParameters[] lists
|
|
// This enumerates all leaves reachable from rootNode.
|
|
// Leaves are:
|
|
// - inputs
|
|
// - learnable parameters
|
|
// It does not traverse disabled ones, i.e.
|
|
// - inputs that are only reachable through PrecomputeNodes that have completed computation
|
|
// - learnable parameters that are constants
|
|
void ComputationNetwork::CollectInputAndLearnableParameters(const ComputationNodeBasePtr& rootNode)
|
|
{
|
|
assert(m_inputValues.find(rootNode) == m_inputValues.end()); // this function must only be called once
|
|
assert(m_learnableParameters.find(rootNode) == m_learnableParameters.end());
|
|
|
|
// gather the lists
|
|
set<ComputationNodeBasePtr> visited;
|
|
list<ComputationNodeBasePtr> inputs, learnableParameters;
|
|
if (rootNode)
|
|
CollectInputAndLearnableParametersRec(rootNode, visited, inputs, learnableParameters);
|
|
else
|
|
for (const auto& root : m_allRoots)
|
|
CollectInputAndLearnableParametersRec(root, visited, inputs, learnableParameters);
|
|
|
|
// sort learnable parameters by name so that we get consistent order when load it from saved file
|
|
learnableParameters.sort([](const ComputationNodeBasePtr& a, const ComputationNodeBasePtr& b)
|
|
{
|
|
return a->NodeName() < b->NodeName();
|
|
});
|
|
|
|
m_inputValues[rootNode] = move(inputs);
|
|
m_learnableParameters[rootNode] = move(learnableParameters);
|
|
}
|
|
|
|
void ComputationNetwork::CollectInputAndLearnableParametersRec(const ComputationNodeBasePtr& node, set<ComputationNodeBasePtr>& visited, list<ComputationNodeBasePtr>& inputs, list<ComputationNodeBasePtr>& learnableParameters)
|
|
{
|
|
if (visited.find(node) != visited.end()) // allready got this one
|
|
return;
|
|
else if (node->OperationName() == OperationNameOf(InputValue) || node->OperationName() == OperationNameOf(SparseInputValue))
|
|
inputs.push_back(node);
|
|
else if (node->OperationName() == OperationNameOf(LearnableParameter) && node->IsParameterUpdateRequired())
|
|
learnableParameters.push_back(node);
|
|
else
|
|
{
|
|
// PreComputeNodes that are already done should not be traversed
|
|
auto pcnode = dynamic_pointer_cast<IPreComputeNode>(node);
|
|
if (pcnode && pcnode->HasComputed())
|
|
return;
|
|
// recurse
|
|
visited.insert(node);
|
|
for (const auto & input : node->GetInputs())
|
|
CollectInputAndLearnableParametersRec(input, visited, inputs, learnableParameters);
|
|
}
|
|
}
|
|
|
|
template <class ElemType>
|
|
/*static*/ void ComputationNetwork::SetDropoutRate(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double dropoutRate, double& prevDropoutRate, unsigned long& dropOutSeed)
|
|
{
|
|
if (dropoutRate != prevDropoutRate)
|
|
{
|
|
fprintf(stderr, "Switching dropout rate to %.8g.\n", dropoutRate);
|
|
// TODO: Change this to use an interface that is independent of <ElemType>.
|
|
list<ComputationNodeBasePtr> dropoutNodes = net->GetNodesWithType(OperationNameOf(DropoutNode), criterionNode);
|
|
if (dropoutNodes.size() == 0 && dropoutRate > 0)
|
|
fprintf(stderr, "WARNING: there is no dropout node.\n");
|
|
else
|
|
for (auto nodeIter = dropoutNodes.begin(); nodeIter != dropoutNodes.end(); nodeIter++)
|
|
{
|
|
auto node = dynamic_pointer_cast<DropoutNode<ElemType>>(*nodeIter);
|
|
node->SetDropoutRate(dropoutRate);
|
|
node->SetRandomSeed(dropOutSeed++);
|
|
}
|
|
|
|
prevDropoutRate = dropoutRate;
|
|
}
|
|
}
|
|
|
|
//set sequence training parameters, e.g. smoothing weight, frame drop threshhold
|
|
template <class ElemType>
|
|
void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net,
|
|
const ComputationNodeBasePtr criterionNode,
|
|
const double& hsmoothingWeight,
|
|
const double& frameDropThresh,
|
|
const bool& doreferencealign,
|
|
const double& amf /*= 14.0f*/,
|
|
const double& lmf /*= 14.0f*/,
|
|
const double& wp /*= 0.0f*/,
|
|
const double& bMMIfactor /*= 0.0f*/,
|
|
const bool& sMBR /*= false*/
|
|
)
|
|
{
|
|
fprintf(stderr, "Setting Hsmoothing weight to %.8g and frame-dropping threshhold to %.8g\n", hsmoothingWeight, frameDropThresh);
|
|
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, usesMBR=%s\n",
|
|
amf, lmf, wp, bMMIfactor, sMBR ? "true" : "false");
|
|
list<ComputationNodeBasePtr> seqNodes = net->GetNodesWithType(OperationNameOf(SequenceWithSoftmaxNode), criterionNode);
|
|
if (seqNodes.size() == 0)
|
|
{
|
|
fprintf(stderr, "WARNING: there is no sequence node.\n");
|
|
}
|
|
else
|
|
{
|
|
for (auto nodeIter = seqNodes.begin(); nodeIter != seqNodes.end(); nodeIter++)
|
|
{
|
|
auto node = dynamic_pointer_cast<SequenceWithSoftmaxNode<ElemType>>(*nodeIter);
|
|
node->SetSmoothWeight(hsmoothingWeight);
|
|
node->SetFrameDropThresh(frameDropThresh);
|
|
node->SetReferenceAlign(doreferencealign);
|
|
node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR);
|
|
}
|
|
}
|
|
}
|
|
|
|
/*static*/ void ComputationNetwork::SetMaxTempMemSizeForCNN(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const size_t maxTempMemSizeInSamples)
|
|
{
|
|
if (maxTempMemSizeInSamples > 0)
|
|
fprintf(stderr, "Setting max temp memory size for Convolution operations to %lu samples.\n", maxTempMemSizeInSamples);
|
|
list<ComputationNodeBasePtr> convolutionNodes = net->GetNodesWithType(OperationNameOf(ConvolutionNode), criterionNode);
|
|
if (convolutionNodes.size() == 0 && maxTempMemSizeInSamples != 0)
|
|
{
|
|
fprintf(stderr, "WARNING: No Convolution operation found.\n");
|
|
}
|
|
else
|
|
{
|
|
for (auto nodeIter = convolutionNodes.begin(); nodeIter != convolutionNodes.end(); nodeIter++)
|
|
{
|
|
auto nodef = dynamic_pointer_cast<ConvolutionNode<float>>(*nodeIter);
|
|
if (nodef)
|
|
nodef->SetmMaxTempMemSizeInSamples(maxTempMemSizeInSamples);
|
|
auto noded = dynamic_pointer_cast<ConvolutionNode<double>>(*nodeIter);
|
|
if (noded)
|
|
noded->SetmMaxTempMemSizeInSamples(maxTempMemSizeInSamples);
|
|
}
|
|
}
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// unit test
|
|
// -----------------------------------------------------------------------
|
|
|
|
/**
|
|
call unit test of each node
|
|
this adds a verification of the correctness of node operations.
|
|
*/
|
|
bool ComputationNetwork::UnitTest(bool allowFragment)
|
|
{
|
|
vector<wstring> vErrors;
|
|
// currently only validates nodes, we should validate everything we can
|
|
if (FeatureNodes().size() == 0 && !allowFragment)
|
|
RuntimeError("No Feature nodes specified");
|
|
// first give criteria nodes as root node
|
|
if (FinalCriterionNodes().size() > 0)
|
|
{
|
|
for (auto& node : FinalCriterionNodes())
|
|
{
|
|
if (!allowFragment)
|
|
FormRecurrentLoops(node);
|
|
// this->SetActualMiniBatchSizeFromFeatures();
|
|
if (!UnitTest(node))
|
|
vErrors.push_back(node->NodeName().c_str());
|
|
}
|
|
}
|
|
else if (!allowFragment)
|
|
RuntimeError("No Criterion nodes specified");
|
|
// now output nodes
|
|
if (OutputNodes().size() > 0)
|
|
{
|
|
for (auto& node : OutputNodes())
|
|
if (!UnitTest(node))
|
|
vErrors.push_back(node->NodeName().c_str());
|
|
}
|
|
else if (!allowFragment)
|
|
RuntimeError("No Output nodes specified");
|
|
// now evaluation nodes
|
|
if (EvaluationNodes().size() > 0)
|
|
{
|
|
for (auto& node : EvaluationNodes())
|
|
if (!UnitTest(node))
|
|
vErrors.push_back(node->NodeName().c_str());
|
|
}
|
|
return vErrors.empty();
|
|
}
|
|
|
|
bool ComputationNetwork::UnitTest(const ComputationNodeBasePtr& rootNode)
|
|
{
|
|
fprintf(stderr, "\n\n Unit test node %ls \n", rootNode->NodeName().c_str());
|
|
|
|
for (const auto& nodeIter : GetEvalOrder(rootNode))
|
|
if (!nodeIter->UnitTest())
|
|
return false;
|
|
|
|
fprintf(stderr, "\n\n");
|
|
|
|
return true;
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// topological plot [erw]
|
|
// -----------------------------------------------------------------------
|
|
|
|
class DotGraphConfigure
|
|
{
|
|
public:
|
|
wstring m_LearnableParameterStyle;
|
|
wstring m_featuresStyle;
|
|
wstring m_CriteriaStyle;
|
|
wstring m_nodesReqMultiSeqHandlingStyle;
|
|
wstring m_labelsStyle;
|
|
wstring m_normalNodeStyle;
|
|
wstring m_PrecomputingNodeStyle;
|
|
wstring m_pastValueNodeStyle;
|
|
wstring m_futureValueNodeStyle;
|
|
|
|
DotGraphConfigure()
|
|
{
|
|
m_LearnableParameterStyle = L"node [ shape = box , color = gray , style = \"filled, rounded\" ]; ";
|
|
m_featuresStyle = L"node [ shape = ellipse , color = red , fillcolor = white ]; ";
|
|
m_CriteriaStyle = L"node [ shape = doublecircle , color = red , fillcolor = white ]; ";
|
|
m_nodesReqMultiSeqHandlingStyle = L"node [ shape = doublecircle , color = brown , fillcolor = white ]; ";
|
|
m_normalNodeStyle = L"node [ shape = ellipse, color = blue, fillcolor = white, style = solid ]; ";
|
|
m_PrecomputingNodeStyle = L"node [ shape = box , color = black, style = \"dashed, filled\", fillcolor= limegreen ] ;";
|
|
m_labelsStyle = L"node [ shape = diamond, color = brown, style = bold ] ; ";
|
|
m_pastValueNodeStyle = L"node [ shape = box3d , color = lightgray, style = \"filled\" , fillcolor = white ] ";
|
|
m_futureValueNodeStyle = L"node [ shape = box3d , color = red, style = \"filled\" , fillcolor = white ] ";
|
|
}
|
|
};
|
|
|
|
wstring ComputationNetwork::FormSpecialNodes(wstring style, vector<ComputationNodeBasePtr>& specialNodes)
|
|
{
|
|
if (specialNodes.empty())
|
|
return L"";
|
|
|
|
wstring str = style;
|
|
|
|
for (const auto& x : specialNodes)
|
|
str = str + msra::strfun::wstrprintf(L"\"%ls\" ", x->GetName().c_str());
|
|
return str + L"; \n";
|
|
}
|
|
|
|
void ComputationNetwork::DescribeNetworkUsingDot(list<ComputationArc>& arcs,
|
|
wstring outFile)
|
|
{
|
|
DotGraphConfigure dotcfg;
|
|
|
|
File fstream(outFile, FileOptions::fileOptionsText | FileOptions::fileOptionsWrite);
|
|
|
|
// get precompute node
|
|
vector<ComputationNodeBasePtr> PreComputedNodes;
|
|
vector<ComputationNodeBasePtr> allnodes = GetAllNodes();
|
|
for (const auto& n : allnodes)
|
|
{
|
|
if (n->RequiresPreCompute())
|
|
PreComputedNodes.push_back(n);
|
|
}
|
|
|
|
// get PastValue node
|
|
vector<ComputationNodeBasePtr> pastValueNodes;
|
|
for (const auto& n : allnodes)
|
|
{
|
|
if (n->OperationName() == OperationNameOf(PastValueNode) || n->OperationName() == L"Delay")
|
|
pastValueNodes.push_back(n);
|
|
}
|
|
|
|
// get FuturetValue node
|
|
vector<ComputationNodeBasePtr> futureValueNodes;
|
|
for (const auto& n : allnodes)
|
|
{
|
|
if (n->OperationName() == OperationNameOf(FutureValueNode))
|
|
futureValueNodes.push_back(n);
|
|
}
|
|
// get learnableParameters
|
|
vector<ComputationNodeBasePtr> learnableParameters;
|
|
for (const auto& n : allnodes)
|
|
{
|
|
if (n->OperationName() == OperationNameOf(LearnableParameter))
|
|
learnableParameters.push_back(n);
|
|
}
|
|
|
|
fstream << "strict digraph {\n";
|
|
fstream << "rankdir = BT ; \n";
|
|
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
// special nodes
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
fstream << L"// special nodes \n";
|
|
|
|
// learnable parameters:
|
|
fstream << FormSpecialNodes(dotcfg.m_LearnableParameterStyle, learnableParameters);
|
|
// features
|
|
fstream << FormSpecialNodes(dotcfg.m_featuresStyle, m_features);
|
|
// labels
|
|
fstream << FormSpecialNodes(dotcfg.m_labelsStyle, m_labels);
|
|
// critera
|
|
fstream << FormSpecialNodes(dotcfg.m_CriteriaStyle, m_finalCriteria);
|
|
// pre-compute nodes
|
|
fstream << FormSpecialNodes(dotcfg.m_PrecomputingNodeStyle, PreComputedNodes);
|
|
// PastValue nodes
|
|
fstream << FormSpecialNodes(dotcfg.m_pastValueNodeStyle, pastValueNodes);
|
|
// FutureValue nodes
|
|
fstream << FormSpecialNodes(dotcfg.m_futureValueNodeStyle, futureValueNodes);
|
|
// normal nodes
|
|
fstream << dotcfg.m_normalNodeStyle << L"\n";
|
|
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
// add labels for each node
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
fstream << L"\n// add labels and operation name\n";
|
|
wstring line;
|
|
for (const auto& x : allnodes)
|
|
{
|
|
line.clear();
|
|
line = msra::strfun::wstrprintf(L" \"%ls\" [ label = \"%ls [%s%s]\\n%ls\" ] ;\n",
|
|
x->GetName().c_str(), x->GetName().c_str(), string(x->GetSampleLayout()).c_str(), x->HasMBLayout() ? " x *" : "",
|
|
x->OperationName().c_str());
|
|
fstream << line;
|
|
}
|
|
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
// sub-graph
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
// subgraph source
|
|
fstream << L"subgraph {\n";
|
|
fstream << L"\t\t rank=source ; ";
|
|
line.clear();
|
|
for (const auto& x : m_features)
|
|
line = line + msra::strfun::wstrprintf(L"\"%ls\" ", x->GetName().c_str());
|
|
fstream << line << L"\n}\n";
|
|
|
|
// subgraph eval/output/criteria
|
|
fstream << L"subgraph {\n";
|
|
fstream << L"\t\t rank=sink ; ";
|
|
line.clear();
|
|
for (const auto& x : m_finalCriteria)
|
|
line = line + msra::strfun::wstrprintf(L"\"%ls\" ", x->GetName().c_str());
|
|
for (const auto& x : m_outputNodes)
|
|
line = line + msra::strfun::wstrprintf(L"\"%ls\" ", x->GetName().c_str());
|
|
for (const auto& x : m_evalNodes)
|
|
line = line + msra::strfun::wstrprintf(L"\"%ls\" ", x->GetName().c_str());
|
|
|
|
fstream << line << L"\n}\n";
|
|
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
// specify arc connections
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
for (auto x = arcs.begin(); x != arcs.end(); x++)
|
|
{
|
|
ComputationNodeBasePtr src = (*x).first;
|
|
ComputationNodeBasePtr des = (*x).second;
|
|
|
|
wstring srcname = src->GetName();
|
|
wstring desname = des->GetName();
|
|
|
|
if (des->OperationName() == OperationNameOf(PastValueNode) || des->OperationName() == L"Delay")
|
|
{
|
|
// special treament for arc with PastValue node as the children
|
|
// create a dummy node
|
|
ComputationNodeBasePtr pastValueNode = des;
|
|
wstring dummyName = des->GetName() + L".dummy";
|
|
wstring out = msra::strfun::wstrprintf(L"node [ shape = box3d , color = lightgray, style = \"filled\" , label = \"%ls\" ] ; \"%ls\"\n",
|
|
(pastValueNode->GetName() + L"\\n(PastValue)").c_str(),
|
|
dummyName.c_str());
|
|
line = out;
|
|
line += msra::strfun::wstrprintf(L"\"%ls\" -> \"%ls\" ; \n", dummyName.c_str(), srcname.c_str());
|
|
}
|
|
else if (des->OperationName() == OperationNameOf(FutureValueNode))
|
|
{
|
|
// special treament for arc with FutureValue node as the children
|
|
// create a dummy node
|
|
ComputationNodeBasePtr futureValueNode = des;
|
|
wstring dummyName = des->GetName() + L".dummy";
|
|
wstring out = msra::strfun::wstrprintf(L"node [ shape = box3d , color = red, style = \"filled\" , label = \"%ls\" ] ; \"%ls\"\n",
|
|
(futureValueNode->GetName() + L"\\n(FutureValue)").c_str(),
|
|
dummyName.c_str());
|
|
line = out;
|
|
line += msra::strfun::wstrprintf(L"\"%ls\" -> \"%ls\" ; \n", dummyName.c_str(), srcname.c_str());
|
|
}
|
|
else
|
|
{
|
|
line = msra::strfun::wstrprintf(L"\"%ls\" -> \"%ls\" ; \n", desname.c_str(), srcname.c_str());
|
|
}
|
|
|
|
fstream << line;
|
|
}
|
|
fstream << L"\n}\n";
|
|
}
|
|
|
|
void ComputationNetwork::PlotNetworkTopology(const wstring outputFile) // [1/13/2015 erw] plot network topology using dot language
|
|
{
|
|
VerifyIsCompiled("PlotNetworkTopology");
|
|
// ValidateNetwork(false, true);
|
|
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
// step 1. get all the arcs in the network
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
unordered_set<ComputationNodeBasePtr> visited;
|
|
list<ComputationArc> arcs;
|
|
|
|
for (auto groupIter : GetAllNodeGroups())
|
|
{
|
|
// note: this will also loop over m_features and m_labels, which will do nothing since they have no inputs
|
|
// TODO: test whether that is true
|
|
const auto& group = *groupIter;
|
|
for (size_t i = 0; i < group.size(); i++)
|
|
group[i]->EnumerateArcs(visited, arcs);
|
|
}
|
|
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
// step 2. output dot description
|
|
// ////////////////////////////////////////////////////////////////////////
|
|
DescribeNetworkUsingDot(arcs, outputFile);
|
|
}
|
|
|
|
// enumerate all arcs that can be reached starting from the current node's children
|
|
// [in/out] visited record already visited nodes
|
|
void ComputationNodeBase::EnumerateArcs(std::unordered_set<ComputationNodeBasePtr>& visited, std::list<ComputationArc>& arcs)
|
|
{
|
|
std::list<ComputationNodeBasePtr> tovisit;
|
|
|
|
if (visited.find(shared_from_this()) == visited.end()) // only do when this node has not been visited before
|
|
{
|
|
tovisit.push_back(shared_from_this());
|
|
|
|
while (!tovisit.empty())
|
|
{
|
|
ComputationNodeBasePtr curNode = tovisit.front();
|
|
tovisit.pop_front();
|
|
|
|
if (visited.find(curNode) == visited.end())
|
|
{
|
|
for (size_t i = 0; i < curNode->m_inputs.size(); i++)
|
|
{
|
|
arcs.push_back(ComputationArc(curNode, curNode->m_inputs[i]));
|
|
|
|
if (visited.find(curNode->m_inputs[i]) == visited.end()) // this children has not been visited before
|
|
tovisit.push_front(curNode->m_inputs[i]); // going to visit each of the children
|
|
}
|
|
visited.insert(curNode);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// -----------------------------------------------------------------------
|
|
// specialized operations
|
|
// -----------------------------------------------------------------------
|
|
|
|
// TODO: Lift this into config language, move underlying code to math lib. This should be a model-editing operation.
|
|
|
|
// ========================================
|
|
// This function performs SVD decomposition for different groups of learnable parameters
|
|
// we perform SVD decomposition such that
|
|
// A \approx B*C, where rank(B)=rank(C)=r < rank(A)
|
|
// After SVD decomposition, the node A will become an intermediate node whose children are B,C ;
|
|
// B and C are two learnable parameters
|
|
// ========================================
|
|
// BUGBUG: this only currently works for one ElemType, not both
|
|
template <class ElemType>
|
|
void ComputationNetwork::PerformSVDecomposition(const map<wstring, float>& SVDConfig, size_t AlignedSize)
|
|
{
|
|
vector<pair<vector<wstring>, float>> nodeGroups;
|
|
wregex NameFilter;
|
|
|
|
for (const auto& e : SVDConfig)
|
|
{
|
|
wstring regexStr = e.first;
|
|
float keepRatio = e.second;
|
|
vector<wstring> NamesInGroup;
|
|
|
|
NameFilter.assign(regexStr);
|
|
|
|
for (auto n = m_nameToNodeMap.begin(); n != m_nameToNodeMap.end(); n++)
|
|
{
|
|
if (!regexStr.empty() && !regex_match(n->first, NameFilter))
|
|
{
|
|
// if regexStr is not empty and the the node node does not match with the regexStr
|
|
continue;
|
|
}
|
|
|
|
shared_ptr<ComputationNode<ElemType>> ptr = dynamic_pointer_cast<LearnableParameter<ElemType>>(n->second);
|
|
if (!ptr)
|
|
continue;
|
|
|
|
if (ptr->Value().GetNumCols() == 1 || ptr->Value().GetNumRows() == 1)
|
|
continue;
|
|
|
|
// still here ?
|
|
NamesInGroup.push_back(n->first);
|
|
}
|
|
nodeGroups.push_back(make_pair(NamesInGroup, keepRatio));
|
|
}
|
|
|
|
size_t groupID = 0;
|
|
for (auto& group : nodeGroups)
|
|
{
|
|
float keepratio = group.second;
|
|
fprintf(stderr,
|
|
"--------------------------------------------------------------------------------------------\n");
|
|
fprintf(stderr,
|
|
"ParameterSVD: start to process group %d with KeepRatio=%.2f\n",
|
|
(int) groupID++, keepratio);
|
|
fprintf(stderr,
|
|
"--------------------------------------------------------------------------------------------\n");
|
|
|
|
for (const auto& name : group.first)
|
|
{
|
|
if (m_nameToNodeMap.find(name) == m_nameToNodeMap.end())
|
|
{
|
|
// could be deleted in the previous groups
|
|
continue;
|
|
}
|
|
|
|
shared_ptr<ComputationNode<ElemType>> pNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(m_nameToNodeMap[name]);
|
|
|
|
// Step 1. do SVD decomposition
|
|
Matrix<ElemType> A = pNode->ValueAsMatrix();
|
|
|
|
// it is a vector, no need to do it
|
|
if (A.GetNumCols() == 1 || A.GetNumRows() == 1)
|
|
continue;
|
|
|
|
size_t m = A.GetNumRows();
|
|
size_t n = A.GetNumCols();
|
|
|
|
Matrix<ElemType> S(-1), U(-1), VT(-1), W(-1);
|
|
chrono::time_point<chrono::system_clock> stTime = chrono::system_clock::now();
|
|
Matrix<ElemType>::SVD(A, S, U, VT, W);
|
|
chrono::time_point<chrono::system_clock> enTime = chrono::system_clock::now();
|
|
|
|
// A \in R^{mXn}
|
|
// U \in R^{mXm}
|
|
// VT \in R^{nXn}
|
|
// S \in R^{min(m,n),1}
|
|
// S is in descending order
|
|
|
|
ElemType totalenergy = 0.0f;
|
|
for (size_t i = 0; i < S.GetNumRows(); i++)
|
|
totalenergy += S(i, 0);
|
|
ElemType keepenergy = totalenergy * keepratio;
|
|
ElemType runenergy = 0.0f;
|
|
|
|
size_t r = 0;
|
|
for (size_t indx = 0; indx < S.GetNumRows(); indx++)
|
|
{
|
|
runenergy += S(indx, 0);
|
|
if (runenergy > keepenergy)
|
|
{
|
|
r = indx + 1;
|
|
break;
|
|
}
|
|
}
|
|
|
|
r = r > S.GetNumRows() ? S.GetNumRows() : r;
|
|
|
|
if (r % AlignedSize != 0)
|
|
{
|
|
r -= r % AlignedSize;
|
|
r = r + AlignedSize > S.GetNumRows() ? S.GetNumRows() : r + AlignedSize;
|
|
}
|
|
// r = (r + 7) & (~7); // to keep the number of rows/cols of resultant matrix a multipier of 8
|
|
// which can be helpful at runtime
|
|
|
|
chrono::duration<double> elapsedtime = enTime - stTime;
|
|
fprintf(stderr,
|
|
"Performing SVD for a %5d-by-%-5d matrix (node name: %-20ls) --- computation time %5.2f secs ; keep %4.1f%% energy ===> keep %5d svd values (reduce to %4.1f%% parameters) \n",
|
|
(int) m, (int) n, name.c_str(), elapsedtime.count(),
|
|
keepratio * 100, (int) r,
|
|
((m + n) * r + 0.0f) / m / n * 100);
|
|
|
|
// redU in R^ {mXr}
|
|
Matrix<ElemType> redU = U.ColumnSlice(0, r);
|
|
Matrix<ElemType> redVT(-1);
|
|
|
|
// redVT in R^{rXn}
|
|
redVT.Resize(r, n);
|
|
redVT.AssignRowSliceValuesOf(VT, 0, r);
|
|
|
|
Matrix<ElemType> redS(r, (size_t)1, A.GetDeviceId());
|
|
for (size_t i = 0; i < r; i++)
|
|
{
|
|
ElemType sqrtsigma = (ElemType) sqrt((double) S(i, 0));
|
|
redS(i, 0) = sqrtsigma;
|
|
}
|
|
|
|
redU.RowElementMultiplyWith(redS.Transpose());
|
|
redVT.ColumnElementMultiplyWith(redS);
|
|
|
|
// Step 2. create two new Parameter nodes and one Times node
|
|
wstring leftChildName = name + L"-U";
|
|
wstring rightChildName = name + L"-V";
|
|
shared_ptr<ComputationNode<ElemType>> pLeft = AddNodeToNetWithElemType(New<LearnableParameter<ElemType>>(m_deviceId, leftChildName, m, r));
|
|
shared_ptr<ComputationNode<ElemType>> pRight = AddNodeToNetWithElemType(New<LearnableParameter<ElemType>>(m_deviceId, rightChildName, r, n));
|
|
|
|
pLeft->ValueAsMatrix() = redU;
|
|
pRight->ValueAsMatrix() = redVT;
|
|
|
|
shared_ptr<ComputationNode<ElemType>> pTimes = AddNodeToNetAndAttachInputs(New<TimesNode<ElemType>>(m_deviceId, name + L"-SVD"), pLeft, pRight);
|
|
|
|
// Step 3. remove old node
|
|
ReplaceLeafNode(name, pTimes);
|
|
}
|
|
}
|
|
|
|
// redo necessary post-processing
|
|
CompileNetwork();
|
|
}
|
|
|
|
// save network to legacy DBN.exe format
|
|
class DbnLayer
|
|
{
|
|
public:
|
|
DbnLayer() : Node(nullptr), Bias(nullptr), Sigmoided(false) {}
|
|
ComputationNodeBasePtr Node;
|
|
ComputationNodeBasePtr Bias;
|
|
bool Sigmoided;
|
|
~DbnLayer() {};
|
|
};
|
|
|
|
template <class ElemType>
|
|
void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const
|
|
{
|
|
// Helper methods
|
|
auto VerifyTypeAll = [](const std::vector<ComputationNodeBasePtr>& nodes, const std::wstring& typeValue) -> bool
|
|
{
|
|
return std::find_if(nodes.begin(), nodes.end(), [&typeValue](ComputationNodeBasePtr node)->bool { return node->OperationName() != typeValue; }) == nodes.end();
|
|
};
|
|
auto GetNodeConsumers = [&net](const ComputationNodeBasePtr node) -> std::vector<ComputationNodeBasePtr>
|
|
{
|
|
std::vector<ComputationNodeBasePtr> consumers;
|
|
for (auto& item : net->GetAllNodes())
|
|
{
|
|
for (auto& input : item->GetInputs())
|
|
{
|
|
if (input == node)
|
|
{
|
|
consumers.push_back(item);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
return consumers;
|
|
};
|
|
auto GetFirstDifferentNode = [](const std::vector<ComputationNodeBasePtr>& list, const ComputationNodeBasePtr node) -> ComputationNodeBasePtr
|
|
{
|
|
auto foundNode = std::find_if(list.begin(), list.end(), [&node](ComputationNodeBasePtr item)->bool { return item != node; });
|
|
return foundNode == list.end() ? nullptr : *foundNode;
|
|
};
|
|
auto GetFirstNodeWithDifferentType = [](const std::vector<ComputationNodeBasePtr>& list, const std::wstring& type) -> ComputationNodeBasePtr
|
|
{
|
|
auto foundNode = std::find_if(list.begin(), list.end(), [&type](ComputationNodeBasePtr item)->bool { return item->OperationName() != type; });
|
|
return foundNode == list.end() ? nullptr : *foundNode;
|
|
};
|
|
auto WhereNode = [](const std::vector<ComputationNodeBasePtr>& nodes, const function<bool(ComputationNodeBasePtr)>& predicate) -> std::vector<ComputationNodeBasePtr>
|
|
{
|
|
std::vector<ComputationNodeBasePtr> results;
|
|
|
|
for (auto& node : nodes)
|
|
{
|
|
if (predicate(node))
|
|
{
|
|
results.push_back(node);
|
|
}
|
|
}
|
|
|
|
return results;
|
|
};
|
|
auto GetNodesWithType = [](const std::vector<ComputationNodeBasePtr>& list, const std::wstring& type) -> std::vector<ComputationNodeBasePtr>
|
|
{
|
|
std::vector<ComputationNodeBasePtr> results;
|
|
|
|
for (auto& node : list)
|
|
{
|
|
if (node->OperationName() == type )
|
|
{
|
|
results.push_back(node);
|
|
}
|
|
}
|
|
|
|
return results;
|
|
};
|
|
auto GetAllPriorNodes = [](ComputationNodeBasePtr node)->bool
|
|
{
|
|
std::wstring lowerName = node->GetName();
|
|
std::transform(lowerName.begin(), lowerName.end(), lowerName.begin(), ::tolower);
|
|
|
|
return node->OperationName() == OperationNameOf(LearnableParameter) && (lowerName.find(L"prior") != wstring::npos);
|
|
};
|
|
|
|
// Get output node
|
|
std::list<ComputationNodeBasePtr> outputNodes = net->GetNodesWithType(OperationNameOf(ErrorPredictionNode));
|
|
ComputationNodeBasePtr outputNode = GetFirstNodeWithDifferentType(outputNodes.front()->GetInputs(), OperationNameOf(InputValue));
|
|
|
|
if (outputNode == nullptr)
|
|
{
|
|
RuntimeError("Cannot find output node");
|
|
}
|
|
|
|
std::list<ComputationNodeBasePtr> orderList;
|
|
std::stack<ComputationNodeBasePtr> nodeStack;
|
|
|
|
nodeStack.push(outputNode);
|
|
|
|
while (nodeStack.size() > 0)
|
|
{
|
|
auto node = nodeStack.top();
|
|
nodeStack.pop();
|
|
auto nodeInputs = node->GetInputs();
|
|
for (auto& input : nodeInputs)
|
|
{
|
|
for (auto& item : orderList)
|
|
{
|
|
if (item == input)
|
|
{
|
|
RuntimeError("Cyclic dependency on node '%ls'", item->GetName().c_str());
|
|
}
|
|
}
|
|
|
|
nodeStack.push(input);
|
|
}
|
|
|
|
orderList.push_back(node);
|
|
}
|
|
|
|
orderList.reverse();
|
|
|
|
// All multiplication nodes that multiply a symbolic variable
|
|
std::list<ComputationNodeBasePtr> multNodes;
|
|
typedef shared_ptr<DbnLayer> DbnLayerPtr;
|
|
std::list<DbnLayerPtr> dbnLayers;
|
|
|
|
for (auto& item : orderList)
|
|
{
|
|
if (item->OperationName() == OperationNameOf(TimesNode) && !VerifyTypeAll(item->GetInputs(), OperationNameOf(LearnableParameter)))
|
|
{
|
|
multNodes.push_back(item);
|
|
}
|
|
}
|
|
|
|
for (auto& node : multNodes)
|
|
{
|
|
std::vector<ComputationNodeBasePtr> consumers = GetNodeConsumers(node);
|
|
if (consumers.size() == 1)
|
|
{
|
|
bool sigmoided = false;
|
|
std::wstring layerId(node->GetName());
|
|
|
|
ComputationNodeBasePtr firstConsumer = consumers.front();
|
|
|
|
if (firstConsumer->OperationName() != OperationNameOf(PlusNode))
|
|
{
|
|
RuntimeError("Expected a plus node to consume the times node.");
|
|
}
|
|
|
|
ComputationNodeBasePtr bias = GetFirstDifferentNode(firstConsumer->GetInputs(), node);
|
|
|
|
auto consumer2 = GetNodeConsumers(consumers.front()).front();
|
|
if (consumer2->OperationName() == L"Sigmoid")
|
|
{
|
|
sigmoided = true;
|
|
layerId = consumer2->GetName();
|
|
}
|
|
else
|
|
{
|
|
layerId = firstConsumer->GetName();
|
|
}
|
|
|
|
// If one of its inputs was itself a multiplication node, then split it out
|
|
// into dbn-style.
|
|
std::vector<ComputationNodeBasePtr> aggTimes = GetNodesWithType(node->GetInputs(), OperationNameOf(TimesNode));
|
|
if (aggTimes.size() > 0)
|
|
{
|
|
ComputationNodeBasePtr multNode = aggTimes.front();
|
|
DbnLayerPtr l1 = make_shared<DbnLayer>();
|
|
DbnLayerPtr l2 = make_shared<DbnLayer>();
|
|
|
|
auto firstInput = multNode->GetInputs()[0];
|
|
auto secondInput = multNode->GetInputs()[1];
|
|
l2->Bias = bias;
|
|
l2->Node = firstInput;
|
|
|
|
l1->Bias = nullptr;
|
|
l1->Node = secondInput;
|
|
|
|
l1->Sigmoided = false;
|
|
l2->Sigmoided = sigmoided;
|
|
|
|
dbnLayers.push_back(l1);
|
|
dbnLayers.push_back(l2);
|
|
}
|
|
else
|
|
{
|
|
auto paramNode = GetNodesWithType(node->GetInputs(), OperationNameOf(LearnableParameter)).front();
|
|
DbnLayerPtr l1 = make_shared<DbnLayer>();
|
|
l1->Bias = bias;
|
|
l1->Node = paramNode;
|
|
l1->Sigmoided = sigmoided;
|
|
|
|
dbnLayers.push_back(l1);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Write the layers to the output
|
|
// DBN wants std not invstd, so need to invert each element
|
|
std::vector<ComputationNodeBasePtr> normalizationNodes = GetNodesWithType(net->GetAllNodes(), OperationNameOf(PerDimMeanVarNormalizationNode));
|
|
if (normalizationNodes.size() == 0)
|
|
{
|
|
RuntimeError("Model does not contain at least one node with the '%ls' operation.", OperationNameOf(PerDimMeanVarNormalizationNode).c_str());
|
|
}
|
|
|
|
ComputationNodeBasePtr meanNode = normalizationNodes.front()->GetInputs()[1];
|
|
ComputationNodeBasePtr stdNode = normalizationNodes.front()->GetInputs()[2];
|
|
|
|
Matrix<ElemType> meanNodeMatrix = meanNode->As<ComputationNode<ElemType>>()->Value();
|
|
Matrix<ElemType> stdNodeMatrix = stdNode->As<ComputationNode<ElemType>>()->Value();
|
|
Matrix<ElemType> invStdNodeMatrix(stdNodeMatrix.ElementInverse());
|
|
|
|
std::vector<ComputationNodeBasePtr> priorNodes = WhereNode(net->GetAllNodes(), GetAllPriorNodes);
|
|
if (priorNodes.size() != 1)
|
|
{
|
|
RuntimeError("Could not reliably determine the prior node!");
|
|
}
|
|
|
|
// =================
|
|
// Write to the file
|
|
// =================
|
|
File fstream(fileName, FileOptions::fileOptionsBinary | FileOptions::fileOptionsWrite);
|
|
|
|
// local helper functions for writing stuff in DBN.exe-expected format
|
|
auto PutTag = [&fstream](const char * tag) { while (*tag) fstream << *tag++; };
|
|
auto PutString = [&fstream](const char * string) { fstream.WriteString(string, 0); };
|
|
auto PutInt = [&fstream](int val) { fstream << val; };
|
|
|
|
// write a DBN matrix object, optionally applying a function
|
|
auto PutMatrixConverted = [&](const Matrix<ElemType> * m, size_t maxelem, const char * name, float(*f)(float))
|
|
{
|
|
PutTag("BMAT");
|
|
PutString(name);
|
|
size_t numRows = m->GetNumRows();
|
|
size_t numCols = m->GetNumCols();
|
|
|
|
if (maxelem == SIZE_MAX)
|
|
{
|
|
PutInt(numRows);
|
|
PutInt(numCols);
|
|
}
|
|
else // this allows to shorten a vector, as we need for mean/invstd
|
|
{
|
|
PutInt(maxelem);
|
|
PutInt(1);
|
|
}
|
|
|
|
// this code transposes the matrix on the fly, and outputs at most maxelem floating point numbers to the stream
|
|
size_t k = 0;
|
|
for (size_t j = 0; j < numCols && k < maxelem; j++)
|
|
for (size_t i = 0; i < numRows && k < maxelem; i++, k++)
|
|
fstream << f((float)(*m)(i, j));
|
|
|
|
PutTag("EMAT");
|
|
};
|
|
auto PutMatrix = [&](const Matrix<ElemType> * m, const char * name) { PutMatrixConverted(m, SIZE_MAX, name, [](float v) { return v; }); };
|
|
|
|
// write out the data
|
|
// Dump DBN header
|
|
PutString("DBN\ncomment=dbn finetune epoch 121\niter.state.frameacc=61.327412\niter.state.logp=0.090508");
|
|
PutTag("BDBN");
|
|
PutInt(0); // a version number
|
|
PutInt(static_cast<int>(dbnLayers.size())); // number of layers
|
|
|
|
// Dump feature norm
|
|
PutMatrixConverted(&meanNodeMatrix, meanNodeMatrix.GetNumRows() / 4, "gmean", [](float v) { return v; });
|
|
PutMatrixConverted(&invStdNodeMatrix, invStdNodeMatrix.GetNumRows() / 4, "gstddev", [](float v) { return v; });
|
|
|
|
PutTag("BNET");
|
|
auto lastOne = dbnLayers.end();
|
|
--lastOne;
|
|
for (auto ii = dbnLayers.begin(), e = dbnLayers.end(); ii != e; ++ii)
|
|
{
|
|
DbnLayerPtr& layer = *ii;
|
|
|
|
if (ii == dbnLayers.begin())
|
|
{
|
|
PutString("rbmgaussbernoulli");
|
|
}
|
|
else if (ii == lastOne)
|
|
{
|
|
PutString("perceptron");
|
|
}
|
|
else if (layer->Sigmoided)
|
|
{
|
|
PutString("rbmbernoullibernoulli");
|
|
}
|
|
else
|
|
{
|
|
PutString("rbmisalinearbernoulli");
|
|
}
|
|
|
|
// Write out the main weight matrix
|
|
auto weight = (layer->Node->As<ComputationNode<ElemType>>()->Value());
|
|
auto transpose = weight.Transpose();
|
|
PutMatrix(&transpose, "W");
|
|
|
|
// Write out biasing vector
|
|
// Is mandatory, so pack with zeroes if not given
|
|
auto rows = layer->Node->GetAsMatrixNumRows();
|
|
if (layer->Bias == nullptr)
|
|
{
|
|
auto zeros = Matrix<ElemType>::Zeros(rows, 1, CPUDEVICE);
|
|
PutMatrixConverted(&zeros, rows, "a", [](float v) { return v; });
|
|
}
|
|
else
|
|
{
|
|
PutMatrixConverted(&(layer->Bias->As<ComputationNode<ElemType>>()->Value()), rows, "a", [](float v) { return v; });
|
|
}
|
|
|
|
// Some sort of legacy vector that is useless
|
|
auto zeros = Matrix<ElemType>::Zeros(0, 0, CPUDEVICE);
|
|
PutMatrix(&(zeros), "b");
|
|
}
|
|
|
|
// Dump the priors
|
|
PutTag("ENET");
|
|
PutMatrix(&(priorNodes.front()->As<ComputationNode<ElemType>>()->Value()), "Pu");
|
|
PutTag("EDBN");
|
|
}
|
|
|
|
template void ComputationNetwork::InitLearnableParameters<float>(const ComputationNodeBasePtr& node, const bool uniformInit, const unsigned long randomSeed, const float initValueScale, bool initOnCPUOnly);
|
|
template void ComputationNetwork::Read<float>(const wstring& fileName);
|
|
template void ComputationNetwork::ReadPersistableParameters<float>(File& fstream, bool create);
|
|
template void ComputationNetwork::PerformSVDecomposition<float>(const map<wstring, float>& SVDConfig, size_t alignedsize);
|
|
template /*static*/ void ComputationNetwork::SetDropoutRate<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double dropoutRate, double& prevDropoutRate, unsigned long& dropOutSeed);
|
|
template void ComputationNetwork::SetSeqParam<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
|
|
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR);
|
|
template void ComputationNetwork::SaveToDbnFile<float>(ComputationNetworkPtr net, const std::wstring& fileName) const;
|
|
|
|
template void ComputationNetwork::InitLearnableParameters<double>(const ComputationNodeBasePtr& node, const bool uniformInit, const unsigned long randomSeed, const double initValueScale, bool initOnCPUOnly);
|
|
template void ComputationNetwork::Read<double>(const wstring& fileName);
|
|
template void ComputationNetwork::ReadPersistableParameters<double>(File& fstream, bool create);
|
|
template void ComputationNetwork::PerformSVDecomposition<double>(const map<wstring, float>& SVDConfig, size_t alignedsize);
|
|
template /*static*/ void ComputationNetwork::SetDropoutRate<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double dropoutRate, double& prevDropoutRate, unsigned long& dropOutSeed);
|
|
template void ComputationNetwork::SetSeqParam<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
|
|
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR);
|
|
template void ComputationNetwork::SaveToDbnFile<double>(ComputationNetworkPtr net, const std::wstring& fileName) const;
|
|
|
|
// register ComputationNetwork with the ScriptableObject system
|
|
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<ComputationNetwork> registerComputationNetwork(L"ComputationNetwork");
|
|
} } }
|