Optimize the flow of post batch normalization statistics, and allow disable regularization terms in batch normalization

This commit is contained in:
yuxiao.guo 2016-09-09 16:41:14 +08:00
Родитель d2de39c993
Коммит 9cb329a0da
18 изменённых файлов: 384 добавлений и 375 удалений

Просмотреть файл

@ -88,11 +88,11 @@ Train=[
] ]
PBN=[ PBN=[
action="pbn" action="bnstat"
modelPath="$ModelDir$/ResNet_50" modelPath="$ModelDir$/ResNet_50"
# Set minibatch size for testing. # Set minibatch size for testing.
minibatchSize=256 minibatchSize=256
iters=30 itersPerNode=30
reader=[ reader=[
readerType="ImageReader" readerType="ImageReader"

Просмотреть файл

@ -439,7 +439,8 @@ EVAL:=eval
SGDLIB_SRC=\ SGDLIB_SRC=\
$(SOURCEDIR)/SGDLib/Profiler.cpp \ $(SOURCEDIR)/SGDLib/Profiler.cpp \
$(SOURCEDIR)/SGDLib/SGD.cpp $(SOURCEDIR)/SGDLib/SGD.cpp \
$(SOURCEDIR)/SGDLib/PostComputingActions.cpp \
EVAL_SRC=\ EVAL_SRC=\
$(SOURCEDIR)/EvalDll/CNTKEval.cpp \ $(SOURCEDIR)/EvalDll/CNTKEval.cpp \

Просмотреть файл

@ -42,11 +42,11 @@ template <typename ElemType>
void DoDumpNodes(const ConfigParameters& config); void DoDumpNodes(const ConfigParameters& config);
template <typename ElemType> template <typename ElemType>
void DoEdit(const ConfigParameters& config); void DoEdit(const ConfigParameters& config);
template <typename ElemType>
void DoBatchNormalizationStat(const ConfigParameters& config);
// evaluation (EvalActions.cpp) // evaluation (EvalActions.cpp)
template <typename ElemType> template <typename ElemType>
void DoEvalBN(const ConfigParameters& config);
template <typename ElemType>
void DoEval(const ConfigParameters& config); void DoEval(const ConfigParameters& config);
template <typename ElemType> template <typename ElemType>
void DoCrossValidate(const ConfigParameters& config); void DoCrossValidate(const ConfigParameters& config);

Просмотреть файл

@ -78,62 +78,6 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize); eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
} }
// ===========================================================================
// DoEvalBNBase() - implements CNTK "pbn" command
// ===========================================================================
template <typename ElemType>
static void DoEvalBNBase(const ConfigParameters& config, IDataReader& reader)
{
// DEVICEID_TYPE deviceId = DeviceFromConfig(config);
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
size_t epochSize = config(L"epochSize", "0");
if (epochSize == 0)
{
epochSize = requestDataSize;
}
wstring modelPath = config(L"modelPath");
wstring exportPath = modelPath + L".PBN";
intargvector mbSize = minibatchSize;
int iters = config(L"iters", 240);
int traceLevel = config(L"traceLevel", "0");
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
size_t firstMBsToShowResult = config(L"firstMBsToShowResult", "0");
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
bool enableDistributedMBReading = config(L"distributedMBReading", false);
vector<wstring> evalNodeNamesVector;
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNamesVector);
// set tracing flags
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), enableDistributedMBReading, numMBsToShowResult,
firstMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
eval.EvaluateBN(&reader, evalNodeNamesVector, exportPath, mbSize[0], iters, epochSize);
}
template <typename ElemType>
void DoEvalBN(const ConfigParameters& config)
{
// evaluate batch normalization mean and various
ConfigParameters readerConfig(config(L"reader"));
// Should trace level to zero in Post BN?
//readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
DataReader evaBNDataReader(readerConfig);
DoEvalBNBase<ElemType>(config, evaBNDataReader);
}
template <typename ElemType> template <typename ElemType>
void DoEval(const ConfigParameters& config) void DoEval(const ConfigParameters& config)
{ {
@ -146,8 +90,6 @@ void DoEval(const ConfigParameters& config)
DoEvalBase<ElemType>(config, testDataReader); DoEvalBase<ElemType>(config, testDataReader);
} }
template void DoEvalBN<double>(const ConfigParameters& config);
template void DoEvalBN<float>(const ConfigParameters& config);
template void DoEval<double>(const ConfigParameters& config); template void DoEval<double>(const ConfigParameters& config);
template void DoEval<float>(const ConfigParameters& config); template void DoEval<float>(const ConfigParameters& config);

Просмотреть файл

