Optimize the flow of post batch normalization statistics, and allow disable regularization terms in batch normalization
This commit is contained in:
Родитель
d2de39c993
Коммит
9cb329a0da
|
@ -88,11 +88,11 @@ Train=[
|
||||||
]
|
]
|
||||||
|
|
||||||
PBN=[
|
PBN=[
|
||||||
action="pbn"
|
action="bnstat"
|
||||||
modelPath="$ModelDir$/ResNet_50"
|
modelPath="$ModelDir$/ResNet_50"
|
||||||
# Set minibatch size for testing.
|
# Set minibatch size for testing.
|
||||||
minibatchSize=256
|
minibatchSize=256
|
||||||
iters=30
|
itersPerNode=30
|
||||||
|
|
||||||
reader=[
|
reader=[
|
||||||
readerType="ImageReader"
|
readerType="ImageReader"
|
||||||
|
|
3
Makefile
3
Makefile
|
@ -439,7 +439,8 @@ EVAL:=eval
|
||||||
|
|
||||||
SGDLIB_SRC=\
|
SGDLIB_SRC=\
|
||||||
$(SOURCEDIR)/SGDLib/Profiler.cpp \
|
$(SOURCEDIR)/SGDLib/Profiler.cpp \
|
||||||
$(SOURCEDIR)/SGDLib/SGD.cpp
|
$(SOURCEDIR)/SGDLib/SGD.cpp \
|
||||||
|
$(SOURCEDIR)/SGDLib/PostComputingActions.cpp \
|
||||||
|
|
||||||
EVAL_SRC=\
|
EVAL_SRC=\
|
||||||
$(SOURCEDIR)/EvalDll/CNTKEval.cpp \
|
$(SOURCEDIR)/EvalDll/CNTKEval.cpp \
|
||||||
|
|
|
@ -42,11 +42,11 @@ template <typename ElemType>
|
||||||
void DoDumpNodes(const ConfigParameters& config);
|
void DoDumpNodes(const ConfigParameters& config);
|
||||||
template <typename ElemType>
|
template <typename ElemType>
|
||||||
void DoEdit(const ConfigParameters& config);
|
void DoEdit(const ConfigParameters& config);
|
||||||
|
template <typename ElemType>
|
||||||
|
void DoBatchNormalizationStat(const ConfigParameters& config);
|
||||||
|
|
||||||
// evaluation (EvalActions.cpp)
|
// evaluation (EvalActions.cpp)
|
||||||
template <typename ElemType>
|
template <typename ElemType>
|
||||||
void DoEvalBN(const ConfigParameters& config);
|
|
||||||
template <typename ElemType>
|
|
||||||
void DoEval(const ConfigParameters& config);
|
void DoEval(const ConfigParameters& config);
|
||||||
template <typename ElemType>
|
template <typename ElemType>
|
||||||
void DoCrossValidate(const ConfigParameters& config);
|
void DoCrossValidate(const ConfigParameters& config);
|
||||||
|
|
|
@ -78,62 +78,6 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
|
||||||
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
|
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===========================================================================
|
|
||||||
// DoEvalBNBase() - implements CNTK "pbn" command
|
|
||||||
// ===========================================================================
|
|
||||||
|
|
||||||
template <typename ElemType>
|
|
||||||
static void DoEvalBNBase(const ConfigParameters& config, IDataReader& reader)
|
|
||||||
{
|
|
||||||
// DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
|
||||||
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
|
||||||
size_t epochSize = config(L"epochSize", "0");
|
|
||||||
if (epochSize == 0)
|
|
||||||
{
|
|
||||||
epochSize = requestDataSize;
|
|
||||||
}
|
|
||||||
wstring modelPath = config(L"modelPath");
|
|
||||||
wstring exportPath = modelPath + L".PBN";
|
|
||||||
intargvector mbSize = minibatchSize;
|
|
||||||
|
|
||||||
int iters = config(L"iters", 240);
|
|
||||||
|
|
||||||
int traceLevel = config(L"traceLevel", "0");
|
|
||||||
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
|
||||||
size_t firstMBsToShowResult = config(L"firstMBsToShowResult", "0");
|
|
||||||
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
|
||||||
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
|
||||||
|
|
||||||
bool enableDistributedMBReading = config(L"distributedMBReading", false);
|
|
||||||
|
|
||||||
vector<wstring> evalNodeNamesVector;
|
|
||||||
|
|
||||||
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNamesVector);
|
|
||||||
|
|
||||||
// set tracing flags
|
|
||||||
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
|
||||||
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
|
||||||
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
|
||||||
|
|
||||||
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), enableDistributedMBReading, numMBsToShowResult,
|
|
||||||
firstMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
|
||||||
eval.EvaluateBN(&reader, evalNodeNamesVector, exportPath, mbSize[0], iters, epochSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ElemType>
|
|
||||||
void DoEvalBN(const ConfigParameters& config)
|
|
||||||
{
|
|
||||||
// evaluate batch normalization mean and various
|
|
||||||
ConfigParameters readerConfig(config(L"reader"));
|
|
||||||
|
|
||||||
// Should trace level to zero in Post BN?
|
|
||||||
//readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
|
||||||
|
|
||||||
DataReader evaBNDataReader(readerConfig);
|
|
||||||
|
|
||||||
DoEvalBNBase<ElemType>(config, evaBNDataReader);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ElemType>
|
template <typename ElemType>
|
||||||
void DoEval(const ConfigParameters& config)
|
void DoEval(const ConfigParameters& config)
|
||||||
{
|
{
|
||||||
|
@ -146,8 +90,6 @@ void DoEval(const ConfigParameters& config)
|
||||||
DoEvalBase<ElemType>(config, testDataReader);
|
DoEvalBase<ElemType>(config, testDataReader);
|
||||||
}
|
}
|
||||||
|
|
||||||
template void DoEvalBN<double>(const ConfigParameters& config);
|
|
||||||
template void DoEvalBN<float>(const ConfigParameters& config);
|
|
||||||
template void DoEval<double>(const ConfigParameters& config);
|
template void DoEval<double>(const ConfigParameters& config);
|
||||||
template void DoEval<float>(const ConfigParameters& config);
|
template void DoEval<float>(const ConfigParameters& config);
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "ScriptableObjects.h"
|
#include "ScriptableObjects.h"
|
||||||
#include "BrainScriptEvaluator.h"
|
#include "BrainScriptEvaluator.h"
|
||||||
#include "BrainScriptParser.h"
|
#include "BrainScriptParser.h"
|
||||||
|
#include "PostComputingActions.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
@ -235,3 +236,46 @@ void DoEdit(const ConfigParameters& config)
|
||||||
|
|
||||||
template void DoEdit<double>(const ConfigParameters& config);
|
template void DoEdit<double>(const ConfigParameters& config);
|
||||||
template void DoEdit<float>(const ConfigParameters& config);
|
template void DoEdit<float>(const ConfigParameters& config);
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// DoBatchNormalizationStat() - implements CNTK "bnstat" command
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
template <typename ElemType>
|
||||||
|
void DoBatchNormalizationStat(const ConfigParameters& config)
|
||||||
|
{
|
||||||
|
ConfigParameters readerConfig(config(L"reader"));
|
||||||
|
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||||
|
|
||||||
|
auto dataReader = make_shared<DataReader>(readerConfig);
|
||||||
|
|
||||||
|
int traceLevel = config(L"traceLevel", "0");
|
||||||
|
int itersPerNode = config(L"itersPerNode", 30);
|
||||||
|
|
||||||
|
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
||||||
|
intargvector mbSize = minibatchSize;
|
||||||
|
|
||||||
|
bool enableDistributedMBReading = config(L"enableDistributedMBReading", false);
|
||||||
|
|
||||||
|
wstring curModelPath = config(L"modelPath", L"");
|
||||||
|
wstring newModelPath = config(L"newModelPath", L"");
|
||||||
|
if (newModelPath == L"")
|
||||||
|
{
|
||||||
|
newModelPath = curModelPath + L".PBN";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::wstring> evalNodeNames;
|
||||||
|
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNames);
|
||||||
|
// set tracing flags
|
||||||
|
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
||||||
|
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
||||||
|
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
||||||
|
|
||||||
|
PostComputingActions<ElemType> postComputingActions(net, MPIWrapper::GetInstance(), enableDistributedMBReading, traceLevel);
|
||||||
|
|
||||||
|
postComputingActions.BatchNormalizationStatistics(dataReader.get(), evalNodeNames, newModelPath, mbSize[0], itersPerNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
template void DoBatchNormalizationStat<double>(const ConfigParameters& config);
|
||||||
|
template void DoBatchNormalizationStat<float>(const ConfigParameters& config);
|
||||||
|
|
||||||
|
|
|
@ -154,7 +154,7 @@ static void DisableLegacyUsage(const ConfigParameters& TopLevelConfig, const Con
|
||||||
|
|
||||||
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
|
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
|
||||||
// be run in parallel across multiple ranks. Others should only run on rank 0
|
// be run in parallel across multiple ranks. Others should only run on rank 0
|
||||||
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "pbn" };
|
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "bnstat" };
|
||||||
|
|
||||||
// process the command
|
// process the command
|
||||||
template <typename ElemType>
|
template <typename ElemType>
|
||||||
|
@ -243,9 +243,9 @@ void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mp
|
||||||
LOGPRINTF(stderr, "CNTKCommandTrainEnd: %s\n", command[i].c_str());
|
LOGPRINTF(stderr, "CNTKCommandTrainEnd: %s\n", command[i].c_str());
|
||||||
fullEpochsOffset += GetMaxEpochs(commandParams);
|
fullEpochsOffset += GetMaxEpochs(commandParams);
|
||||||
}
|
}
|
||||||
else if (thisAction == "pbn")
|
else if (thisAction == "bnstat")
|
||||||
{
|
{
|
||||||
DoEvalBN<ElemType>(commandParams);
|
DoBatchNormalizationStat<ElemType>(commandParams);
|
||||||
}
|
}
|
||||||
else if (thisAction == "adapt")
|
else if (thisAction == "adapt")
|
||||||
{
|
{
|
||||||
|
|
|
@ -136,10 +136,6 @@ public:
|
||||||
// main entry point for backprop
|
// main entry point for backprop
|
||||||
void Backprop(const ComputationNodeBasePtr rootNode);
|
void Backprop(const ComputationNodeBasePtr rootNode);
|
||||||
|
|
||||||
// partial forward entry
|
|
||||||
void ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode,
|
|
||||||
const ComputationNodeBasePtr endNode);
|
|
||||||
|
|
||||||
template <class NODESET> // version that takes multiple nodes
|
template <class NODESET> // version that takes multiple nodes
|
||||||
void ForwardProp(const NODESET& nodes)
|
void ForwardProp(const NODESET& nodes)
|
||||||
{
|
{
|
||||||
|
@ -678,6 +674,44 @@ public:
|
||||||
return nodesWithType;
|
return nodesWithType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the eval nodes with names
|
||||||
|
// if evalNodeNames are not specified, return all the default evalnodes and training criterion nodes.
|
||||||
|
std::vector<ComputationNodeBasePtr> GetEvalNodesWithName(const std::vector<wstring> evalNodeNames)
|
||||||
|
{
|
||||||
|
// determine nodes to evaluate
|
||||||
|
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||||
|
|
||||||
|
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
||||||
|
if (evalNodeNames.size() == 0)
|
||||||
|
{
|
||||||
|
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||||
|
if (EvaluationNodes().empty() && FinalCriterionNodes().empty())
|
||||||
|
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
||||||
|
|
||||||
|
for (const auto& node : EvaluationNodes())
|
||||||
|
if (criteriaLogged.insert(node).second)
|
||||||
|
evalNodes.push_back(node);
|
||||||
|
|
||||||
|
for (const auto& node : FinalCriterionNodes())
|
||||||
|
if (criteriaLogged.insert(node).second)
|
||||||
|
evalNodes.push_back(node);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i < evalNodeNames.size(); i++)
|
||||||
|
{
|
||||||
|
const auto& node = GetNodeFromName(evalNodeNames[i]);
|
||||||
|
if (!criteriaLogged.insert(node).second)
|
||||||
|
continue;
|
||||||
|
if (node->GetSampleLayout().GetNumElements() != 1)
|
||||||
|
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
||||||
|
evalNodes.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return evalNodes;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// return list of nodes that require precomputation and not precomputed yet
|
// return list of nodes that require precomputation and not precomputed yet
|
||||||
std::list<ComputationNodeBasePtr> GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true);
|
std::list<ComputationNodeBasePtr> GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true);
|
||||||
|
@ -1014,7 +1048,7 @@ protected:
|
||||||
virtual const std::wstring OperationName() const override
|
virtual const std::wstring OperationName() const override
|
||||||
{
|
{
|
||||||
return L"PARTraversalFlowControlNode";
|
return L"PARTraversalFlowControlNode";
|
||||||
}
|
}
|
||||||
virtual void BeginForwardProp() override
|
virtual void BeginForwardProp() override
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
@ -1038,8 +1072,6 @@ protected:
|
||||||
virtual void AllocateGradientMatricesForInputs(MatrixPool& matrixPool);
|
virtual void AllocateGradientMatricesForInputs(MatrixPool& matrixPool);
|
||||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool);
|
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool);
|
||||||
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool);
|
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool);
|
||||||
|
|
||||||
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) override;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// this special constructor constructs the top-level network node
|
// this special constructor constructs the top-level network node
|
||||||
|
|
|
@ -79,17 +79,6 @@ void ComputationNetwork::Backprop(const ComputationNodeBasePtr rootNode) // trai
|
||||||
GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true);
|
GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComputationNetwork::ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode, const ComputationNodeBasePtr endNode)
|
|
||||||
{
|
|
||||||
VerifyIsCompiled("ForwardProp");
|
|
||||||
|
|
||||||
// traverse partial nodes as inputs
|
|
||||||
shared_ptr<FlowControlNode> network = dynamic_pointer_cast<FlowControlNode>(GetNestedNetwork(rootNode));
|
|
||||||
assert(network);
|
|
||||||
|
|
||||||
network->ForwardProp(FrameRange(nullptr), startNode, endNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode)
|
void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode)
|
||||||
{
|
{
|
||||||
if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end())
|
if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end())
|
||||||
|
@ -158,7 +147,6 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
|
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
|
||||||
{
|
{
|
||||||
childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
|
childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
|
||||||
|
@ -187,36 +175,7 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
||||||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
|
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const FrameRange & fr, ComputationNodeBasePtr startNode, ComputationNodeBasePtr endNode)
|
|
||||||
{
|
|
||||||
// if start node is nullptr, forward will be enable
|
|
||||||
bool enableForward = startNode ? false : true;
|
|
||||||
|
|
||||||
for (auto& node : m_nestedNodes)
|
|
||||||
{
|
|
||||||
#if 0
|
|
||||||
if (dynamic_pointer_cast<LearnableParameter<float>>(node))
|
|
||||||
dynamic_pointer_cast<ComputationNode<float>>(node)->DebugLogMinibatch();
|
|
||||||
#endif
|
|
||||||
if (node->IsOutOfDateWrtInputs() && enableForward)
|
|
||||||
{
|
|
||||||
node->BeginForwardProp();
|
|
||||||
node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
|
|
||||||
node->EndForwardProp();
|
|
||||||
|
|
||||||
node->BumpEvalTimeStamp();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node == startNode)
|
|
||||||
{
|
|
||||||
enableForward = true;
|
|
||||||
}
|
|
||||||
else if (node == endNode)
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling)
|
// SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling)
|
||||||
|
|
|
@ -1878,10 +1878,6 @@ public:
|
||||||
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const override {}
|
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const override {}
|
||||||
virtual std::set<std::pair<const MatrixBase*, std::wstring>> GetMatrixInfo() const override { NOT_IMPLEMENTED; }
|
virtual std::set<std::pair<const MatrixBase*, std::wstring>> GetMatrixInfo() const override { NOT_IMPLEMENTED; }
|
||||||
|
|
||||||
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) { NOT_IMPLEMENTED; }
|
|
||||||
|
|
||||||
std::vector<ComputationNodeBasePtr> GetNestedNodes() { return m_nestedNodes; }
|
|
||||||
|
|
||||||
protected: public: // needed in ComputationNetwork::FindInRecurrentLoops(), which really should be part of SEQTraversalFlowControlNode
|
protected: public: // needed in ComputationNetwork::FindInRecurrentLoops(), which really should be part of SEQTraversalFlowControlNode
|
||||||
std::vector<ComputationNodeBasePtr> m_nestedNodes; // nodes tucked away in this node, in evaluation order
|
std::vector<ComputationNodeBasePtr> m_nestedNodes; // nodes tucked away in this node, in evaluation order
|
||||||
};
|
};
|
||||||
|
|
|
@ -37,6 +37,7 @@ public:
|
||||||
MarkValueNonSharable();
|
MarkValueNonSharable();
|
||||||
m_initString = L"fromValue"; // default init is with 0; typically overwritten
|
m_initString = L"fromValue"; // default init is with 0; typically overwritten
|
||||||
m_initValue = 0;
|
m_initValue = 0;
|
||||||
|
m_regMultiplier = 1.0f; // enable reg in update by default
|
||||||
}
|
}
|
||||||
LearnableParameter(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& shape) :
|
LearnableParameter(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& shape) :
|
||||||
LearnableParameter(deviceId, name)
|
LearnableParameter(deviceId, name)
|
||||||
|
@ -101,6 +102,14 @@ public:
|
||||||
// called from CloneFunction(..., parameters="constant")
|
// called from CloneFunction(..., parameters="constant")
|
||||||
virtual void FreezeParameters() override; // from IFreezable
|
virtual void FreezeParameters() override; // from IFreezable
|
||||||
|
|
||||||
|
// Setting the reg multiplier for a learnable node, effecting L1Reg and L2Reg both.
|
||||||
|
void SetRegMultiplier(float regMultiplier)
|
||||||
|
{
|
||||||
|
m_regMultiplier = regMultiplier;
|
||||||
|
}
|
||||||
|
// called from SGD UpdateWeights, to adjust the reg for each node
|
||||||
|
float GetRegMultiplier() const { return m_regMultiplier; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// init parameters for deferred initialization (which happens in Validate())
|
// init parameters for deferred initialization (which happens in Validate())
|
||||||
std::wstring m_initString; // if non-empty then deferred initialization is needed. Gets cleared upon completion of deferred init.
|
std::wstring m_initString; // if non-empty then deferred initialization is needed. Gets cleared upon completion of deferred init.
|
||||||
|
@ -109,6 +118,9 @@ private:
|
||||||
int m_initOutputRank;
|
int m_initOutputRank;
|
||||||
bool m_initOnCPUOnly;
|
bool m_initOnCPUOnly;
|
||||||
ElemType m_initValue;
|
ElemType m_initValue;
|
||||||
|
|
||||||
|
// flags related to gradient update
|
||||||
|
float m_regMultiplier; // The multiplier to adjust the L1Reg and L2Reg for Learnable node
|
||||||
};
|
};
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
#include "ComputationNode.h"
|
#include "ComputationNode.h"
|
||||||
#include "BatchNormalizationEngine.h"
|
#include "BatchNormalizationEngine.h"
|
||||||
#include "RNGHandle.h"
|
#include "RNGHandle.h"
|
||||||
|
#include "InputAndParamNodes.h"
|
||||||
|
|
||||||
#define __STDC_FORMAT_MACROS
|
#define __STDC_FORMAT_MACROS
|
||||||
#include <inttypes.h>
|
#include <inttypes.h>
|
||||||
|
@ -1587,15 +1588,15 @@ class BatchNormalizationNode : public ComputationNodeNonLooping<ElemType>, publi
|
||||||
public:
|
public:
|
||||||
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) :
|
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) :
|
||||||
Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true),
|
Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true),
|
||||||
m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW), m_postBatchNormalization(false), m_swapNormTimeConst(0),
|
m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW),
|
||||||
m_swapBlendTimeConst(0), m_convertRunningVariancePending(false)
|
m_convertRunningVariancePending(false)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant,
|
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant,
|
||||||
double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) :
|
double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) :
|
||||||
Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant),
|
Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant),
|
||||||
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0), m_postBatchNormalization(false),
|
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0),
|
||||||
m_swapNormTimeConst(0), m_swapBlendTimeConst(0), m_convertRunningVariancePending(false)
|
m_convertRunningVariancePending(false)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||||
|
@ -1605,9 +1606,6 @@ public:
|
||||||
ImageLayoutKindFrom(configp->Get(L"imageLayout")))
|
ImageLayoutKindFrom(configp->Get(L"imageLayout")))
|
||||||
{
|
{
|
||||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||||
m_postBatchNormalization = false;
|
|
||||||
m_swapNormTimeConst = 0;
|
|
||||||
m_swapBlendTimeConst = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Save(File& fstream) const override
|
void Save(File& fstream) const override
|
||||||
|
@ -1724,7 +1722,7 @@ private: // time-constant conversions
|
||||||
double ComputeExpAvgFactor() const
|
double ComputeExpAvgFactor() const
|
||||||
{
|
{
|
||||||
// in inference mode, only use long-term mean and do not update running estimates
|
// in inference mode, only use long-term mean and do not update running estimates
|
||||||
if (!Environment().IsTraining() && !m_postBatchNormalization)
|
if (!Environment().IsTraining())
|
||||||
{
|
{
|
||||||
if (m_samplesSeen == 0)
|
if (m_samplesSeen == 0)
|
||||||
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
||||||
|
@ -1756,7 +1754,7 @@ private: // time-constant conversions
|
||||||
double ComputeBlendFactor() const
|
double ComputeBlendFactor() const
|
||||||
{
|
{
|
||||||
// in inference mode, only use long-term mean and do not update running estimates
|
// in inference mode, only use long-term mean and do not update running estimates
|
||||||
if (!Environment().IsTraining() && !m_postBatchNormalization)
|
if (!Environment().IsTraining())
|
||||||
{
|
{
|
||||||
if (m_samplesSeen == 0)
|
if (m_samplesSeen == 0)
|
||||||
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
||||||
|
@ -1805,7 +1803,7 @@ public:
|
||||||
// In inference-only mode, m_savedMean and m_saveInvStdDev will not be
|
// In inference-only mode, m_savedMean and m_saveInvStdDev will not be
|
||||||
// produced and BackpropToNonLooping() may not be called. In
|
// produced and BackpropToNonLooping() may not be called. In
|
||||||
// non-inference (training) mode, saved statistics must be produced.
|
// non-inference (training) mode, saved statistics must be produced.
|
||||||
bool inferenceOnly = !Environment().IsTraining() && !m_postBatchNormalization;
|
bool inferenceOnly = !Environment().IsTraining();
|
||||||
m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in)
|
m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in)
|
||||||
inferenceOnly, expAvgFactor, blendFactor,
|
inferenceOnly, expAvgFactor, blendFactor,
|
||||||
runMean, runVariance, // (in/out) running estimates, updated from the current MB mean/variance
|
runMean, runVariance, // (in/out) running estimates, updated from the current MB mean/variance
|
||||||
|
@ -1870,14 +1868,6 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void EndForwardProp() override
|
virtual void EndForwardProp() override
|
||||||
{
|
|
||||||
if(m_postBatchNormalization)
|
|
||||||
m_samplesSeen += GetMBLayout()->GetActualNumSamples();
|
|
||||||
|
|
||||||
Base::EndForwardProp();
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void EndBackprop() override
|
|
||||||
{
|
{
|
||||||
// Update samples if not locked.
|
// Update samples if not locked.
|
||||||
double expAvgFactor = ComputeExpAvgFactor(); // weight for the new MB statistics in the running estimate. The previous value of the running statistics is kept with weight (1-this)
|
double expAvgFactor = ComputeExpAvgFactor(); // weight for the new MB statistics in the running estimate. The previous value of the running statistics is kept with weight (1-this)
|
||||||
|
@ -2019,28 +2009,29 @@ public:
|
||||||
m_blendTimeConst = std::numeric_limits<double>::infinity();
|
m_blendTimeConst = std::numeric_limits<double>::infinity();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResetStatisticsState will set the batch normal statistics into initial state
|
||||||
|
// used for re-statistics the mean and variance of BN
|
||||||
|
// any others use may lead undependable results, please be careful
|
||||||
|
void ResetStatisticsState()
|
||||||
|
{
|
||||||
|
m_samplesSeen = 0;
|
||||||
|
m_normTimeConst = 0;
|
||||||
|
m_blendTimeConst = 0;
|
||||||
|
}
|
||||||
|
// Turn off the L1 and L2 regularization
|
||||||
|
void DisableRegInBatchNormalization()
|
||||||
|
{
|
||||||
|
let scaleNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(1));
|
||||||
|
let biasNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(2));
|
||||||
|
scaleNode->SetRegMultiplier(0.f);
|
||||||
|
biasNode->SetRegMultiplier(0.f);
|
||||||
|
}
|
||||||
double NormalizationTimeConstant() const { return m_normTimeConst; }
|
double NormalizationTimeConstant() const { return m_normTimeConst; }
|
||||||
double BlendTimeConstant() const { return m_blendTimeConst; }
|
double BlendTimeConstant() const { return m_blendTimeConst; }
|
||||||
bool Spatial() const { return m_spatial; }
|
bool Spatial() const { return m_spatial; }
|
||||||
double Epsilon() const { return m_epsilon; }
|
double Epsilon() const { return m_epsilon; }
|
||||||
bool UseCNTKEngine() const { return m_useCntkEngine; }
|
bool UseCNTKEngine() const { return m_useCntkEngine; }
|
||||||
|
|
||||||
void SetPostBatchNormalizationBegin()
|
|
||||||
{
|
|
||||||
m_postBatchNormalization = true;
|
|
||||||
m_samplesSeen = 0;
|
|
||||||
m_swapNormTimeConst = m_normTimeConst;
|
|
||||||
m_swapBlendTimeConst = m_blendTimeConst;
|
|
||||||
m_normTimeConst = -1;
|
|
||||||
m_blendTimeConst = 0;
|
|
||||||
}
|
|
||||||
void SetPostBatchNormalizationEnd()
|
|
||||||
{
|
|
||||||
m_postBatchNormalization = false;
|
|
||||||
m_normTimeConst = m_swapNormTimeConst;
|
|
||||||
m_blendTimeConst = m_swapBlendTimeConst;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Old versioning - do not use. Do not remove until we're sure there are no old models around.
|
// Old versioning - do not use. Do not remove until we're sure there are no old models around.
|
||||||
struct VersionInfo
|
struct VersionInfo
|
||||||
|
@ -2104,11 +2095,6 @@ private:
|
||||||
|
|
||||||
std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng;
|
std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng;
|
||||||
|
|
||||||
// post batch normalization process mark
|
|
||||||
bool m_postBatchNormalization;
|
|
||||||
|
|
||||||
double m_swapNormTimeConst;
|
|
||||||
double m_swapBlendTimeConst;
|
|
||||||
bool m_convertRunningVariancePending;
|
bool m_convertRunningVariancePending;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||||
|
//
|
||||||
|
// PostStat.cpp -- CNTK post statistics related actions
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "PostComputingActions.h"
|
||||||
|
|
||||||
|
#include "TrainingNodes.h"
|
||||||
|
#include "ProgressTracing.h"
|
||||||
|
#include "DataReaderHelpers.h"
|
||||||
|
#include "SimpleDistGradAggregator.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace Microsoft { namespace MSR{ namespace CNTK {
|
||||||
|
|
||||||
|
template <class ElemType>
|
||||||
|
void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader * dataReader, const vector<wstring>& evalNodeNames,
|
||||||
|
const wstring newModelPath, const size_t mbSize, const int iters)
|
||||||
|
{
|
||||||
|
// since the mean and variance of bn will be modified in statistics,
|
||||||
|
// training mode will make it work. And there is no back prop, other parameters
|
||||||
|
// are fixed during computing.
|
||||||
|
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::training);
|
||||||
|
|
||||||
|
// bn nodes need to be computed from bottom to top with evaluating order
|
||||||
|
let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
|
||||||
|
|
||||||
|
// find all the BN nodes by evalOrder
|
||||||
|
std::vector<ComputationNodeBasePtr> bnNodes;
|
||||||
|
std::set<ComputationNodeBasePtr> bnNodesLogged; // (avoid double record of batch normalization nodes)
|
||||||
|
for (auto& evalNode : evalNodes)
|
||||||
|
{
|
||||||
|
for (auto& node : m_net->GetEvalOrder(evalNode))
|
||||||
|
{
|
||||||
|
let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||||
|
if (bnNode)
|
||||||
|
{
|
||||||
|
if (bnNodesLogged.insert(node).second)
|
||||||
|
{
|
||||||
|
// reset the statistics states of bn nodes
|
||||||
|
bnNode->ResetStatisticsState();
|
||||||
|
bnNode->SetNormalizationTimeConstants(-1, bnNode->NormalizationTimeConstant(),
|
||||||
|
0, bnNode->BlendTimeConstant());
|
||||||
|
bnNodes.push_back(node);
|
||||||
|
// add BN nodes into the evaluation group, then they will be added into root nodes when
|
||||||
|
// the network re-compile
|
||||||
|
m_net->AddToNodeGroup(L"evaluation", bnNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// re-compile the network to add bn nodes as rootNodes.
|
||||||
|
m_net->CompileNetwork();
|
||||||
|
|
||||||
|
// allocate memory for all bnNodes evalOrder
|
||||||
|
m_net->AllocateAllMatrices(bnNodes, std::vector<ComputationNodeBasePtr>(), nullptr);
|
||||||
|
|
||||||
|
// prepare features
|
||||||
|
auto& featureNodes = m_net->FeatureNodes();
|
||||||
|
|
||||||
|
StreamMinibatchInputs inputMatrices;
|
||||||
|
for (auto& node : featureNodes)
|
||||||
|
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||||
|
|
||||||
|
bool useParallelTrain = (m_mpi != nullptr);
|
||||||
|
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
|
||||||
|
size_t totalEpochSize = bnNodes.size() * mbSize * iters;
|
||||||
|
|
||||||
|
m_net->StartEvaluateMinibatchLoop(bnNodes);
|
||||||
|
|
||||||
|
if (useDistributedMBReading)
|
||||||
|
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), totalEpochSize);
|
||||||
|
else
|
||||||
|
dataReader->StartMinibatchLoop(mbSize, 0, totalEpochSize);
|
||||||
|
|
||||||
|
for (auto& node : bnNodes)
|
||||||
|
{
|
||||||
|
let bnNode = static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||||
|
size_t actualMBSize = 0;
|
||||||
|
|
||||||
|
LOGPRINTF(stderr, "Estimating Statistics --> %ls\n", bnNode->GetName().c_str());
|
||||||
|
|
||||||
|
|
||||||
|
// for every single bn node, the statistics is the average of mean and variance for several times in forward prop
|
||||||
|
// the forward prop is from the feature to the current bn node
|
||||||
|
for (int iter = 0; iter < iters; iter++)
|
||||||
|
{
|
||||||
|
// during the bn stat, dataRead must be ensured
|
||||||
|
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
|
||||||
|
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||||
|
|
||||||
|
if (!wasDataRead) LogicError("DataRead Failure in batch normalization statistics");
|
||||||
|
|
||||||
|
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||||
|
|
||||||
|
// forward prop till reaching the current bn node
|
||||||
|
m_net->ForwardProp(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// after finished statistics, the mean and variance of the bn node should be freezd.
|
||||||
|
bnNode->FreezeParameters();
|
||||||
|
|
||||||
|
// Sync during or after all iters of a BN node are equivalent
|
||||||
|
if (useParallelTrain)
|
||||||
|
{
|
||||||
|
if (m_gradHeader == nullptr)
|
||||||
|
{
|
||||||
|
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
|
||||||
|
{
|
||||||
|
DistGradHeader::Destroy(ptr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// push the statistics results of mean and variance of bn nodes into mpi updating vector
|
||||||
|
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
|
||||||
|
|
||||||
|
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
|
||||||
|
|
||||||
|
auto runMeanParameterPtr = node->Input(3);
|
||||||
|
auto runStdParameterPtr = node->Input(4);
|
||||||
|
|
||||||
|
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
|
||||||
|
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
|
||||||
|
|
||||||
|
learnParamsValues[0] = &(runMeanNode->Value());
|
||||||
|
learnParamsValues[1] = &(runStdNode->Value());
|
||||||
|
|
||||||
|
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
|
||||||
|
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), 0);
|
||||||
|
|
||||||
|
// get the average mean and variance across all the workers
|
||||||
|
for (auto& parameter : learnParamsValues)
|
||||||
|
{
|
||||||
|
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dataReader->DataEnd();
|
||||||
|
|
||||||
|
// remove all the added BN nodes from evaluation group
|
||||||
|
for (auto& bnNode : bnNodes)
|
||||||
|
{
|
||||||
|
m_net->RemoveFromNodeGroup(L"evaluation", bnNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
// save model
|
||||||
|
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
|
||||||
|
m_net->Save(newModelPath);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template class PostComputingActions<float>;
|
||||||
|
template class PostComputingActions<double>;
|
||||||
|
|
||||||
|
}}}
|
|
@ -0,0 +1,65 @@
|
||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||||
|
//
|
||||||
|
// PostStat.h -- CNTK post statistics related actions
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "ComputationNode.h"
|
||||||
|
#include "ComputationNetwork.h"
|
||||||
|
#include "MPIWrapper.h"
|
||||||
|
#include "DataReader.h"
|
||||||
|
#include "IDistGradAggregator.h"
|
||||||
|
#include "DistGradHeader.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
|
template <class ElemType>
|
||||||
|
class IDistGradAggregator;
|
||||||
|
|
||||||
|
// Post statistics normally called between training and evaluating, to generate the statistics results used by evaluating
|
||||||
|
// For now, the application is only with statistics mean and variance of Batch Normalization nodes after training
|
||||||
|
template <class ElemType>
|
||||||
|
class PostComputingActions
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
PostComputingActions(ComputationNetworkPtr net, const MPIWrapperPtr& mpi, bool enableDistributedMBReading = false, const int traceLevel = 0) :
|
||||||
|
m_net(net),
|
||||||
|
m_traceLevel(traceLevel),
|
||||||
|
m_mpi(mpi),
|
||||||
|
m_distGradAgg(nullptr),
|
||||||
|
m_gradHeader(nullptr),
|
||||||
|
m_enableDistributedMBReading(enableDistributedMBReading)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function is used for evaluating the mean and variance of all batch normalization nodes after training.
|
||||||
|
// Details will link to the wiki https://github.com/Microsoft/CNTK/wiki/Post-Batch-Normalization-Statistics
|
||||||
|
// The reason why put it into evalute is the action take place after trainning and non-backprop processing, which makes me believe
|
||||||
|
// this function is like a kind of evaluate function.
|
||||||
|
// In this function,
|
||||||
|
// 1. since all other weights are fix except the un-pbn nodes, I set the networkoperationMode into inferring.
|
||||||
|
// 2. The next thing is to load the network model and data source, I follow the Evaluate function to do so, however, I delete something
|
||||||
|
// seem useless, like error statistics etc.
|
||||||
|
// 3. Finding the BN nodes in the network and put them into a vector with evaluate order (This links the nestedNode vector I got in
|
||||||
|
// ControlFlowNetwork)
|
||||||
|
// 4. From node to node in the BN vector to generate the mean and various (This links to the changes of BatchNormalizationNode
|
||||||
|
// in TrainingNodes.h, since I need to make the nodes "learn" mean and variance in inferring mode)
|
||||||
|
// 5. Consider the multi-GPU, we need to sync up the BN results between all the worker and average the value.
|
||||||
|
void BatchNormalizationStatistics(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring newModelPath,
|
||||||
|
const size_t mbSize, const int iters = 30);
|
||||||
|
|
||||||
|
private:
|
||||||
|
ComputationNetworkPtr m_net;
|
||||||
|
MPIWrapperPtr m_mpi;
|
||||||
|
bool m_enableDistributedMBReading;
|
||||||
|
|
||||||
|
int m_traceLevel;
|
||||||
|
|
||||||
|
std::shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg;
|
||||||
|
std::shared_ptr<struct DistGradHeader> m_gradHeader;
|
||||||
|
};
|
||||||
|
}}}
|
|
@ -8,6 +8,7 @@
|
||||||
#include "SpecialPurposeNodes.h" // for SequenceWithSoftmaxNode
|
#include "SpecialPurposeNodes.h" // for SequenceWithSoftmaxNode
|
||||||
#include "DataReaderHelpers.h"
|
#include "DataReaderHelpers.h"
|
||||||
#include "MatrixQuantizerImpl.h"
|
#include "MatrixQuantizerImpl.h"
|
||||||
|
#include "InputAndParamNodes.h"
|
||||||
|
|
||||||
#ifdef CNTK_PARALLEL_TRAINING_SUPPORT
|
#ifdef CNTK_PARALLEL_TRAINING_SUPPORT
|
||||||
//static inline bool operator==(const std::pair<double,size_t>& a, double b) { assert(b==0); return a.first == b; }
|
//static inline bool operator==(const std::pair<double,size_t>& a, double b) { assert(b==0); return a.first == b; }
|
||||||
|
@ -875,23 +876,13 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
EpochCriterion epochCriterionLastLogged = epochCriterion;
|
EpochCriterion epochCriterionLastLogged = epochCriterion;
|
||||||
vector<EpochCriterion> epochEvalErrorsLastLogged = epochEvalErrors;
|
vector<EpochCriterion> epochEvalErrorsLastLogged = epochEvalErrors;
|
||||||
|
|
||||||
// Now, we need to use a switch to enable/disable wk in BatchNormalization.
|
// NOTE: For ResNet, the regularization in BatchNormalization should be disable.
|
||||||
// If we can determine whether wk added or not for each node, then, discard this
|
if (m_disableRegInBatchNormalization) {
|
||||||
std::unordered_set<ComputationNodeBasePtr> batchNormalizationWeights;
|
let bnNodes = net->GetNodesWithType(L"BatchNormalization");
|
||||||
if (m_disableWkInBatchNormal) {
|
for (auto &node : bnNodes)
|
||||||
for (auto& evalNode : evaluationNodes)
|
|
||||||
{
|
{
|
||||||
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(net->GetNestedNetwork(evalNode));
|
let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||||
for (auto& node : nestedNetwork->GetNestedNodes())
|
bnNode->DisableRegInBatchNormalization();
|
||||||
{
|
|
||||||
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
|
|
||||||
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
|
||||||
if (castNode)
|
|
||||||
{
|
|
||||||
batchNormalizationWeights.insert(castNode->GetInputs()[1]);
|
|
||||||
batchNormalizationWeights.insert(castNode->GetInputs()[2]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1110,11 +1101,10 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
||||||
if (smoothedGradient.HasNan("TrainOneEpoch/UpdateWeights(): "))
|
if (smoothedGradient.HasNan("TrainOneEpoch/UpdateWeights(): "))
|
||||||
LogicError("%ls %ls operation has NaNs in smoothedGradient.", node->NodeName().c_str(), node->OperationName().c_str());
|
LogicError("%ls %ls operation has NaNs in smoothedGradient.", node->NodeName().c_str(), node->OperationName().c_str());
|
||||||
#endif
|
#endif
|
||||||
double l2Factor = batchNormalizationWeights.find(node) == batchNormalizationWeights.end() ? 1.0 : 0.0;
|
|
||||||
// BUGBUG (Issue #95): Access to net MBLayout can no longer be done if we have multiple input layouts
|
// BUGBUG (Issue #95): Access to net MBLayout can no longer be done if we have multiple input layouts
|
||||||
UpdateWeights(node, smoothedGradient, learnRatePerSample,
|
UpdateWeights(node, smoothedGradient, learnRatePerSample,
|
||||||
GetMomentumPerSample(epochNumber /*BUGBUG workaround:*/, net->GetMBLayoutPtrOfNetwork()->GetNumParallelSequences()), numSamplesInMinibatch,
|
GetMomentumPerSample(epochNumber /*BUGBUG workaround:*/, net->GetMBLayoutPtrOfNetwork()->GetNumParallelSequences()), numSamplesInMinibatch,
|
||||||
m_L2RegWeight * l2Factor, m_L1RegWeight,
|
m_L2RegWeight, m_L1RegWeight,
|
||||||
m_needAveMultiplier, m_useNesterovMomentum);
|
m_needAveMultiplier, m_useNesterovMomentum);
|
||||||
#ifdef _DEBUG
|
#ifdef _DEBUG
|
||||||
if (dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value().HasNan("TrainOneEpoch/UpdateWeights(): "))
|
if (dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value().HasNan("TrainOneEpoch/UpdateWeights(): "))
|
||||||
|
@ -2017,9 +2007,10 @@ void SGD<ElemType>::UpdateWeights(const ComputationNodeBasePtr& node,
|
||||||
LogicError("UpdateWeights() called for a learnable ComputationNode which has m_learningRateMultiplier == 0!");
|
LogicError("UpdateWeights() called for a learnable ComputationNode which has m_learningRateMultiplier == 0!");
|
||||||
|
|
||||||
double nodeDependentLearningRatePerSample = learnRatePerSample * node->GetLearningRateMultiplier();
|
double nodeDependentLearningRatePerSample = learnRatePerSample * node->GetLearningRateMultiplier();
|
||||||
|
double nodeDependentRegMultiplier = dynamic_pointer_cast<LearnableParameter<ElemType>>(node)->GetRegMultiplier();
|
||||||
UpdateWeightsS(this, dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value(), dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Gradient(),
|
UpdateWeightsS(this, dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value(), dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Gradient(),
|
||||||
smoothedGradient, nodeDependentLearningRatePerSample, momentumPerSample,
|
smoothedGradient, nodeDependentLearningRatePerSample, momentumPerSample,
|
||||||
actualMBSize, L2RegWeight, L1RegWeight,
|
actualMBSize, L2RegWeight * nodeDependentRegMultiplier, L1RegWeight * nodeDependentRegMultiplier,
|
||||||
needAveMultiplier, m_useNesterovMomentum);
|
needAveMultiplier, m_useNesterovMomentum);
|
||||||
node->BumpEvalTimeStamp();
|
node->BumpEvalTimeStamp();
|
||||||
}
|
}
|
||||||
|
@ -2475,7 +2466,7 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
||||||
m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0);
|
m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0);
|
||||||
m_seqGammarCalcWP = configSGD(L"seqGammarWordPen", 0.0);
|
m_seqGammarCalcWP = configSGD(L"seqGammarWordPen", 0.0);
|
||||||
|
|
||||||
m_disableWkInBatchNormal = configSGD(L"disableWkInBatchNormal", false);
|
m_disableRegInBatchNormalization = configSGD(L"disableRegInBatchNormalization", false);
|
||||||
|
|
||||||
m_dropoutRates = configSGD(L"dropoutRate", ConfigRecordType::Array(doubleargvector(vector<double>{0.0})));
|
m_dropoutRates = configSGD(L"dropoutRate", ConfigRecordType::Array(doubleargvector(vector<double>{0.0})));
|
||||||
m_batchNormalizationTimeConstant = configSGD(L"batchNormalizationTimeConstant", ConfigRecordType::Array(doubleargvector(vector<double>{0})));
|
m_batchNormalizationTimeConstant = configSGD(L"batchNormalizationTimeConstant", ConfigRecordType::Array(doubleargvector(vector<double>{0})));
|
||||||
|
|
|
@ -291,7 +291,10 @@ protected:
|
||||||
double m_seqGammarCalcbMMIFactor;
|
double m_seqGammarCalcbMMIFactor;
|
||||||
bool m_seqGammarCalcUsesMBR;
|
bool m_seqGammarCalcUsesMBR;
|
||||||
|
|
||||||
bool m_disableWkInBatchNormal;
|
// decide whether should apply L2 regularization into BatchNormalizationNode
|
||||||
|
// true: disable L2 Regularization
|
||||||
|
// false: enable L2 Regularization (default)
|
||||||
|
bool m_disableRegInBatchNormalization;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class ElemType>
|
template <class ElemType>
|
||||||
|
|
|
@ -124,6 +124,7 @@
|
||||||
<ClInclude Include="..\ComputationNetworkLib\NonlinearityNodes.h" />
|
<ClInclude Include="..\ComputationNetworkLib\NonlinearityNodes.h" />
|
||||||
<ClInclude Include="..\ComputationNetworkLib\RecurrentNodes.h" />
|
<ClInclude Include="..\ComputationNetworkLib\RecurrentNodes.h" />
|
||||||
<ClInclude Include="MASGD.h" />
|
<ClInclude Include="MASGD.h" />
|
||||||
|
<ClInclude Include="PostComputingActions.h" />
|
||||||
<ClInclude Include="SimpleDistGradAggregator.h" />
|
<ClInclude Include="SimpleDistGradAggregator.h" />
|
||||||
<ClInclude Include="SimpleEvaluator.h" />
|
<ClInclude Include="SimpleEvaluator.h" />
|
||||||
<ClInclude Include="SimpleOutputWriter.h" />
|
<ClInclude Include="SimpleOutputWriter.h" />
|
||||||
|
@ -132,10 +133,11 @@
|
||||||
<ClInclude Include="targetver.h" />
|
<ClInclude Include="targetver.h" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
<ClCompile Include="PostComputingActions.cpp" />
|
||||||
<ClCompile Include="Profiler.cpp" />
|
<ClCompile Include="Profiler.cpp" />
|
||||||
<ClCompile Include="SGD.cpp" />
|
<ClCompile Include="SGD.cpp" />
|
||||||
<ClCompile Include="stdafx.cpp" />
|
<ClCompile Include="stdafx.cpp" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||||
<ImportGroup Label="ExtensionTargets" />
|
<ImportGroup Label="ExtensionTargets" />
|
||||||
</Project>
|
</Project>
|
|
@ -1,32 +1,17 @@
|
||||||
<?xml version="1.0" encoding="utf-8"?>
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<ClCompile Include="..\Common\DataReader.cpp">
|
|
||||||
<Filter>Common</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="..\Common\DataWriter.cpp">
|
|
||||||
<Filter>Common</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="..\Common\File.cpp">
|
|
||||||
<Filter>Common</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="..\Common\fileutil.cpp">
|
|
||||||
<Filter>Common</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="stdafx.cpp">
|
<ClCompile Include="stdafx.cpp">
|
||||||
<Filter>Misc</Filter>
|
<Filter>Misc</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="..\Common\TimerUtility.cpp">
|
|
||||||
<Filter>Common</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="Profiler.cpp">
|
<ClCompile Include="Profiler.cpp">
|
||||||
<Filter>GPU Interfacing</Filter>
|
<Filter>GPU Interfacing</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="SGD.cpp">
|
<ClCompile Include="SGD.cpp">
|
||||||
<Filter>SGD</Filter>
|
<Filter>SGD</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="..\Common\Config.cpp">
|
<ClCompile Include="PostComputingActions.cpp">
|
||||||
<Filter>Common</Filter>
|
<Filter>Stat</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
@ -144,6 +129,9 @@
|
||||||
<ClInclude Include="Criterion.h">
|
<ClInclude Include="Criterion.h">
|
||||||
<Filter>SGD</Filter>
|
<Filter>SGD</Filter>
|
||||||
</ClInclude>
|
</ClInclude>
|
||||||
|
<ClInclude Include="PostComputingActions.h">
|
||||||
|
<Filter>Stat</Filter>
|
||||||
|
</ClInclude>
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<Filter Include="Common">
|
<Filter Include="Common">
|
||||||
|
@ -182,5 +170,8 @@
|
||||||
<Filter Include="Data Reading">
|
<Filter Include="Data Reading">
|
||||||
<UniqueIdentifier>{b866d513-7bd0-497c-98c2-f62dbcd4cde4}</UniqueIdentifier>
|
<UniqueIdentifier>{b866d513-7bd0-497c-98c2-f62dbcd4cde4}</UniqueIdentifier>
|
||||||
</Filter>
|
</Filter>
|
||||||
|
<Filter Include="Stat">
|
||||||
|
<UniqueIdentifier>{f406217f-5a11-44ca-bb34-52254dbee8af}</UniqueIdentifier>
|
||||||
|
</Filter>
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
</Project>
|
</Project>
|
|
@ -52,36 +52,7 @@ public:
|
||||||
{
|
{
|
||||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
||||||
|
|
||||||
// determine nodes to evaluate
|
let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
|
||||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
|
||||||
|
|
||||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
|
||||||
if (evalNodeNames.size() == 0)
|
|
||||||
{
|
|
||||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
|
||||||
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
|
|
||||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
|
||||||
|
|
||||||
for (const auto& node : m_net->EvaluationNodes())
|
|
||||||
if (criteriaLogged.insert(node).second)
|
|
||||||
evalNodes.push_back(node);
|
|
||||||
|
|
||||||
for (const auto& node : m_net->FinalCriterionNodes())
|
|
||||||
if (criteriaLogged.insert(node).second)
|
|
||||||
evalNodes.push_back(node);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
|
||||||
{
|
|
||||||
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
|
|
||||||
if (!criteriaLogged.insert(node).second)
|
|
||||||
continue;
|
|
||||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
|
||||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
|
||||||
evalNodes.push_back(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize eval results
|
// initialize eval results
|
||||||
std::vector<EpochCriterion> evalResults(evalNodes.size(), EpochCriterion(0));
|
std::vector<EpochCriterion> evalResults(evalNodes.size(), EpochCriterion(0));
|
||||||
|
@ -112,7 +83,7 @@ public:
|
||||||
if (useDistributedMBReading)
|
if (useDistributedMBReading)
|
||||||
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize);
|
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize);
|
||||||
else
|
else
|
||||||
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
|
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
|
||||||
|
|
||||||
m_net->StartEvaluateMinibatchLoop(evalNodes);
|
m_net->StartEvaluateMinibatchLoop(evalNodes);
|
||||||
|
|
||||||
|
@ -257,153 +228,6 @@ public:
|
||||||
return evalResults;
|
return evalResults;
|
||||||
}
|
}
|
||||||
|
|
||||||
void EvaluateBN(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring exportPath, const size_t mbSize, const int iters = 240, const size_t testSize = requestDataSize)
|
|
||||||
{
|
|
||||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
|
||||||
|
|
||||||
// determine nodes to evaluate
|
|
||||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
|
||||||
|
|
||||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
|
||||||
if (evalNodeNames.size() == 0)
|
|
||||||
{
|
|
||||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
|
||||||
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
|
|
||||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
|
||||||
|
|
||||||
for (const auto& node : m_net->EvaluationNodes())
|
|
||||||
if (criteriaLogged.insert(node).second)
|
|
||||||
evalNodes.push_back(node);
|
|
||||||
|
|
||||||
for (const auto& node : m_net->FinalCriterionNodes())
|
|
||||||
if (criteriaLogged.insert(node).second)
|
|
||||||
evalNodes.push_back(node);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
|
||||||
{
|
|
||||||
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
|
|
||||||
if (!criteriaLogged.insert(node).second)
|
|
||||||
continue;
|
|
||||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
|
||||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
|
||||||
evalNodes.push_back(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// allocate memory for forward computation
|
|
||||||
m_net->AllocateAllMatrices(evalNodes, {}, nullptr);
|
|
||||||
|
|
||||||
// prepare features and labels
|
|
||||||
auto& featureNodes = m_net->FeatureNodes();
|
|
||||||
auto& labelNodes = m_net->LabelNodes();
|
|
||||||
|
|
||||||
StreamMinibatchInputs inputMatrices;
|
|
||||||
for (auto& node : featureNodes)
|
|
||||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
|
||||||
for (auto& node : labelNodes)
|
|
||||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
|
||||||
|
|
||||||
bool useParallelTrain = (m_mpi != nullptr);
|
|
||||||
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
|
|
||||||
if (useDistributedMBReading)
|
|
||||||
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize);
|
|
||||||
else
|
|
||||||
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
|
|
||||||
|
|
||||||
m_net->StartEvaluateMinibatchLoop(evalNodes);
|
|
||||||
|
|
||||||
// Passing in two empty node lists so the dispatcher can work for the evalNodes.
|
|
||||||
std::list<ComputationNodeBasePtr> learnableNodes;
|
|
||||||
std::vector<ComputationNodeBasePtr> criterionNodes;
|
|
||||||
|
|
||||||
// First, all batch normalization nodes should be marked.
|
|
||||||
std::vector<ComputationNodeBasePtr> batchNormalNodes;
|
|
||||||
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(m_net->GetNestedNetwork(evalNodes[0]));
|
|
||||||
for (auto& node : nestedNetwork->GetNestedNodes())
|
|
||||||
{
|
|
||||||
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
|
|
||||||
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
|
||||||
if (castNode)
|
|
||||||
{
|
|
||||||
batchNormalNodes.push_back(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Push all batch normalization mean and std into learn params values for mpi update
|
|
||||||
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
|
|
||||||
|
|
||||||
bool noMoreSamplesToProcess = false;
|
|
||||||
for (auto& node : batchNormalNodes)
|
|
||||||
{
|
|
||||||
shared_ptr<BatchNormalizationNode<ElemType>> batchNode =
|
|
||||||
static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
|
||||||
batchNode->SetPostBatchNormalizationBegin();
|
|
||||||
size_t actualMBSize = 0;
|
|
||||||
|
|
||||||
LOGPRINTF(stderr, "Start evaluating: %ls\n", batchNode->GetName().c_str());
|
|
||||||
|
|
||||||
// Post batch normal iters
|
|
||||||
for (int iter = 0; iter < iters; iter++)
|
|
||||||
{
|
|
||||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
|
|
||||||
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
|
||||||
|
|
||||||
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess))
|
|
||||||
break;
|
|
||||||
|
|
||||||
// TODO should handle it, since post BN exist no samples in iters
|
|
||||||
if (!wasDataRead)
|
|
||||||
actualMBSize = 0;
|
|
||||||
|
|
||||||
// Batch Normalization Evaluate don't need to support subMinibatches
|
|
||||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
|
||||||
ComputationNetwork::BumpEvalTimeStamp(labelNodes);
|
|
||||||
|
|
||||||
m_net->ForwardProp(evalNodes[0], nullptr, node);
|
|
||||||
dataReader->DataEnd();
|
|
||||||
}
|
|
||||||
batchNode->SetPostBatchNormalizationEnd();
|
|
||||||
|
|
||||||
// Sync during or after all iters of a BN node are equivalent
|
|
||||||
if (useParallelTrain)
|
|
||||||
{
|
|
||||||
if (m_gradHeader == nullptr)
|
|
||||||
{
|
|
||||||
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
|
|
||||||
{
|
|
||||||
DistGradHeader::Destroy(ptr);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
|
|
||||||
|
|
||||||
auto runMeanParameterPtr = node->GetInputs()[3];
|
|
||||||
auto runStdParameterPtr = node->GetInputs()[4];
|
|
||||||
|
|
||||||
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
|
|
||||||
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
|
|
||||||
|
|
||||||
learnParamsValues[0] = &(runMeanNode->Value());
|
|
||||||
learnParamsValues[1] = &(runStdNode->Value());
|
|
||||||
|
|
||||||
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
|
|
||||||
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), 0);
|
|
||||||
|
|
||||||
for (auto& parameter : learnParamsValues)
|
|
||||||
{
|
|
||||||
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save Model
|
|
||||||
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
|
|
||||||
m_net->Save(exportPath);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void DisplayEvalStatistics(const size_t startMBNum, const size_t endMBNum, const size_t numSamplesLastLogged,
|
void DisplayEvalStatistics(const size_t startMBNum, const size_t endMBNum, const size_t numSamplesLastLogged,
|
||||||
const vector<ComputationNodeBasePtr>& evalNodes,
|
const vector<ComputationNodeBasePtr>& evalNodes,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче