Optimize the flow of post batch normalization statistics, and allow disable regularization terms in batch normalization
This commit is contained in:
Родитель
d2de39c993
Коммит
9cb329a0da
|
@ -88,11 +88,11 @@ Train=[
|
|||
]
|
||||
|
||||
PBN=[
|
||||
action="pbn"
|
||||
action="bnstat"
|
||||
modelPath="$ModelDir$/ResNet_50"
|
||||
# Set minibatch size for testing.
|
||||
minibatchSize=256
|
||||
iters=30
|
||||
itersPerNode=30
|
||||
|
||||
reader=[
|
||||
readerType="ImageReader"
|
||||
|
|
3
Makefile
3
Makefile
|
@ -439,7 +439,8 @@ EVAL:=eval
|
|||
|
||||
SGDLIB_SRC=\
|
||||
$(SOURCEDIR)/SGDLib/Profiler.cpp \
|
||||
$(SOURCEDIR)/SGDLib/SGD.cpp
|
||||
$(SOURCEDIR)/SGDLib/SGD.cpp \
|
||||
$(SOURCEDIR)/SGDLib/PostComputingActions.cpp \
|
||||
|
||||
EVAL_SRC=\
|
||||
$(SOURCEDIR)/EvalDll/CNTKEval.cpp \
|
||||
|
|
|
@ -42,11 +42,11 @@ template <typename ElemType>
|
|||
void DoDumpNodes(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoEdit(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoBatchNormalizationStat(const ConfigParameters& config);
|
||||
|
||||
// evaluation (EvalActions.cpp)
|
||||
template <typename ElemType>
|
||||
void DoEvalBN(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoEval(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoCrossValidate(const ConfigParameters& config);
|
||||
|
|
|
@ -78,62 +78,6 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
|
|||
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// DoEvalBNBase() - implements CNTK "pbn" command
|
||||
// ===========================================================================
|
||||
|
||||
template <typename ElemType>
|
||||
static void DoEvalBNBase(const ConfigParameters& config, IDataReader& reader)
|
||||
{
|
||||
// DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
||||
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
||||
size_t epochSize = config(L"epochSize", "0");
|
||||
if (epochSize == 0)
|
||||
{
|
||||
epochSize = requestDataSize;
|
||||
}
|
||||
wstring modelPath = config(L"modelPath");
|
||||
wstring exportPath = modelPath + L".PBN";
|
||||
intargvector mbSize = minibatchSize;
|
||||
|
||||
int iters = config(L"iters", 240);
|
||||
|
||||
int traceLevel = config(L"traceLevel", "0");
|
||||
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
||||
size_t firstMBsToShowResult = config(L"firstMBsToShowResult", "0");
|
||||
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
||||
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
||||
|
||||
bool enableDistributedMBReading = config(L"distributedMBReading", false);
|
||||
|
||||
vector<wstring> evalNodeNamesVector;
|
||||
|
||||
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNamesVector);
|
||||
|
||||
// set tracing flags
|
||||
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
||||
|
||||
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), enableDistributedMBReading, numMBsToShowResult,
|
||||
firstMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
||||
eval.EvaluateBN(&reader, evalNodeNamesVector, exportPath, mbSize[0], iters, epochSize);
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
void DoEvalBN(const ConfigParameters& config)
|
||||
{
|
||||
// evaluate batch normalization mean and various
|
||||
ConfigParameters readerConfig(config(L"reader"));
|
||||
|
||||
// Should trace level to zero in Post BN?
|
||||
//readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
|
||||
DataReader evaBNDataReader(readerConfig);
|
||||
|
||||
DoEvalBNBase<ElemType>(config, evaBNDataReader);
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
void DoEval(const ConfigParameters& config)
|
||||
{
|
||||
|
@ -146,8 +90,6 @@ void DoEval(const ConfigParameters& config)
|
|||
DoEvalBase<ElemType>(config, testDataReader);
|
||||
}
|
||||
|
||||
template void DoEvalBN<double>(const ConfigParameters& config);
|
||||
template void DoEvalBN<float>(const ConfigParameters& config);
|
||||
template void DoEval<double>(const ConfigParameters& config);
|
||||
template void DoEval<float>(const ConfigParameters& config);
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "ScriptableObjects.h"
|
||||
#include "BrainScriptEvaluator.h"
|
||||
#include "BrainScriptParser.h"
|
||||
#include "PostComputingActions.h"
|
||||
|
||||
#include <string>
|
||||
#include <chrono>
|
||||
|
@ -235,3 +236,46 @@ void DoEdit(const ConfigParameters& config)
|
|||
|
||||
template void DoEdit<double>(const ConfigParameters& config);
|
||||
template void DoEdit<float>(const ConfigParameters& config);
|
||||
|
||||
// ===========================================================================
|
||||
// DoBatchNormalizationStat() - implements CNTK "bnstat" command
|
||||
// ===========================================================================
|
||||
|
||||
template <typename ElemType>
|
||||
void DoBatchNormalizationStat(const ConfigParameters& config)
|
||||
{
|
||||
ConfigParameters readerConfig(config(L"reader"));
|
||||
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
|
||||
auto dataReader = make_shared<DataReader>(readerConfig);
|
||||
|
||||
int traceLevel = config(L"traceLevel", "0");
|
||||
int itersPerNode = config(L"itersPerNode", 30);
|
||||
|
||||
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
||||
intargvector mbSize = minibatchSize;
|
||||
|
||||
bool enableDistributedMBReading = config(L"enableDistributedMBReading", false);
|
||||
|
||||
wstring curModelPath = config(L"modelPath", L"");
|
||||
wstring newModelPath = config(L"newModelPath", L"");
|
||||
if (newModelPath == L"")
|
||||
{
|
||||
newModelPath = curModelPath + L".PBN";
|
||||
}
|
||||
|
||||
std::vector<std::wstring> evalNodeNames;
|
||||
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNames);
|
||||
// set tracing flags
|
||||
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
||||
|
||||
PostComputingActions<ElemType> postComputingActions(net, MPIWrapper::GetInstance(), enableDistributedMBReading, traceLevel);
|
||||
|
||||
postComputingActions.BatchNormalizationStatistics(dataReader.get(), evalNodeNames, newModelPath, mbSize[0], itersPerNode);
|
||||
}
|
||||
|
||||
template void DoBatchNormalizationStat<double>(const ConfigParameters& config);
|
||||
template void DoBatchNormalizationStat<float>(const ConfigParameters& config);
|
||||
|
||||
|
|
|
@ -154,7 +154,7 @@ static void DisableLegacyUsage(const ConfigParameters& TopLevelConfig, const Con
|
|||
|
||||
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
|
||||
// be run in parallel across multiple ranks. Others should only run on rank 0
|
||||
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "pbn" };
|
||||
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "bnstat" };
|
||||
|
||||
// process the command
|
||||
template <typename ElemType>
|
||||
|
@ -243,9 +243,9 @@ void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mp
|
|||
LOGPRINTF(stderr, "CNTKCommandTrainEnd: %s\n", command[i].c_str());
|
||||
fullEpochsOffset += GetMaxEpochs(commandParams);
|
||||
}
|
||||
else if (thisAction == "pbn")
|
||||
else if (thisAction == "bnstat")
|
||||
{
|
||||
DoEvalBN<ElemType>(commandParams);
|
||||
DoBatchNormalizationStat<ElemType>(commandParams);
|
||||
}
|
||||
else if (thisAction == "adapt")
|
||||
{
|
||||
|
|
|
@ -136,10 +136,6 @@ public:
|
|||
// main entry point for backprop
|
||||
void Backprop(const ComputationNodeBasePtr rootNode);
|
||||
|
||||
// partial forward entry
|
||||
void ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode,
|
||||
const ComputationNodeBasePtr endNode);
|
||||
|
||||
template <class NODESET> // version that takes multiple nodes
|
||||
void ForwardProp(const NODESET& nodes)
|
||||
{
|
||||
|
@ -678,6 +674,44 @@ public:
|
|||
return nodesWithType;
|
||||
}
|
||||
|
||||
// Get the eval nodes with names
|
||||
// if evalNodeNames are not specified, return all the default evalnodes and training criterion nodes.
|
||||
std::vector<ComputationNodeBasePtr> GetEvalNodesWithName(const std::vector<wstring> evalNodeNames)
|
||||
{
|
||||
// determine nodes to evaluate
|
||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||
|
||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
||||
if (evalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (EvaluationNodes().empty() && FinalCriterionNodes().empty())
|
||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
||||
|
||||
for (const auto& node : EvaluationNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
|
||||
for (const auto& node : FinalCriterionNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
||||
{
|
||||
const auto& node = GetNodeFromName(evalNodeNames[i]);
|
||||
if (!criteriaLogged.insert(node).second)
|
||||
continue;
|
||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
return evalNodes;
|
||||
}
|
||||
|
||||
public:
|
||||
// return list of nodes that require precomputation and not precomputed yet
|
||||
std::list<ComputationNodeBasePtr> GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true);
|
||||
|
@ -1039,8 +1073,6 @@ protected:
|
|||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool);
|
||||
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool);
|
||||
|
||||
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) override;
|
||||
|
||||
public:
|
||||
// this special constructor constructs the top-level network node
|
||||
// There is currently no other constructor for inner nested PAR-traversed sub-networks, but there will be.
|
||||
|
|
|
@ -79,17 +79,6 @@ void ComputationNetwork::Backprop(const ComputationNodeBasePtr rootNode) // trai
|
|||
GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true);
|
||||
}
|
||||
|
||||
void ComputationNetwork::ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode, const ComputationNodeBasePtr endNode)
|
||||
{
|
||||
VerifyIsCompiled("ForwardProp");
|
||||
|
||||
// traverse partial nodes as inputs
|
||||
shared_ptr<FlowControlNode> network = dynamic_pointer_cast<FlowControlNode>(GetNestedNetwork(rootNode));
|
||||
assert(network);
|
||||
|
||||
network->ForwardProp(FrameRange(nullptr), startNode, endNode);
|
||||
}
|
||||
|
||||
void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode)
|
||||
{
|
||||
if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end())
|
||||
|
@ -158,7 +147,6 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
|
||||
{
|
||||
childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
|
||||
|
@ -187,36 +175,7 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
|||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
|
||||
{
|
||||
}
|
||||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const FrameRange & fr, ComputationNodeBasePtr startNode, ComputationNodeBasePtr endNode)
|
||||
{
|
||||
// if start node is nullptr, forward will be enable
|
||||
bool enableForward = startNode ? false : true;
|
||||
|
||||
for (auto& node : m_nestedNodes)
|
||||
{
|
||||
#if 0
|
||||
if (dynamic_pointer_cast<LearnableParameter<float>>(node))
|
||||
dynamic_pointer_cast<ComputationNode<float>>(node)->DebugLogMinibatch();
|
||||
#endif
|
||||
if (node->IsOutOfDateWrtInputs() && enableForward)
|
||||
{
|
||||
node->BeginForwardProp();
|
||||
node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
|
||||
node->EndForwardProp();
|
||||
|
||||
node->BumpEvalTimeStamp();
|
||||
}
|
||||
|
||||
if (node == startNode)
|
||||
{
|
||||
enableForward = true;
|
||||
}
|
||||
else if (node == endNode)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling)
|
||||
|
|
|
@ -1878,10 +1878,6 @@ public:
|
|||
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const override {}
|
||||
virtual std::set<std::pair<const MatrixBase*, std::wstring>> GetMatrixInfo() const override { NOT_IMPLEMENTED; }
|
||||
|
||||
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) { NOT_IMPLEMENTED; }
|
||||
|
||||
std::vector<ComputationNodeBasePtr> GetNestedNodes() { return m_nestedNodes; }
|
||||
|
||||
protected: public: // needed in ComputationNetwork::FindInRecurrentLoops(), which really should be part of SEQTraversalFlowControlNode
|
||||
std::vector<ComputationNodeBasePtr> m_nestedNodes; // nodes tucked away in this node, in evaluation order
|
||||
};
|
||||
|
|
|
@ -37,6 +37,7 @@ public:
|
|||
MarkValueNonSharable();
|
||||
m_initString = L"fromValue"; // default init is with 0; typically overwritten
|
||||
m_initValue = 0;
|
||||
m_regMultiplier = 1.0f; // enable reg in update by default
|
||||
}
|
||||
LearnableParameter(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& shape) :
|
||||
LearnableParameter(deviceId, name)
|
||||
|
@ -101,6 +102,14 @@ public:
|
|||
// called from CloneFunction(..., parameters="constant")
|
||||
virtual void FreezeParameters() override; // from IFreezable
|
||||
|
||||
// Setting the reg multiplier for a learnable node, effecting L1Reg and L2Reg both.
|
||||
void SetRegMultiplier(float regMultiplier)
|
||||
{
|
||||
m_regMultiplier = regMultiplier;
|
||||
}
|
||||
// called from SGD UpdateWeights, to adjust the reg for each node
|
||||
float GetRegMultiplier() const { return m_regMultiplier; }
|
||||
|
||||
private:
|
||||
// init parameters for deferred initialization (which happens in Validate())
|
||||
std::wstring m_initString; // if non-empty then deferred initialization is needed. Gets cleared upon completion of deferred init.
|
||||
|
@ -109,6 +118,9 @@ private:
|
|||
int m_initOutputRank;
|
||||
bool m_initOnCPUOnly;
|
||||
ElemType m_initValue;
|
||||
|
||||
// flags related to gradient update
|
||||
float m_regMultiplier; // The multiplier to adjust the L1Reg and L2Reg for Learnable node
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "ComputationNode.h"
|
||||
#include "BatchNormalizationEngine.h"
|
||||
#include "RNGHandle.h"
|
||||
#include "InputAndParamNodes.h"
|
||||
|
||||
#define __STDC_FORMAT_MACROS
|
||||
#include <inttypes.h>
|
||||
|
@ -1587,15 +1588,15 @@ class BatchNormalizationNode : public ComputationNodeNonLooping<ElemType>, publi
|
|||
public:
|
||||
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) :
|
||||
Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true),
|
||||
m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW), m_postBatchNormalization(false), m_swapNormTimeConst(0),
|
||||
m_swapBlendTimeConst(0), m_convertRunningVariancePending(false)
|
||||
m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW),
|
||||
m_convertRunningVariancePending(false)
|
||||
{
|
||||
}
|
||||
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant,
|
||||
double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) :
|
||||
Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant),
|
||||
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0), m_postBatchNormalization(false),
|
||||
m_swapNormTimeConst(0), m_swapBlendTimeConst(0), m_convertRunningVariancePending(false)
|
||||
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0),
|
||||
m_convertRunningVariancePending(false)
|
||||
{
|
||||
}
|
||||
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
|
@ -1605,9 +1606,6 @@ public:
|
|||
ImageLayoutKindFrom(configp->Get(L"imageLayout")))
|
||||
{
|
||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||
m_postBatchNormalization = false;
|
||||
m_swapNormTimeConst = 0;
|
||||
m_swapBlendTimeConst = 0;
|
||||
}
|
||||
|
||||
void Save(File& fstream) const override
|
||||
|
@ -1724,7 +1722,7 @@ private: // time-constant conversions
|
|||
double ComputeExpAvgFactor() const
|
||||
{
|
||||
// in inference mode, only use long-term mean and do not update running estimates
|
||||
if (!Environment().IsTraining() && !m_postBatchNormalization)
|
||||
if (!Environment().IsTraining())
|
||||
{
|
||||
if (m_samplesSeen == 0)
|
||||
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
||||
|
@ -1756,7 +1754,7 @@ private: // time-constant conversions
|
|||
double ComputeBlendFactor() const
|
||||
{
|
||||
// in inference mode, only use long-term mean and do not update running estimates
|
||||
if (!Environment().IsTraining() && !m_postBatchNormalization)
|
||||
if (!Environment().IsTraining())
|
||||
{
|
||||
if (m_samplesSeen == 0)
|
||||
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
||||
|
@ -1805,7 +1803,7 @@ public:
|
|||
// In inference-only mode, m_savedMean and m_saveInvStdDev will not be
|
||||
// produced and BackpropToNonLooping() may not be called. In
|
||||
// non-inference (training) mode, saved statistics must be produced.
|
||||
bool inferenceOnly = !Environment().IsTraining() && !m_postBatchNormalization;
|
||||
bool inferenceOnly = !Environment().IsTraining();
|
||||
m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in)
|
||||
inferenceOnly, expAvgFactor, blendFactor,
|
||||
runMean, runVariance, // (in/out) running estimates, updated from the current MB mean/variance
|
||||
|
@ -1870,14 +1868,6 @@ public:
|
|||
}
|
||||
|
||||
virtual void EndForwardProp() override
|
||||
{
|
||||
if(m_postBatchNormalization)
|
||||
m_samplesSeen += GetMBLayout()->GetActualNumSamples();
|
||||
|
||||
Base::EndForwardProp();
|
||||
}
|
||||
|
||||
virtual void EndBackprop() override
|
||||
{
|
||||
// Update samples if not locked.
|
||||
double expAvgFactor = ComputeExpAvgFactor(); // weight for the new MB statistics in the running estimate. The previous value of the running statistics is kept with weight (1-this)
|
||||
|
@ -2019,28 +2009,29 @@ public:
|
|||
m_blendTimeConst = std::numeric_limits<double>::infinity();
|
||||
}
|
||||
|
||||
// ResetStatisticsState will set the batch normal statistics into initial state
|
||||
// used for re-statistics the mean and variance of BN
|
||||
// any others use may lead undependable results, please be careful
|
||||
void ResetStatisticsState()
|
||||
{
|
||||
m_samplesSeen = 0;
|
||||
m_normTimeConst = 0;
|
||||
m_blendTimeConst = 0;
|
||||
}
|
||||
// Turn off the L1 and L2 regularization
|
||||
void DisableRegInBatchNormalization()
|
||||
{
|
||||
let scaleNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(1));
|
||||
let biasNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(2));
|
||||
scaleNode->SetRegMultiplier(0.f);
|
||||
biasNode->SetRegMultiplier(0.f);
|
||||
}
|
||||
double NormalizationTimeConstant() const { return m_normTimeConst; }
|
||||
double BlendTimeConstant() const { return m_blendTimeConst; }
|
||||
bool Spatial() const { return m_spatial; }
|
||||
double Epsilon() const { return m_epsilon; }
|
||||
bool UseCNTKEngine() const { return m_useCntkEngine; }
|
||||
|
||||
void SetPostBatchNormalizationBegin()
|
||||
{
|
||||
m_postBatchNormalization = true;
|
||||
m_samplesSeen = 0;
|
||||
m_swapNormTimeConst = m_normTimeConst;
|
||||
m_swapBlendTimeConst = m_blendTimeConst;
|
||||
m_normTimeConst = -1;
|
||||
m_blendTimeConst = 0;
|
||||
}
|
||||
void SetPostBatchNormalizationEnd()
|
||||
{
|
||||
m_postBatchNormalization = false;
|
||||
m_normTimeConst = m_swapNormTimeConst;
|
||||
m_blendTimeConst = m_swapBlendTimeConst;
|
||||
}
|
||||
|
||||
private:
|
||||
// Old versioning - do not use. Do not remove until we're sure there are no old models around.
|
||||
struct VersionInfo
|
||||
|
@ -2104,11 +2095,6 @@ private:
|
|||
|
||||
std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng;
|
||||
|
||||
// post batch normalization process mark
|
||||
bool m_postBatchNormalization;
|
||||
|
||||
double m_swapNormTimeConst;
|
||||
double m_swapBlendTimeConst;
|
||||
bool m_convertRunningVariancePending;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
// PostStat.cpp -- CNTK post statistics related actions
|
||||
//
|
||||
|
||||
#include "PostComputingActions.h"
|
||||
|
||||
#include "TrainingNodes.h"
|
||||
#include "ProgressTracing.h"
|
||||
#include "DataReaderHelpers.h"
|
||||
#include "SimpleDistGradAggregator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace Microsoft { namespace MSR{ namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader * dataReader, const vector<wstring>& evalNodeNames,
|
||||
const wstring newModelPath, const size_t mbSize, const int iters)
|
||||
{
|
||||
// since the mean and variance of bn will be modified in statistics,
|
||||
// training mode will make it work. And there is no back prop, other parameters
|
||||
// are fixed during computing.
|
||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::training);
|
||||
|
||||
// bn nodes need to be computed from bottom to top with evaluating order
|
||||
let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
|
||||
|
||||
// find all the BN nodes by evalOrder
|
||||
std::vector<ComputationNodeBasePtr> bnNodes;
|
||||
std::set<ComputationNodeBasePtr> bnNodesLogged; // (avoid double record of batch normalization nodes)
|
||||
for (auto& evalNode : evalNodes)
|
||||
{
|
||||
for (auto& node : m_net->GetEvalOrder(evalNode))
|
||||
{
|
||||
let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
if (bnNode)
|
||||
{
|
||||
if (bnNodesLogged.insert(node).second)
|
||||
{
|
||||
// reset the statistics states of bn nodes
|
||||
bnNode->ResetStatisticsState();
|
||||
bnNode->SetNormalizationTimeConstants(-1, bnNode->NormalizationTimeConstant(),
|
||||
0, bnNode->BlendTimeConstant());
|
||||
bnNodes.push_back(node);
|
||||
// add BN nodes into the evaluation group, then they will be added into root nodes when
|
||||
// the network re-compile
|
||||
m_net->AddToNodeGroup(L"evaluation", bnNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// re-compile the network to add bn nodes as rootNodes.
|
||||
m_net->CompileNetwork();
|
||||
|
||||
// allocate memory for all bnNodes evalOrder
|
||||
m_net->AllocateAllMatrices(bnNodes, std::vector<ComputationNodeBasePtr>(), nullptr);
|
||||
|
||||
// prepare features
|
||||
auto& featureNodes = m_net->FeatureNodes();
|
||||
|
||||
StreamMinibatchInputs inputMatrices;
|
||||
for (auto& node : featureNodes)
|
||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||
|
||||
bool useParallelTrain = (m_mpi != nullptr);
|
||||
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
|
||||
size_t totalEpochSize = bnNodes.size() * mbSize * iters;
|
||||
|
||||
m_net->StartEvaluateMinibatchLoop(bnNodes);
|
||||
|
||||
if (useDistributedMBReading)
|
||||
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), totalEpochSize);
|
||||
else
|
||||
dataReader->StartMinibatchLoop(mbSize, 0, totalEpochSize);
|
||||
|
||||
for (auto& node : bnNodes)
|
||||
{
|
||||
let bnNode = static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
size_t actualMBSize = 0;
|
||||
|
||||
LOGPRINTF(stderr, "Estimating Statistics --> %ls\n", bnNode->GetName().c_str());
|
||||
|
||||
|
||||
// for every single bn node, the statistics is the average of mean and variance for several times in forward prop
|
||||
// the forward prop is from the feature to the current bn node
|
||||
for (int iter = 0; iter < iters; iter++)
|
||||
{
|
||||
// during the bn stat, dataRead must be ensured
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
|
||||
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
|
||||
if (!wasDataRead) LogicError("DataRead Failure in batch normalization statistics");
|
||||
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
|
||||
// forward prop till reaching the current bn node
|
||||
m_net->ForwardProp(node);
|
||||
}
|
||||
|
||||
// after finished statistics, the mean and variance of the bn node should be freezd.
|
||||
bnNode->FreezeParameters();
|
||||
|
||||
// Sync during or after all iters of a BN node are equivalent
|
||||
if (useParallelTrain)
|
||||
{
|
||||
if (m_gradHeader == nullptr)
|
||||
{
|
||||
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
|
||||
{
|
||||
DistGradHeader::Destroy(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
// push the statistics results of mean and variance of bn nodes into mpi updating vector
|
||||
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
|
||||
|
||||
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
|
||||
|
||||
auto runMeanParameterPtr = node->Input(3);
|
||||
auto runStdParameterPtr = node->Input(4);
|
||||
|
||||
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
|
||||
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
|
||||
|
||||
learnParamsValues[0] = &(runMeanNode->Value());
|
||||
learnParamsValues[1] = &(runStdNode->Value());
|
||||
|
||||
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
|
||||
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), 0);
|
||||
|
||||
// get the average mean and variance across all the workers
|
||||
for (auto& parameter : learnParamsValues)
|
||||
{
|
||||
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dataReader->DataEnd();
|
||||
|
||||
// remove all the added BN nodes from evaluation group
|
||||
for (auto& bnNode : bnNodes)
|
||||
{
|
||||
m_net->RemoveFromNodeGroup(L"evaluation", bnNode);
|
||||
}
|
||||
|
||||
// save model
|
||||
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
|
||||
m_net->Save(newModelPath);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template class PostComputingActions<float>;
|
||||
template class PostComputingActions<double>;
|
||||
|
||||
}}}
|
|
@ -0,0 +1,65 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
// PostStat.h -- CNTK post statistics related actions
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#include "ComputationNode.h"
|
||||
#include "ComputationNetwork.h"
|
||||
#include "MPIWrapper.h"
|
||||
#include "DataReader.h"
|
||||
#include "IDistGradAggregator.h"
|
||||
#include "DistGradHeader.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class IDistGradAggregator;
|
||||
|
||||
// Post statistics normally called between training and evaluating, to generate the statistics results used by evaluating
|
||||
// For now, the application is only with statistics mean and variance of Batch Normalization nodes after training
|
||||
template <class ElemType>
|
||||
class PostComputingActions
|
||||
{
|
||||
public:
|
||||
PostComputingActions(ComputationNetworkPtr net, const MPIWrapperPtr& mpi, bool enableDistributedMBReading = false, const int traceLevel = 0) :
|
||||
m_net(net),
|
||||
m_traceLevel(traceLevel),
|
||||
m_mpi(mpi),
|
||||
m_distGradAgg(nullptr),
|
||||
m_gradHeader(nullptr),
|
||||
m_enableDistributedMBReading(enableDistributedMBReading)
|
||||
{
|
||||
}
|
||||
|
||||
// This function is used for evaluating the mean and variance of all batch normalization nodes after training.
|
||||
// Details will link to the wiki https://github.com/Microsoft/CNTK/wiki/Post-Batch-Normalization-Statistics
|
||||
// The reason why put it into evalute is the action take place after trainning and non-backprop processing, which makes me believe
|
||||
// this function is like a kind of evaluate function.
|
||||
// In this function,
|
||||
// 1. since all other weights are fix except the un-pbn nodes, I set the networkoperationMode into inferring.
|
||||
// 2. The next thing is to load the network model and data source, I follow the Evaluate function to do so, however, I delete something
|
||||
// seem useless, like error statistics etc.
|
||||
// 3. Finding the BN nodes in the network and put them into a vector with evaluate order (This links the nestedNode vector I got in
|
||||
// ControlFlowNetwork)
|
||||
// 4. From node to node in the BN vector to generate the mean and various (This links to the changes of BatchNormalizationNode
|
||||
// in TrainingNodes.h, since I need to make the nodes "learn" mean and variance in inferring mode)
|
||||
// 5. Consider the multi-GPU, we need to sync up the BN results between all the worker and average the value.
|
||||
void BatchNormalizationStatistics(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring newModelPath,
|
||||
const size_t mbSize, const int iters = 30);
|
||||
|
||||
private:
|
||||
ComputationNetworkPtr m_net;
|
||||
MPIWrapperPtr m_mpi;
|
||||
bool m_enableDistributedMBReading;
|
||||
|
||||
int m_traceLevel;
|
||||
|
||||
std::shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg;
|
||||
std::shared_ptr<struct DistGradHeader> m_gradHeader;
|
||||
};
|
||||
}}}
|
|
@ -8,6 +8,7 @@
|
|||
#include "SpecialPurposeNodes.h" // for SequenceWithSoftmaxNode
|
||||
#include "DataReaderHelpers.h"
|
||||
#include "MatrixQuantizerImpl.h"
|
||||
#include "InputAndParamNodes.h"
|
||||
|
||||
#ifdef CNTK_PARALLEL_TRAINING_SUPPORT
|
||||
//static inline bool operator==(const std::pair<double,size_t>& a, double b) { assert(b==0); return a.first == b; }
|
||||
|
@ -875,23 +876,13 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
EpochCriterion epochCriterionLastLogged = epochCriterion;
|
||||
vector<EpochCriterion> epochEvalErrorsLastLogged = epochEvalErrors;
|
||||
|
||||
// Now, we need to use a switch to enable/disable wk in BatchNormalization.
|
||||
// If we can determine whether wk added or not for each node, then, discard this
|
||||
std::unordered_set<ComputationNodeBasePtr> batchNormalizationWeights;
|
||||
if (m_disableWkInBatchNormal) {
|
||||
for (auto& evalNode : evaluationNodes)
|
||||
// NOTE: For ResNet, the regularization in BatchNormalization should be disable.
|
||||
if (m_disableRegInBatchNormalization) {
|
||||
let bnNodes = net->GetNodesWithType(L"BatchNormalization");
|
||||
for (auto &node : bnNodes)
|
||||
{
|
||||
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(net->GetNestedNetwork(evalNode));
|
||||
for (auto& node : nestedNetwork->GetNestedNodes())
|
||||
{
|
||||
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
|
||||
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
if (castNode)
|
||||
{
|
||||
batchNormalizationWeights.insert(castNode->GetInputs()[1]);
|
||||
batchNormalizationWeights.insert(castNode->GetInputs()[2]);
|
||||
}
|
||||
}
|
||||
let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
bnNode->DisableRegInBatchNormalization();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1110,11 +1101,10 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
if (smoothedGradient.HasNan("TrainOneEpoch/UpdateWeights(): "))
|
||||
LogicError("%ls %ls operation has NaNs in smoothedGradient.", node->NodeName().c_str(), node->OperationName().c_str());
|
||||
#endif
|
||||
double l2Factor = batchNormalizationWeights.find(node) == batchNormalizationWeights.end() ? 1.0 : 0.0;
|
||||
// BUGBUG (Issue #95): Access to net MBLayout can no longer be done if we have multiple input layouts
|
||||
UpdateWeights(node, smoothedGradient, learnRatePerSample,
|
||||
GetMomentumPerSample(epochNumber /*BUGBUG workaround:*/, net->GetMBLayoutPtrOfNetwork()->GetNumParallelSequences()), numSamplesInMinibatch,
|
||||
m_L2RegWeight * l2Factor, m_L1RegWeight,
|
||||
m_L2RegWeight, m_L1RegWeight,
|
||||
m_needAveMultiplier, m_useNesterovMomentum);
|
||||
#ifdef _DEBUG
|
||||
if (dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value().HasNan("TrainOneEpoch/UpdateWeights(): "))
|
||||
|
@ -2017,9 +2007,10 @@ void SGD<ElemType>::UpdateWeights(const ComputationNodeBasePtr& node,
|
|||
LogicError("UpdateWeights() called for a learnable ComputationNode which has m_learningRateMultiplier == 0!");
|
||||
|
||||
double nodeDependentLearningRatePerSample = learnRatePerSample * node->GetLearningRateMultiplier();
|
||||
double nodeDependentRegMultiplier = dynamic_pointer_cast<LearnableParameter<ElemType>>(node)->GetRegMultiplier();
|
||||
UpdateWeightsS(this, dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value(), dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Gradient(),
|
||||
smoothedGradient, nodeDependentLearningRatePerSample, momentumPerSample,
|
||||
actualMBSize, L2RegWeight, L1RegWeight,
|
||||
actualMBSize, L2RegWeight * nodeDependentRegMultiplier, L1RegWeight * nodeDependentRegMultiplier,
|
||||
needAveMultiplier, m_useNesterovMomentum);
|
||||
node->BumpEvalTimeStamp();
|
||||
}
|
||||
|
@ -2475,7 +2466,7 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
|||
m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0);
|
||||
m_seqGammarCalcWP = configSGD(L"seqGammarWordPen", 0.0);
|
||||
|
||||
m_disableWkInBatchNormal = configSGD(L"disableWkInBatchNormal", false);
|
||||
m_disableRegInBatchNormalization = configSGD(L"disableRegInBatchNormalization", false);
|
||||
|
||||
m_dropoutRates = configSGD(L"dropoutRate", ConfigRecordType::Array(doubleargvector(vector<double>{0.0})));
|
||||
m_batchNormalizationTimeConstant = configSGD(L"batchNormalizationTimeConstant", ConfigRecordType::Array(doubleargvector(vector<double>{0})));
|
||||
|
|
|
@ -291,7 +291,10 @@ protected:
|
|||
double m_seqGammarCalcbMMIFactor;
|
||||
bool m_seqGammarCalcUsesMBR;
|
||||
|
||||
bool m_disableWkInBatchNormal;
|
||||
// decide whether should apply L2 regularization into BatchNormalizationNode
|
||||
// true: disable L2 Regularization
|
||||
// false: enable L2 Regularization (default)
|
||||
bool m_disableRegInBatchNormalization;
|
||||
};
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -124,6 +124,7 @@
|
|||
<ClInclude Include="..\ComputationNetworkLib\NonlinearityNodes.h" />
|
||||
<ClInclude Include="..\ComputationNetworkLib\RecurrentNodes.h" />
|
||||
<ClInclude Include="MASGD.h" />
|
||||
<ClInclude Include="PostComputingActions.h" />
|
||||
<ClInclude Include="SimpleDistGradAggregator.h" />
|
||||
<ClInclude Include="SimpleEvaluator.h" />
|
||||
<ClInclude Include="SimpleOutputWriter.h" />
|
||||
|
@ -132,6 +133,7 @@
|
|||
<ClInclude Include="targetver.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="PostComputingActions.cpp" />
|
||||
<ClCompile Include="Profiler.cpp" />
|
||||
<ClCompile Include="SGD.cpp" />
|
||||
<ClCompile Include="stdafx.cpp" />
|
||||
|
|
|
@ -1,32 +1,17 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\Common\DataReader.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\Common\DataWriter.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\Common\File.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\Common\fileutil.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<Filter>Misc</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\Common\TimerUtility.cpp">
|
||||
<Filter>Common</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Profiler.cpp">
|
||||
<Filter>GPU Interfacing</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="SGD.cpp">
|
||||
<Filter>SGD</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\Common\Config.cpp">
|
||||
<Filter>Common</Filter>
|
||||
<ClCompile Include="PostComputingActions.cpp">
|
||||
<Filter>Stat</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
|
@ -144,6 +129,9 @@
|
|||
<ClInclude Include="Criterion.h">
|
||||
<Filter>SGD</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="PostComputingActions.h">
|
||||
<Filter>Stat</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="Common">
|
||||
|
@ -182,5 +170,8 @@
|
|||
<Filter Include="Data Reading">
|
||||
<UniqueIdentifier>{b866d513-7bd0-497c-98c2-f62dbcd4cde4}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="Stat">
|
||||
<UniqueIdentifier>{f406217f-5a11-44ca-bb34-52254dbee8af}</UniqueIdentifier>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -52,36 +52,7 @@ public:
|
|||
{
|
||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
||||
|
||||
// determine nodes to evaluate
|
||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||
|
||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
||||
if (evalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
|
||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
||||
|
||||
for (const auto& node : m_net->EvaluationNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
|
||||
for (const auto& node : m_net->FinalCriterionNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
||||
{
|
||||
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
|
||||
if (!criteriaLogged.insert(node).second)
|
||||
continue;
|
||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
|
||||
|
||||
// initialize eval results
|
||||
std::vector<EpochCriterion> evalResults(evalNodes.size(), EpochCriterion(0));
|
||||
|
@ -257,153 +228,6 @@ public:
|
|||
return evalResults;
|
||||
}
|
||||
|
||||
void EvaluateBN(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring exportPath, const size_t mbSize, const int iters = 240, const size_t testSize = requestDataSize)
|
||||
{
|
||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
||||
|
||||
// determine nodes to evaluate
|
||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||
|
||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
||||
if (evalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
|
||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
||||
|
||||
for (const auto& node : m_net->EvaluationNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
|
||||
for (const auto& node : m_net->FinalCriterionNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
||||
{
|
||||
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
|
||||
if (!criteriaLogged.insert(node).second)
|
||||
continue;
|
||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// allocate memory for forward computation
|
||||
m_net->AllocateAllMatrices(evalNodes, {}, nullptr);
|
||||
|
||||
// prepare features and labels
|
||||
auto& featureNodes = m_net->FeatureNodes();
|
||||
auto& labelNodes = m_net->LabelNodes();
|
||||
|
||||
StreamMinibatchInputs inputMatrices;
|
||||
for (auto& node : featureNodes)
|
||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||
for (auto& node : labelNodes)
|
||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||
|
||||
bool useParallelTrain = (m_mpi != nullptr);
|
||||
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
|
||||
if (useDistributedMBReading)
|
||||
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize);
|
||||
else
|
||||
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
|
||||
|
||||
m_net->StartEvaluateMinibatchLoop(evalNodes);
|
||||
|
||||
// Passing in two empty node lists so the dispatcher can work for the evalNodes.
|
||||
std::list<ComputationNodeBasePtr> learnableNodes;
|
||||
std::vector<ComputationNodeBasePtr> criterionNodes;
|
||||
|
||||
// First, all batch normalization nodes should be marked.
|
||||
std::vector<ComputationNodeBasePtr> batchNormalNodes;
|
||||
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(m_net->GetNestedNetwork(evalNodes[0]));
|
||||
for (auto& node : nestedNetwork->GetNestedNodes())
|
||||
{
|
||||
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
|
||||
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
if (castNode)
|
||||
{
|
||||
batchNormalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Push all batch normalization mean and std into learn params values for mpi update
|
||||
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
|
||||
|
||||
bool noMoreSamplesToProcess = false;
|
||||
for (auto& node : batchNormalNodes)
|
||||
{
|
||||
shared_ptr<BatchNormalizationNode<ElemType>> batchNode =
|
||||
static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
batchNode->SetPostBatchNormalizationBegin();
|
||||
size_t actualMBSize = 0;
|
||||
|
||||
LOGPRINTF(stderr, "Start evaluating: %ls\n", batchNode->GetName().c_str());
|
||||
|
||||
// Post batch normal iters
|
||||
for (int iter = 0; iter < iters; iter++)
|
||||
{
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
|
||||
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
|
||||
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess))
|
||||
break;
|
||||
|
||||
// TODO should handle it, since post BN exist no samples in iters
|
||||
if (!wasDataRead)
|
||||
actualMBSize = 0;
|
||||
|
||||
// Batch Normalization Evaluate don't need to support subMinibatches
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
ComputationNetwork::BumpEvalTimeStamp(labelNodes);
|
||||
|
||||
m_net->ForwardProp(evalNodes[0], nullptr, node);
|
||||
dataReader->DataEnd();
|
||||
}
|
||||
batchNode->SetPostBatchNormalizationEnd();
|
||||
|
||||
// Sync during or after all iters of a BN node are equivalent
|
||||
if (useParallelTrain)
|
||||
{
|
||||
if (m_gradHeader == nullptr)
|
||||
{
|
||||
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
|
||||
{
|
||||
DistGradHeader::Destroy(ptr);
|
||||
});
|
||||
}
|
||||
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
|
||||
|
||||
auto runMeanParameterPtr = node->GetInputs()[3];
|
||||
auto runStdParameterPtr = node->GetInputs()[4];
|
||||
|
||||
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
|
||||
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
|
||||
|
||||
learnParamsValues[0] = &(runMeanNode->Value());
|
||||
learnParamsValues[1] = &(runStdNode->Value());
|
||||
|
||||
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
|
||||
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), 0);
|
||||
|
||||
for (auto& parameter : learnParamsValues)
|
||||
{
|
||||
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save Model
|
||||
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
|
||||
m_net->Save(exportPath);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
protected:
|
||||
void DisplayEvalStatistics(const size_t startMBNum, const size_t endMBNum, const size_t numSamplesLastLogged,
|
||||
const vector<ComputationNodeBasePtr>& evalNodes,
|
||||
|
|
Загрузка…
Ссылка в новой задаче