@ -26,6 +26,7 @@
#include "ScriptableObjects.h" #include "ScriptableObjects.h"
#include "BrainScriptEvaluator.h" #include "BrainScriptEvaluator.h"
#include "BrainScriptParser.h" #include "BrainScriptParser.h"
#include "PostComputingActions.h"
#include <string> #include <string>
#include <chrono> #include <chrono>
@ -235,3 +236,46 @@ void DoEdit(const ConfigParameters& config)
template void DoEdit<double>(const ConfigParameters& config); template void DoEdit<double>(const ConfigParameters& config);
template void DoEdit<float>(const ConfigParameters& config); template void DoEdit<float>(const ConfigParameters& config);
// ===========================================================================
// DoBatchNormalizationStat() - implements CNTK "bnstat" command
// ===========================================================================
template <typename ElemType>
void DoBatchNormalizationStat(const ConfigParameters& config)
{
ConfigParameters readerConfig(config(L"reader"));
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
auto dataReader = make_shared<DataReader>(readerConfig);
int traceLevel = config(L"traceLevel", "0");
int itersPerNode = config(L"itersPerNode", 30);
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
intargvector mbSize = minibatchSize;
bool enableDistributedMBReading = config(L"enableDistributedMBReading", false);
wstring curModelPath = config(L"modelPath", L"");
wstring newModelPath = config(L"newModelPath", L"");
if (newModelPath == L"")
{
newModelPath = curModelPath + L".PBN";
}
std::vector<std::wstring> evalNodeNames;
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNames);
// set tracing flags
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
PostComputingActions<ElemType> postComputingActions(net, MPIWrapper::GetInstance(), enableDistributedMBReading, traceLevel);
postComputingActions.BatchNormalizationStatistics(dataReader.get(), evalNodeNames, newModelPath, mbSize[0], itersPerNode);
}
template void DoBatchNormalizationStat<double>(const ConfigParameters& config);
template void DoBatchNormalizationStat<float>(const ConfigParameters& config);

Просмотреть файл

@ -154,7 +154,7 @@ static void DisableLegacyUsage(const ConfigParameters& TopLevelConfig, const Con
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should // When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
// be run in parallel across multiple ranks. Others should only run on rank 0 // be run in parallel across multiple ranks. Others should only run on rank 0
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "pbn" }; const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "bnstat" };
// process the command // process the command
template <typename ElemType> template <typename ElemType>
@ -243,9 +243,9 @@ void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mp
LOGPRINTF(stderr, "CNTKCommandTrainEnd: %s\n", command[i].c_str()); LOGPRINTF(stderr, "CNTKCommandTrainEnd: %s\n", command[i].c_str());
fullEpochsOffset += GetMaxEpochs(commandParams); fullEpochsOffset += GetMaxEpochs(commandParams);
} }
else if (thisAction == "pbn") else if (thisAction == "bnstat")
{ {
DoEvalBN<ElemType>(commandParams); DoBatchNormalizationStat<ElemType>(commandParams);
} }
else if (thisAction == "adapt") else if (thisAction == "adapt")
{ {

Просмотреть файл

@ -136,10 +136,6 @@ public:
// main entry point for backprop // main entry point for backprop
void Backprop(const ComputationNodeBasePtr rootNode); void Backprop(const ComputationNodeBasePtr rootNode);
// partial forward entry
void ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode,
const ComputationNodeBasePtr endNode);
template <class NODESET> // version that takes multiple nodes template <class NODESET> // version that takes multiple nodes
void ForwardProp(const NODESET& nodes) void ForwardProp(const NODESET& nodes)
{ {
@ -678,6 +674,44 @@ public:
return nodesWithType; return nodesWithType;
} }
// Get the eval nodes with names
// if evalNodeNames are not specified, return all the default evalnodes and training criterion nodes.
std::vector<ComputationNodeBasePtr> GetEvalNodesWithName(const std::vector<wstring> evalNodeNames)
{
// determine nodes to evaluate
std::vector<ComputationNodeBasePtr> evalNodes;
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
if (evalNodeNames.size() == 0)
{
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
if (EvaluationNodes().empty() && FinalCriterionNodes().empty())
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
for (const auto& node : EvaluationNodes())
if (criteriaLogged.insert(node).second)
evalNodes.push_back(node);
for (const auto& node : FinalCriterionNodes())
if (criteriaLogged.insert(node).second)
evalNodes.push_back(node);
}
else
{
for (int i = 0; i < evalNodeNames.size(); i++)
{
const auto& node = GetNodeFromName(evalNodeNames[i]);
if (!criteriaLogged.insert(node).second)
continue;
if (node->GetSampleLayout().GetNumElements() != 1)
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
evalNodes.push_back(node);
}
}
return evalNodes;
}
public: public:
// return list of nodes that require precomputation and not precomputed yet // return list of nodes that require precomputation and not precomputed yet
std::list<ComputationNodeBasePtr> GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true); std::list<ComputationNodeBasePtr> GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true);
@ -1014,7 +1048,7 @@ protected:
virtual const std::wstring OperationName() const override virtual const std::wstring OperationName() const override
{ {
return L"PARTraversalFlowControlNode"; return L"PARTraversalFlowControlNode";
} }
virtual void BeginForwardProp() override virtual void BeginForwardProp() override
{ {
} }
@ -1038,8 +1072,6 @@ protected:
virtual void AllocateGradientMatricesForInputs(MatrixPool& matrixPool); virtual void AllocateGradientMatricesForInputs(MatrixPool& matrixPool);
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool); virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool);
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool); virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool);
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) override;
public: public:
// this special constructor constructs the top-level network node // this special constructor constructs the top-level network node

Просмотреть файл

@ -79,17 +79,6 @@ void ComputationNetwork::Backprop(const ComputationNodeBasePtr rootNode) // trai
GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true); GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true);
} }
void ComputationNetwork::ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode, const ComputationNodeBasePtr endNode)
{
VerifyIsCompiled("ForwardProp");
// traverse partial nodes as inputs
shared_ptr<FlowControlNode> network = dynamic_pointer_cast<FlowControlNode>(GetNestedNetwork(rootNode));
assert(network);
network->ForwardProp(FrameRange(nullptr), startNode, endNode);
}
void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode) void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode)
{ {
if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end()) if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end())
@ -158,7 +147,6 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
} }
} }
} }
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/ /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
{ {
childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
@ -187,36 +175,7 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/ /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
{ {
} }
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const FrameRange & fr, ComputationNodeBasePtr startNode, ComputationNodeBasePtr endNode)
{
// if start node is nullptr, forward will be enable
bool enableForward = startNode ? false : true;
for (auto& node : m_nestedNodes)
{
#if 0
if (dynamic_pointer_cast<LearnableParameter<float>>(node))
dynamic_pointer_cast<ComputationNode<float>>(node)->DebugLogMinibatch();
#endif
if (node->IsOutOfDateWrtInputs() && enableForward)
{
node->BeginForwardProp();
node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
node->EndForwardProp();
node->BumpEvalTimeStamp();
}
if (node == startNode)
{
enableForward = true;
}
else if (node == endNode)
{
break;
}
}
}
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
// SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling) // SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling)

Просмотреть файл

@ -1878,10 +1878,6 @@ public:
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const override {} virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const override {}
virtual std::set<std::pair<const MatrixBase*, std::wstring>> GetMatrixInfo() const override { NOT_IMPLEMENTED; } virtual std::set<std::pair<const MatrixBase*, std::wstring>> GetMatrixInfo() const override { NOT_IMPLEMENTED; }
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) { NOT_IMPLEMENTED; }
std::vector<ComputationNodeBasePtr> GetNestedNodes() { return m_nestedNodes; }
protected: public: // needed in ComputationNetwork::FindInRecurrentLoops(), which really should be part of SEQTraversalFlowControlNode protected: public: // needed in ComputationNetwork::FindInRecurrentLoops(), which really should be part of SEQTraversalFlowControlNode
std::vector<ComputationNodeBasePtr> m_nestedNodes; // nodes tucked away in this node, in evaluation order std::vector<ComputationNodeBasePtr> m_nestedNodes; // nodes tucked away in this node, in evaluation order
}; };

Просмотреть файл

@ -37,6 +37,7 @@ public:
MarkValueNonSharable(); MarkValueNonSharable();
m_initString = L"fromValue"; // default init is with 0; typically overwritten m_initString = L"fromValue"; // default init is with 0; typically overwritten
m_initValue = 0; m_initValue = 0;
m_regMultiplier = 1.0f; // enable reg in update by default
} }
LearnableParameter(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& shape) : LearnableParameter(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& shape) :
LearnableParameter(deviceId, name) LearnableParameter(deviceId, name)
@ -101,6 +102,14 @@ public:
// called from CloneFunction(..., parameters="constant") // called from CloneFunction(..., parameters="constant")
virtual void FreezeParameters() override; // from IFreezable virtual void FreezeParameters() override; // from IFreezable
// Setting the reg multiplier for a learnable node, effecting L1Reg and L2Reg both.
void SetRegMultiplier(float regMultiplier)
{
m_regMultiplier = regMultiplier;
}
// called from SGD UpdateWeights, to adjust the reg for each node
float GetRegMultiplier() const { return m_regMultiplier; }
private: private:
// init parameters for deferred initialization (which happens in Validate()) // init parameters for deferred initialization (which happens in Validate())
std::wstring m_initString; // if non-empty then deferred initialization is needed. Gets cleared upon completion of deferred init. std::wstring m_initString; // if non-empty then deferred initialization is needed. Gets cleared upon completion of deferred init.
@ -109,6 +118,9 @@ private:
int m_initOutputRank; int m_initOutputRank;
bool m_initOnCPUOnly; bool m_initOnCPUOnly;
ElemType m_initValue; ElemType m_initValue;
// flags related to gradient update
float m_regMultiplier; // The multiplier to adjust the L1Reg and L2Reg for Learnable node
}; };
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------

Просмотреть файл

@ -8,6 +8,7 @@
#include "ComputationNode.h" #include "ComputationNode.h"
#include "BatchNormalizationEngine.h" #include "BatchNormalizationEngine.h"
#include "RNGHandle.h" #include "RNGHandle.h"
#include "InputAndParamNodes.h"
#define __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS
#include <inttypes.h> #include <inttypes.h>
@ -1587,15 +1588,15 @@ class BatchNormalizationNode : public ComputationNodeNonLooping<ElemType>, publi
public: public:
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) : BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) :
Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true), Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true),
m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW), m_postBatchNormalization(false), m_swapNormTimeConst(0), m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW),
m_swapBlendTimeConst(0), m_convertRunningVariancePending(false) m_convertRunningVariancePending(false)
{ {
} }
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant, BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant,
double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) : double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) :
Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant), Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant),
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0), m_postBatchNormalization(false), m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0),
m_swapNormTimeConst(0), m_swapBlendTimeConst(0), m_convertRunningVariancePending(false) m_convertRunningVariancePending(false)
{ {
} }
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) : BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) :
@ -1605,9 +1606,6 @@ public:
ImageLayoutKindFrom(configp->Get(L"imageLayout"))) ImageLayoutKindFrom(configp->Get(L"imageLayout")))
{ {
AttachInputsFromConfig(configp, this->GetExpectedNumInputs()); AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
m_postBatchNormalization = false;
m_swapNormTimeConst = 0;
m_swapBlendTimeConst = 0;
} }
void Save(File& fstream) const override void Save(File& fstream) const override
@ -1724,7 +1722,7 @@ private: // time-constant conversions
double ComputeExpAvgFactor() const double ComputeExpAvgFactor() const
{ {
// in inference mode, only use long-term mean and do not update running estimates // in inference mode, only use long-term mean and do not update running estimates
if (!Environment().IsTraining() && !m_postBatchNormalization) if (!Environment().IsTraining())
{ {
if (m_samplesSeen == 0) if (m_samplesSeen == 0)
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str()); RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
@ -1756,7 +1754,7 @@ private: // time-constant conversions
double ComputeBlendFactor() const double ComputeBlendFactor() const
{ {
// in inference mode, only use long-term mean and do not update running estimates // in inference mode, only use long-term mean and do not update running estimates
if (!Environment().IsTraining() && !m_postBatchNormalization) if (!Environment().IsTraining())
{ {
if (m_samplesSeen == 0) if (m_samplesSeen == 0)
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str()); RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
@ -1805,7 +1803,7 @@ public:
// In inference-only mode, m_savedMean and m_saveInvStdDev will not be // In inference-only mode, m_savedMean and m_saveInvStdDev will not be
// produced and BackpropToNonLooping() may not be called. In // produced and BackpropToNonLooping() may not be called. In
// non-inference (training) mode, saved statistics must be produced. // non-inference (training) mode, saved statistics must be produced.
bool inferenceOnly = !Environment().IsTraining() && !m_postBatchNormalization; bool inferenceOnly = !Environment().IsTraining();
m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in) m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in)
inferenceOnly, expAvgFactor, blendFactor, inferenceOnly, expAvgFactor, blendFactor,
runMean, runVariance, // (in/out) running estimates, updated from the current MB mean/variance runMean, runVariance, // (in/out) running estimates, updated from the current MB mean/variance
@ -1870,14 +1868,6 @@ public:
} }
virtual void EndForwardProp() override virtual void EndForwardProp() override
{
if(m_postBatchNormalization)
m_samplesSeen += GetMBLayout()->GetActualNumSamples();
Base::EndForwardProp();
}
virtual void EndBackprop() override
{ {
// Update samples if not locked. // Update samples if not locked.
double expAvgFactor = ComputeExpAvgFactor(); // weight for the new MB statistics in the running estimate. The previous value of the running statistics is kept with weight (1-this) double expAvgFactor = ComputeExpAvgFactor(); // weight for the new MB statistics in the running estimate. The previous value of the running statistics is kept with weight (1-this)
@ -2019,28 +2009,29 @@ public:
m_blendTimeConst = std::numeric_limits<double>::infinity(); m_blendTimeConst = std::numeric_limits<double>::infinity();
} }
// ResetStatisticsState will set the batch normal statistics into initial state
// used for re-statistics the mean and variance of BN
// any others use may lead undependable results, please be careful
void ResetStatisticsState()
{
m_samplesSeen = 0;
m_normTimeConst = 0;
m_blendTimeConst = 0;
}
// Turn off the L1 and L2 regularization
void DisableRegInBatchNormalization()
{
let scaleNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(1));
let biasNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(2));
scaleNode->SetRegMultiplier(0.f);
biasNode->SetRegMultiplier(0.f);
}
double NormalizationTimeConstant() const { return m_normTimeConst; } double NormalizationTimeConstant() const { return m_normTimeConst; }
double BlendTimeConstant() const { return m_blendTimeConst; } double BlendTimeConstant() const { return m_blendTimeConst; }
bool Spatial() const { return m_spatial; } bool Spatial() const { return m_spatial; }
double Epsilon() const { return m_epsilon; } double Epsilon() const { return m_epsilon; }
bool UseCNTKEngine() const { return m_useCntkEngine; } bool UseCNTKEngine() const { return m_useCntkEngine; }
void SetPostBatchNormalizationBegin()
{
m_postBatchNormalization = true;
m_samplesSeen = 0;
m_swapNormTimeConst = m_normTimeConst;
m_swapBlendTimeConst = m_blendTimeConst;
m_normTimeConst = -1;
m_blendTimeConst = 0;
}
void SetPostBatchNormalizationEnd()
{
m_postBatchNormalization = false;
m_normTimeConst = m_swapNormTimeConst;
m_blendTimeConst = m_swapBlendTimeConst;
}
private: private:
// Old versioning - do not use. Do not remove until we're sure there are no old models around. // Old versioning - do not use. Do not remove until we're sure there are no old models around.
struct VersionInfo struct VersionInfo
@ -2104,11 +2095,6 @@ private:
std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng; std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng;
// post batch normalization process mark
bool m_postBatchNormalization;
double m_swapNormTimeConst;
double m_swapBlendTimeConst;
bool m_convertRunningVariancePending; bool m_convertRunningVariancePending;
}; };

Просмотреть файл

@ -0,0 +1,161 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// PostStat.cpp -- CNTK post statistics related actions
//
#include "PostComputingActions.h"
#include "TrainingNodes.h"
#include "ProgressTracing.h"
#include "DataReaderHelpers.h"
#include "SimpleDistGradAggregator.h"
#include <vector>
namespace Microsoft { namespace MSR{ namespace CNTK {
template <class ElemType>
void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader * dataReader, const vector<wstring>& evalNodeNames,
const wstring newModelPath, const size_t mbSize, const int iters)
{
// since the mean and variance of bn will be modified in statistics,
// training mode will make it work. And there is no back prop, other parameters
// are fixed during computing.
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::training);
// bn nodes need to be computed from bottom to top with evaluating order
let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
// find all the BN nodes by evalOrder
std::vector<ComputationNodeBasePtr> bnNodes;
std::set<ComputationNodeBasePtr> bnNodesLogged; // (avoid double record of batch normalization nodes)
for (auto& evalNode : evalNodes)
{
for (auto& node : m_net->GetEvalOrder(evalNode))
{
let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
if (bnNode)
{
if (bnNodesLogged.insert(node).second)
{
// reset the statistics states of bn nodes
bnNode->ResetStatisticsState();
bnNode->SetNormalizationTimeConstants(-1, bnNode->NormalizationTimeConstant(),
0, bnNode->BlendTimeConstant());
bnNodes.push_back(node);
// add BN nodes into the evaluation group, then they will be added into root nodes when
// the network re-compile
m_net->AddToNodeGroup(L"evaluation", bnNode);
}
}
}
}
// re-compile the network to add bn nodes as rootNodes.
m_net->CompileNetwork();
// allocate memory for all bnNodes evalOrder
m_net->AllocateAllMatrices(bnNodes, std::vector<ComputationNodeBasePtr>(), nullptr);
// prepare features
auto& featureNodes = m_net->FeatureNodes();
StreamMinibatchInputs inputMatrices;
for (auto& node : featureNodes)
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
bool useParallelTrain = (m_mpi != nullptr);
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
size_t totalEpochSize = bnNodes.size() * mbSize * iters;
m_net->StartEvaluateMinibatchLoop(bnNodes);
if (useDistributedMBReading)
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), totalEpochSize);
else
dataReader->StartMinibatchLoop(mbSize, 0, totalEpochSize);
for (auto& node : bnNodes)
{
let bnNode = static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
size_t actualMBSize = 0;
LOGPRINTF(stderr, "Estimating Statistics --> %ls\n", bnNode->GetName().c_str());
// for every single bn node, the statistics is the average of mean and variance for several times in forward prop
// the forward prop is from the feature to the current bn node
for (int iter = 0; iter < iters; iter++)
{
// during the bn stat, dataRead must be ensured
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
if (!wasDataRead) LogicError("DataRead Failure in batch normalization statistics");
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
// forward prop till reaching the current bn node
m_net->ForwardProp(node);
}
// after finished statistics, the mean and variance of the bn node should be freezd.
bnNode->FreezeParameters();
// Sync during or after all iters of a BN node are equivalent
if (useParallelTrain)
{
if (m_gradHeader == nullptr)
{
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
{
DistGradHeader::Destroy(ptr);
});
}
// push the statistics results of mean and variance of bn nodes into mpi updating vector
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
auto runMeanParameterPtr = node->Input(3);
auto runStdParameterPtr = node->Input(4);
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
learnParamsValues[0] = &(runMeanNode->Value());
learnParamsValues[1] = &(runStdNode->Value());
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), 0);
// get the average mean and variance across all the workers
for (auto& parameter : learnParamsValues)
{
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
}
}
}
dataReader->DataEnd();
// remove all the added BN nodes from evaluation group
for (auto& bnNode : bnNodes)
{
m_net->RemoveFromNodeGroup(L"evaluation", bnNode);
}
// save model
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
m_net->Save(newModelPath);
return;
}
template class PostComputingActions<float>;
template class PostComputingActions<double>;
}}}

Просмотреть файл

@ -0,0 +1,65 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// PostStat.h -- CNTK post statistics related actions
//
#pragma once
#include "ComputationNode.h"
#include "ComputationNetwork.h"
#include "MPIWrapper.h"
#include "DataReader.h"
#include "IDistGradAggregator.h"
#include "DistGradHeader.h"
using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class IDistGradAggregator;
// Post statistics normally called between training and evaluating, to generate the statistics results used by evaluating
// For now, the application is only with statistics mean and variance of Batch Normalization nodes after training
template <class ElemType>
class PostComputingActions
{
public:
PostComputingActions(ComputationNetworkPtr net, const MPIWrapperPtr& mpi, bool enableDistributedMBReading = false, const int traceLevel = 0) :
m_net(net),
m_traceLevel(traceLevel),
m_mpi(mpi),
m_distGradAgg(nullptr),
m_gradHeader(nullptr),
m_enableDistributedMBReading(enableDistributedMBReading)
{
}
// This function is used for evaluating the mean and variance of all batch normalization nodes after training.
// Details will link to the wiki https://github.com/Microsoft/CNTK/wiki/Post-Batch-Normalization-Statistics
// The reason why put it into evalute is the action take place after trainning and non-backprop processing, which makes me believe
// this function is like a kind of evaluate function.
// In this function,
// 1. since all other weights are fix except the un-pbn nodes, I set the networkoperationMode into inferring.
// 2. The next thing is to load the network model and data source, I follow the Evaluate function to do so, however, I delete something
// seem useless, like error statistics etc.
// 3. Finding the BN nodes in the network and put them into a vector with evaluate order (This links the nestedNode vector I got in
// ControlFlowNetwork)
// 4. From node to node in the BN vector to generate the mean and various (This links to the changes of BatchNormalizationNode
// in TrainingNodes.h, since I need to make the nodes "learn" mean and variance in inferring mode)
// 5. Consider the multi-GPU, we need to sync up the BN results between all the worker and average the value.
void BatchNormalizationStatistics(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring newModelPath,
const size_t mbSize, const int iters = 30);
private:
ComputationNetworkPtr m_net;
MPIWrapperPtr m_mpi;
bool m_enableDistributedMBReading;
int m_traceLevel;
std::shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg;
std::shared_ptr<struct DistGradHeader> m_gradHeader;
};
}}}

Просмотреть файл

@ -8,6 +8,7 @@
#include "SpecialPurposeNodes.h" // for SequenceWithSoftmaxNode #include "SpecialPurposeNodes.h" // for SequenceWithSoftmaxNode
#include "DataReaderHelpers.h" #include "DataReaderHelpers.h"
#include "MatrixQuantizerImpl.h" #include "MatrixQuantizerImpl.h"
#include "InputAndParamNodes.h"
#ifdef CNTK_PARALLEL_TRAINING_SUPPORT #ifdef CNTK_PARALLEL_TRAINING_SUPPORT
//static inline bool operator==(const std::pair<double,size_t>& a, double b) { assert(b==0); return a.first == b; } //static inline bool operator==(const std::pair<double,size_t>& a, double b) { assert(b==0); return a.first == b; }
@ -875,23 +876,13 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
EpochCriterion epochCriterionLastLogged = epochCriterion; EpochCriterion epochCriterionLastLogged = epochCriterion;
vector<EpochCriterion> epochEvalErrorsLastLogged = epochEvalErrors; vector<EpochCriterion> epochEvalErrorsLastLogged = epochEvalErrors;
// Now, we need to use a switch to enable/disable wk in BatchNormalization. // NOTE: For ResNet, the regularization in BatchNormalization should be disable.
// If we can determine whether wk added or not for each node, then, discard this if (m_disableRegInBatchNormalization) {
std::unordered_set<ComputationNodeBasePtr> batchNormalizationWeights; let bnNodes = net->GetNodesWithType(L"BatchNormalization");
if (m_disableWkInBatchNormal) { for (auto &node : bnNodes)
for (auto& evalNode : evaluationNodes)
{ {
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(net->GetNestedNetwork(evalNode)); let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
for (auto& node : nestedNetwork->GetNestedNodes()) bnNode->DisableRegInBatchNormalization();
{
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
if (castNode)
{
batchNormalizationWeights.insert(castNode->GetInputs()[1]);
batchNormalizationWeights.insert(castNode->GetInputs()[2]);
}
}
} }
} }
@ -1110,11 +1101,10 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
if (smoothedGradient.HasNan("TrainOneEpoch/UpdateWeights(): ")) if (smoothedGradient.HasNan("TrainOneEpoch/UpdateWeights(): "))
LogicError("%ls %ls operation has NaNs in smoothedGradient.", node->NodeName().c_str(), node->OperationName().c_str()); LogicError("%ls %ls operation has NaNs in smoothedGradient.", node->NodeName().c_str(), node->OperationName().c_str());
#endif #endif
double l2Factor = batchNormalizationWeights.find(node) == batchNormalizationWeights.end() ? 1.0 : 0.0;
// BUGBUG (Issue #95): Access to net MBLayout can no longer be done if we have multiple input layouts // BUGBUG (Issue #95): Access to net MBLayout can no longer be done if we have multiple input layouts
UpdateWeights(node, smoothedGradient, learnRatePerSample, UpdateWeights(node, smoothedGradient, learnRatePerSample,
GetMomentumPerSample(epochNumber /*BUGBUG workaround:*/, net->GetMBLayoutPtrOfNetwork()->GetNumParallelSequences()), numSamplesInMinibatch, GetMomentumPerSample(epochNumber /*BUGBUG workaround:*/, net->GetMBLayoutPtrOfNetwork()->GetNumParallelSequences()), numSamplesInMinibatch,
m_L2RegWeight * l2Factor, m_L1RegWeight, m_L2RegWeight, m_L1RegWeight,
m_needAveMultiplier, m_useNesterovMomentum); m_needAveMultiplier, m_useNesterovMomentum);
#ifdef _DEBUG #ifdef _DEBUG
if (dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value().HasNan("TrainOneEpoch/UpdateWeights(): ")) if (dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value().HasNan("TrainOneEpoch/UpdateWeights(): "))
@ -2017,9 +2007,10 @@ void SGD<ElemType>::UpdateWeights(const ComputationNodeBasePtr& node,
LogicError("UpdateWeights() called for a learnable ComputationNode which has m_learningRateMultiplier == 0!"); LogicError("UpdateWeights() called for a learnable ComputationNode which has m_learningRateMultiplier == 0!");
double nodeDependentLearningRatePerSample = learnRatePerSample * node->GetLearningRateMultiplier(); double nodeDependentLearningRatePerSample = learnRatePerSample * node->GetLearningRateMultiplier();
double nodeDependentRegMultiplier = dynamic_pointer_cast<LearnableParameter<ElemType>>(node)->GetRegMultiplier();
UpdateWeightsS(this, dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value(), dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Gradient(), UpdateWeightsS(this, dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value(), dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Gradient(),
smoothedGradient, nodeDependentLearningRatePerSample, momentumPerSample, smoothedGradient, nodeDependentLearningRatePerSample, momentumPerSample,
actualMBSize, L2RegWeight, L1RegWeight, actualMBSize, L2RegWeight * nodeDependentRegMultiplier, L1RegWeight * nodeDependentRegMultiplier,
needAveMultiplier, m_useNesterovMomentum); needAveMultiplier, m_useNesterovMomentum);
node->BumpEvalTimeStamp(); node->BumpEvalTimeStamp();
} }
@ -2475,7 +2466,7 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0); m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0);
m_seqGammarCalcWP = configSGD(L"seqGammarWordPen", 0.0); m_seqGammarCalcWP = configSGD(L"seqGammarWordPen", 0.0);
m_disableWkInBatchNormal = configSGD(L"disableWkInBatchNormal", false); m_disableRegInBatchNormalization = configSGD(L"disableRegInBatchNormalization", false);
m_dropoutRates = configSGD(L"dropoutRate", ConfigRecordType::Array(doubleargvector(vector<double>{0.0}))); m_dropoutRates = configSGD(L"dropoutRate", ConfigRecordType::Array(doubleargvector(vector<double>{0.0})));
m_batchNormalizationTimeConstant = configSGD(L"batchNormalizationTimeConstant", ConfigRecordType::Array(doubleargvector(vector<double>{0}))); m_batchNormalizationTimeConstant = configSGD(L"batchNormalizationTimeConstant", ConfigRecordType::Array(doubleargvector(vector<double>{0})));

Просмотреть файл

@ -291,7 +291,10 @@ protected:
double m_seqGammarCalcbMMIFactor; double m_seqGammarCalcbMMIFactor;
bool m_seqGammarCalcUsesMBR; bool m_seqGammarCalcUsesMBR;
bool m_disableWkInBatchNormal; // decide whether should apply L2 regularization into BatchNormalizationNode
// true: disable L2 Regularization
// false: enable L2 Regularization (default)
bool m_disableRegInBatchNormalization;
}; };
template <class ElemType> template <class ElemType>

Просмотреть файл

@ -124,6 +124,7 @@
<ClInclude Include="..\ComputationNetworkLib\NonlinearityNodes.h" /> <ClInclude Include="..\ComputationNetworkLib\NonlinearityNodes.h" />
<ClInclude Include="..\ComputationNetworkLib\RecurrentNodes.h" /> <ClInclude Include="..\ComputationNetworkLib\RecurrentNodes.h" />
<ClInclude Include="MASGD.h" /> <ClInclude Include="MASGD.h" />
<ClInclude Include="PostComputingActions.h" />
<ClInclude Include="SimpleDistGradAggregator.h" /> <ClInclude Include="SimpleDistGradAggregator.h" />
<ClInclude Include="SimpleEvaluator.h" /> <ClInclude Include="SimpleEvaluator.h" />
<ClInclude Include="SimpleOutputWriter.h" /> <ClInclude Include="SimpleOutputWriter.h" />
@ -132,10 +133,11 @@
<ClInclude Include="targetver.h" /> <ClInclude Include="targetver.h" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="PostComputingActions.cpp" />
<ClCompile Include="Profiler.cpp" /> <ClCompile Include="Profiler.cpp" />
<ClCompile Include="SGD.cpp" /> <ClCompile Include="SGD.cpp" />
<ClCompile Include="stdafx.cpp" /> <ClCompile Include="stdafx.cpp" />
</ItemGroup> </ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets" /> <ImportGroup Label="ExtensionTargets" />
</Project> </Project>

Просмотреть файл

@ -1,32 +1,17 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> <Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup> <ItemGroup>
<ClCompile Include="..\Common\DataReader.cpp">
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="..\Common\DataWriter.cpp">
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="..\Common\File.cpp">
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="..\Common\fileutil.cpp">
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="stdafx.cpp"> <ClCompile Include="stdafx.cpp">
<Filter>Misc</Filter> <Filter>Misc</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\Common\TimerUtility.cpp">
<Filter>Common</Filter>
</ClCompile>
<ClCompile Include="Profiler.cpp"> <ClCompile Include="Profiler.cpp">
<Filter>GPU Interfacing</Filter> <Filter>GPU Interfacing</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="SGD.cpp"> <ClCompile Include="SGD.cpp">
<Filter>SGD</Filter> <Filter>SGD</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\Common\Config.cpp"> <ClCompile Include="PostComputingActions.cpp">
<Filter>Common</Filter> <Filter>Stat</Filter>
</ClCompile> </ClCompile>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
@ -144,6 +129,9 @@
<ClInclude Include="Criterion.h"> <ClInclude Include="Criterion.h">
<Filter>SGD</Filter> <Filter>SGD</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="PostComputingActions.h">
<Filter>Stat</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<Filter Include="Common"> <Filter Include="Common">
@ -182,5 +170,8 @@
<Filter Include="Data Reading"> <Filter Include="Data Reading">
<UniqueIdentifier>{b866d513-7bd0-497c-98c2-f62dbcd4cde4}</UniqueIdentifier> <UniqueIdentifier>{b866d513-7bd0-497c-98c2-f62dbcd4cde4}</UniqueIdentifier>
</Filter> </Filter>
<Filter Include="Stat">
<UniqueIdentifier>{f406217f-5a11-44ca-bb34-52254dbee8af}</UniqueIdentifier>
</Filter>
</ItemGroup> </ItemGroup>
</Project> </Project>

Просмотреть файл

@ -52,36 +52,7 @@ public:
{ {
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring); ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
// determine nodes to evaluate let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
std::vector<ComputationNodeBasePtr> evalNodes;
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
if (evalNodeNames.size() == 0)
{
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
for (const auto& node : m_net->EvaluationNodes())
if (criteriaLogged.insert(node).second)
evalNodes.push_back(node);
for (const auto& node : m_net->FinalCriterionNodes())
if (criteriaLogged.insert(node).second)
evalNodes.push_back(node);
}
else
{
for (int i = 0; i < evalNodeNames.size(); i++)
{
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
if (!criteriaLogged.insert(node).second)
continue;
if (node->GetSampleLayout().GetNumElements() != 1)
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
evalNodes.push_back(node);
}
}
// initialize eval results // initialize eval results
std::vector<EpochCriterion> evalResults(evalNodes.size(), EpochCriterion(0)); std::vector<EpochCriterion> evalResults(evalNodes.size(), EpochCriterion(0));
@ -112,7 +83,7 @@ public:
if (useDistributedMBReading) if (useDistributedMBReading)
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize); dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize);
else else
dataReader->StartMinibatchLoop(mbSize, 0, testSize); dataReader->StartMinibatchLoop(mbSize, 0, testSize);
m_net->StartEvaluateMinibatchLoop(evalNodes); m_net->StartEvaluateMinibatchLoop(evalNodes);
@ -257,153 +228,6 @@ public:
return evalResults; return evalResults;
} }
void EvaluateBN(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring exportPath, const size_t mbSize, const int iters = 240, const size_t testSize = requestDataSize)
{
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
// determine nodes to evaluate
std::vector<ComputationNodeBasePtr> evalNodes;
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
if (evalNodeNames.size() == 0)
{
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
for (const auto& node : m_net->EvaluationNodes())
if (criteriaLogged.insert(node).second)
evalNodes.push_back(node);
for (const auto& node : m_net->FinalCriterionNodes())
if (criteriaLogged.insert(node).second)
evalNodes.push_back(node);
}
else
{
for (int i = 0; i < evalNodeNames.size(); i++)
{
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
if (!criteriaLogged.insert(node).second)
continue;
if (node->GetSampleLayout().GetNumElements() != 1)
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
evalNodes.push_back(node);
}
}
// allocate memory for forward computation
m_net->AllocateAllMatrices(evalNodes, {}, nullptr);
// prepare features and labels
auto& featureNodes = m_net->FeatureNodes();
auto& labelNodes = m_net->LabelNodes();
StreamMinibatchInputs inputMatrices;
for (auto& node : featureNodes)
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
for (auto& node : labelNodes)
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
bool useParallelTrain = (m_mpi != nullptr);
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
if (useDistributedMBReading)
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize);
else
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
m_net->StartEvaluateMinibatchLoop(evalNodes);
// Passing in two empty node lists so the dispatcher can work for the evalNodes.
std::list<ComputationNodeBasePtr> learnableNodes;
std::vector<ComputationNodeBasePtr> criterionNodes;
// First, all batch normalization nodes should be marked.
std::vector<ComputationNodeBasePtr> batchNormalNodes;
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(m_net->GetNestedNetwork(evalNodes[0]));
for (auto& node : nestedNetwork->GetNestedNodes())
{
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
if (castNode)
{
batchNormalNodes.push_back(node);
}
}
// Push all batch normalization mean and std into learn params values for mpi update
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
bool noMoreSamplesToProcess = false;
for (auto& node : batchNormalNodes)
{
shared_ptr<BatchNormalizationNode<ElemType>> batchNode =
static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
batchNode->SetPostBatchNormalizationBegin();
size_t actualMBSize = 0;
LOGPRINTF(stderr, "Start evaluating: %ls\n", batchNode->GetName().c_str());
// Post batch normal iters
for (int iter = 0; iter < iters; iter++)
{
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess))
break;
// TODO should handle it, since post BN exist no samples in iters
if (!wasDataRead)
actualMBSize = 0;
// Batch Normalization Evaluate don't need to support subMinibatches
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
ComputationNetwork::BumpEvalTimeStamp(labelNodes);
m_net->ForwardProp(evalNodes[0], nullptr, node);
dataReader->DataEnd();
}
batchNode->SetPostBatchNormalizationEnd();
// Sync during or after all iters of a BN node are equivalent
if (useParallelTrain)
{
if (m_gradHeader == nullptr)
{
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
{
DistGradHeader::Destroy(ptr);
});
}
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
auto runMeanParameterPtr = node->GetInputs()[3];
auto runStdParameterPtr = node->GetInputs()[4];
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
learnParamsValues[0] = &(runMeanNode->Value());
learnParamsValues[1] = &(runStdNode->Value());
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), 0);
for (auto& parameter : learnParamsValues)
{
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
}
}
}
// Save Model
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
m_net->Save(exportPath);
return;
}
protected: protected:
void DisplayEvalStatistics(const size_t startMBNum, const size_t endMBNum, const size_t numSamplesLastLogged, void DisplayEvalStatistics(const size_t startMBNum, const size_t endMBNum, const size_t numSamplesLastLogged,
const vector<ComputationNodeBasePtr>& evalNodes, const vector<ComputationNodeBasePtr>& evalNodes,