Merge branch 'master' into qiwye/asgd-dev
Conflicts: Source/SGDLib/SGD.cpp Source/SGDLib/SGDLib.vcxproj.filters
This commit is contained in:
Коммит
4dd1625797
3
CNTK.sln
3
CNTK.sln
|
@ -1278,13 +1278,16 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PythonBindings", "bindings\
|
|||
{9BD0A711-0BBD-45B6-B81C-053F03C26CFB} = {9BD0A711-0BBD-45B6-B81C-053F03C26CFB}
|
||||
{33D2FD22-DEF2-4507-A58A-368F641AEBE5} = {33D2FD22-DEF2-4507-A58A-368F641AEBE5}
|
||||
{D667AF32-028A-4A5D-BE19-F46776F0F6B2} = {D667AF32-028A-4A5D-BE19-F46776F0F6B2}
|
||||
{7B7A563D-AA8E-4660-A805-D50235A02120} = {7B7A563D-AA8E-4660-A805-D50235A02120}
|
||||
{9A2F2441-5972-4EA8-9215-4119FCE0FB68} = {9A2F2441-5972-4EA8-9215-4119FCE0FB68}
|
||||
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
|
||||
{91973E60-A7BE-4C86-8FDB-59C88A0B3715} = {91973E60-A7BE-4C86-8FDB-59C88A0B3715}
|
||||
{014DA766-B37B-4581-BC26-963EA5507931} = {014DA766-B37B-4581-BC26-963EA5507931}
|
||||
{CE429AA2-3778-4619-8FD1-49BA3B81197B} = {CE429AA2-3778-4619-8FD1-49BA3B81197B}
|
||||
{EF766CAE-9CB1-494C-9153-0030631A6340} = {EF766CAE-9CB1-494C-9153-0030631A6340}
|
||||
{62836DC1-DF77-4B98-BF2D-45C943B7DDC6} = {62836DC1-DF77-4B98-BF2D-45C943B7DDC6}
|
||||
{E5606ECE-48CA-4464-BB12-09D81D02B9EF} = {E5606ECE-48CA-4464-BB12-09D81D02B9EF}
|
||||
{482999D1-B7E2-466E-9F8D-2119F93EAFD9} = {482999D1-B7E2-466E-9F8D-2119F93EAFD9}
|
||||
{1D5787D4-52E4-45DB-951B-82F220EE0C6A} = {1D5787D4-52E4-45DB-951B-82F220EE0C6A}
|
||||
{7B7A51ED-AA8E-4660-A805-D50235A02120} = {7B7A51ED-AA8E-4660-A805-D50235A02120}
|
||||
{E6646FFE-3588-4276-8A15-8D65C22711C1} = {E6646FFE-3588-4276-8A15-8D65C22711C1}
|
||||
|
|
|
@ -46,7 +46,7 @@ Train=[
|
|||
L2RegWeight=0.0001
|
||||
dropoutRate=0
|
||||
|
||||
disableWkInBatchNormal=true
|
||||
disableRegInBatchNormalization=true
|
||||
|
||||
ParallelTrain=[
|
||||
parallelizationMethod="DataParallelSGD"
|
||||
|
@ -88,11 +88,12 @@ Train=[
|
|||
]
|
||||
|
||||
PBN=[
|
||||
action="pbn"
|
||||
action="bnstat"
|
||||
modelPath="$ModelDir$/ResNet_50"
|
||||
# Set minibatch size for testing.
|
||||
minibatchSize=256
|
||||
iters=30
|
||||
itersPerNode=30
|
||||
enableDistributedMBReading=true
|
||||
|
||||
reader=[
|
||||
readerType="ImageReader"
|
||||
|
|
3
Makefile
3
Makefile
|
@ -467,7 +467,8 @@ EVAL:=eval
|
|||
|
||||
SGDLIB_SRC=\
|
||||
$(SOURCEDIR)/SGDLib/Profiler.cpp \
|
||||
$(SOURCEDIR)/SGDLib/SGD.cpp
|
||||
$(SOURCEDIR)/SGDLib/SGD.cpp \
|
||||
$(SOURCEDIR)/SGDLib/PostComputingActions.cpp \
|
||||
|
||||
EVAL_SRC=\
|
||||
$(SOURCEDIR)/EvalDll/CNTKEval.cpp \
|
||||
|
|
|
@ -42,11 +42,11 @@ template <typename ElemType>
|
|||
void DoDumpNodes(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoEdit(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoBatchNormalizationStat(const ConfigParameters& config);
|
||||
|
||||
// evaluation (EvalActions.cpp)
|
||||
template <typename ElemType>
|
||||
void DoEvalBN(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoEval(const ConfigParameters& config);
|
||||
template <typename ElemType>
|
||||
void DoCrossValidate(const ConfigParameters& config);
|
||||
|
|
|
@ -80,8 +80,6 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -91,10 +89,6 @@
|
|||
<Command>if exist "%ProgramW6432%\NVIDIA Corporation\NVSMI" xcopy /I /D /Y "%ProgramW6432%\NVIDIA Corporation\NVSMI\nvml*.dll" "$(TargetDir)"</Command>
|
||||
<Message>Copying NVidia GDK extension DLL to target folder</Message>
|
||||
</PostBuildEvent>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\Common\CrossProcessMutex.h" />
|
||||
|
|
|
@ -84,65 +84,6 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
|
|||
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// DoEvalBNBase() - implements CNTK "pbn" command
|
||||
// ===========================================================================
|
||||
|
||||
template <typename ElemType>
|
||||
static void DoEvalBNBase(const ConfigParameters& config, IDataReader& reader)
|
||||
{
|
||||
// DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
||||
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
||||
size_t epochSize = config(L"epochSize", "0");
|
||||
if (epochSize == 0)
|
||||
{
|
||||
epochSize = requestDataSize;
|
||||
}
|
||||
wstring modelPath = config(L"modelPath");
|
||||
wstring exportPath = modelPath + L".PBN";
|
||||
intargvector mbSize = minibatchSize;
|
||||
|
||||
int iters = config(L"iters", 240);
|
||||
|
||||
int traceLevel = config(L"traceLevel", "0");
|
||||
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
||||
size_t firstMBsToShowResult = config(L"firstMBsToShowResult", "0");
|
||||
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
||||
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
||||
|
||||
bool enableDistributedMBReading = config(L"distributedMBReading", GetDistributedMBReadingDefaultValue(config, reader));
|
||||
|
||||
vector<wstring> evalNodeNamesVector;
|
||||
|
||||
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNamesVector);
|
||||
|
||||
// set tracing flags
|
||||
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
||||
|
||||
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), enableDistributedMBReading, numMBsToShowResult,
|
||||
firstMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
||||
eval.EvaluateBN(&reader, evalNodeNamesVector, exportPath, mbSize[0], iters, epochSize);
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
void DoEvalBN(const ConfigParameters& config)
|
||||
{
|
||||
// This is actually used for re-estimating the BN node. It *should* actually randomize.
|
||||
// TODO: rename to DoEstimateBN.
|
||||
|
||||
// evaluate batch normalization mean and various
|
||||
ConfigParameters readerConfig(config(L"reader"));
|
||||
|
||||
// Should trace level to zero in Post BN?
|
||||
//readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
|
||||
DataReader evaBNDataReader(readerConfig);
|
||||
|
||||
DoEvalBNBase<ElemType>(config, evaBNDataReader);
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
void DoEval(const ConfigParameters& config)
|
||||
{
|
||||
|
@ -158,8 +99,6 @@ void DoEval(const ConfigParameters& config)
|
|||
DoEvalBase<ElemType>(config, testDataReader);
|
||||
}
|
||||
|
||||
template void DoEvalBN<double>(const ConfigParameters& config);
|
||||
template void DoEvalBN<float>(const ConfigParameters& config);
|
||||
template void DoEval<double>(const ConfigParameters& config);
|
||||
template void DoEval<float>(const ConfigParameters& config);
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "ScriptableObjects.h"
|
||||
#include "BrainScriptEvaluator.h"
|
||||
#include "BrainScriptParser.h"
|
||||
#include "PostComputingActions.h"
|
||||
|
||||
#include <string>
|
||||
#include <chrono>
|
||||
|
@ -259,3 +260,46 @@ void DoEdit(const ConfigParameters& config)
|
|||
|
||||
template void DoEdit<double>(const ConfigParameters& config);
|
||||
template void DoEdit<float>(const ConfigParameters& config);
|
||||
|
||||
// ===========================================================================
|
||||
// DoBatchNormalizationStat() - implements CNTK "bnstat" command
|
||||
// ===========================================================================
|
||||
|
||||
template <typename ElemType>
|
||||
void DoBatchNormalizationStat(const ConfigParameters& config)
|
||||
{
|
||||
ConfigParameters readerConfig(config(L"reader"));
|
||||
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
||||
|
||||
auto dataReader = make_shared<DataReader>(readerConfig);
|
||||
|
||||
int traceLevel = config(L"traceLevel", "0");
|
||||
int itersPerNode = config(L"itersPerNode", 30);
|
||||
|
||||
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
||||
intargvector mbSize = minibatchSize;
|
||||
|
||||
bool enableDistributedMBReading = config(L"enableDistributedMBReading", false);
|
||||
|
||||
wstring curModelPath = config(L"modelPath", L"");
|
||||
wstring newModelPath = config(L"newModelPath", L"");
|
||||
if (newModelPath == L"")
|
||||
{
|
||||
newModelPath = curModelPath + L".PBN";
|
||||
}
|
||||
|
||||
std::vector<std::wstring> evalNodeNames;
|
||||
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNames);
|
||||
// set tracing flags
|
||||
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
||||
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
||||
|
||||
PostComputingActions<ElemType> postComputingActions(net, MPIWrapper::GetInstance(), enableDistributedMBReading, traceLevel);
|
||||
|
||||
postComputingActions.BatchNormalizationStatistics(dataReader.get(), evalNodeNames, newModelPath, mbSize[0], itersPerNode);
|
||||
}
|
||||
|
||||
template void DoBatchNormalizationStat<double>(const ConfigParameters& config);
|
||||
template void DoBatchNormalizationStat<float>(const ConfigParameters& config);
|
||||
|
||||
|
|
|
@ -165,7 +165,7 @@ static void DisableLegacyUsage(const ConfigParameters& TopLevelConfig, const Con
|
|||
|
||||
// When running in parallel with MPI, only commands in 'commandstoRunOnAllRanks' should
|
||||
// be run in parallel across multiple ranks. Others should only run on rank 0
|
||||
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "pbn" };
|
||||
const std::set<std::string> commandstoRunOnAllRanks = { "train", "trainRNN", "adapt", "test", "eval", "cv", "devtest", "bnstat" };
|
||||
|
||||
// process the command
|
||||
template <typename ElemType>
|
||||
|
@ -273,10 +273,9 @@ void DoCommands(const ConfigParameters& config, const shared_ptr<MPIWrapper>& mp
|
|||
}
|
||||
fullEpochsOffset += GetMaxEpochs(commandParams);
|
||||
}
|
||||
// TODO: Choose a clearer name.
|
||||
else if (thisAction == "pbn")
|
||||
else if (thisAction == "bnstat")
|
||||
{
|
||||
DoEvalBN<ElemType>(commandParams);
|
||||
DoBatchNormalizationStat<ElemType>(commandParams);
|
||||
}
|
||||
else if (thisAction == "adapt")
|
||||
{
|
||||
|
|
|
@ -131,8 +131,6 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -141,10 +139,6 @@
|
|||
<Command>xcopy /I /D /Y "$(ProjectDir)BrainScript\CNTKCoreLib\CNTK.core.bs" "$(TargetDir)" && if exist "%ProgramW6432%\NVIDIA Corporation\NVSMI" xcopy /I /D /Y "%ProgramW6432%\NVIDIA Corporation\NVSMI\nvml*.dll" "$(TargetDir)"</Command>
|
||||
<Message>Copying dependencies</Message>
|
||||
</PostBuildEvent>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<Text Include="BrainScript\Doc\Notes.txt" />
|
||||
|
|
|
@ -119,8 +119,6 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -129,10 +127,6 @@
|
|||
<Command>if exist "%ProgramW6432%\NVIDIA Corporation\NVSMI" xcopy /I /D /Y "%ProgramW6432%\NVIDIA Corporation\NVSMI\nvml*.dll" "$(TargetDir)"</Command>
|
||||
<Message>Copying NVidia GDK extension DLL to target folder</Message>
|
||||
</PostBuildEvent>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="API\CNTKLibrary.h" />
|
||||
|
|
|
@ -54,13 +54,7 @@
|
|||
<ItemDefinitionGroup Condition="$(ReleaseBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -44,6 +44,10 @@ struct ComputationEnvironment
|
|||
|
||||
// traceLevel
|
||||
int traceLevel = 0;
|
||||
|
||||
// Extreme tracing of node outputs. Make space on your disk.
|
||||
bool IsLogLevelNodeTrace() const { return traceLevel >= 1000000; }
|
||||
|
||||
// more properties should be added here as needed
|
||||
};
|
||||
typedef std::shared_ptr<ComputationEnvironment> ComputationEnvironmentPtr;
|
||||
|
|
|
@ -136,10 +136,6 @@ public:
|
|||
// main entry point for backprop
|
||||
void Backprop(const ComputationNodeBasePtr rootNode);
|
||||
|
||||
// partial forward entry
|
||||
void ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode,
|
||||
const ComputationNodeBasePtr endNode);
|
||||
|
||||
template <class NODESET> // version that takes multiple nodes
|
||||
void ForwardProp(const NODESET& nodes)
|
||||
{
|
||||
|
@ -689,6 +685,44 @@ public:
|
|||
return GetNodesWhere(predicate, rootNode);
|
||||
}
|
||||
|
||||
// Get the eval nodes with names
|
||||
// if evalNodeNames are not specified, return all the default evalnodes and training criterion nodes.
|
||||
std::vector<ComputationNodeBasePtr> GetEvalNodesWithName(const std::vector<wstring> evalNodeNames)
|
||||
{
|
||||
// determine nodes to evaluate
|
||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||
|
||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
||||
if (evalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (EvaluationNodes().empty() && FinalCriterionNodes().empty())
|
||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
||||
|
||||
for (const auto& node : EvaluationNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
|
||||
for (const auto& node : FinalCriterionNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
||||
{
|
||||
const auto& node = GetNodeFromName(evalNodeNames[i]);
|
||||
if (!criteriaLogged.insert(node).second)
|
||||
continue;
|
||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
return evalNodes;
|
||||
}
|
||||
|
||||
public:
|
||||
// return list of nodes that require precomputation and not precomputed yet
|
||||
std::list<ComputationNodeBasePtr> GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true);
|
||||
|
@ -1056,9 +1090,6 @@ protected:
|
|||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool);
|
||||
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool);
|
||||
|
||||
// TODO: Why is this virtual?
|
||||
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) override;
|
||||
|
||||
public:
|
||||
// this special constructor constructs the top-level network node
|
||||
// There is currently no other constructor for inner nested PAR-traversed sub-networks, but there will be.
|
||||
|
|
|
@ -79,17 +79,6 @@ void ComputationNetwork::Backprop(const ComputationNodeBasePtr rootNode) // trai
|
|||
GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true);
|
||||
}
|
||||
|
||||
void ComputationNetwork::ForwardProp(const ComputationNodeBasePtr rootNode, const ComputationNodeBasePtr startNode, const ComputationNodeBasePtr endNode)
|
||||
{
|
||||
VerifyIsCompiled("ForwardProp");
|
||||
|
||||
// traverse partial nodes as inputs
|
||||
shared_ptr<FlowControlNode> network = dynamic_pointer_cast<FlowControlNode>(GetNestedNetwork(rootNode));
|
||||
assert(network);
|
||||
|
||||
network->ForwardProp(FrameRange(nullptr), startNode, endNode);
|
||||
}
|
||||
|
||||
void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode)
|
||||
{
|
||||
if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end())
|
||||
|
@ -159,12 +148,11 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
|||
node->BumpEvalTimeStamp();
|
||||
}
|
||||
|
||||
// more extreme tracing for the ultimate debugging experience. Make space on your disk.
|
||||
if (node->GetEnvironmentPtr() && node->Environment().traceLevel >= 1000000) // very high number, since this spews like hell
|
||||
// Extreme Tracing, part 1/4
|
||||
if (node->HasEnvironmentPtr() && node->Environment().IsLogLevelNodeTrace())
|
||||
DumpNode<float>(node, /*dumpGradient=*/false) || DumpNode<double>(node, false);
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
|
||||
{
|
||||
childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
|
||||
|
@ -177,8 +165,8 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
|||
node->Backprop(fr.WithLayout(node->GetMBLayout()), true /*childrenInThisLoop*/, true /*childrenInOuterLoop*/);
|
||||
node->EndBackprop();
|
||||
|
||||
// more extreme tracing for the ultimate debugging experience. Make space on your disk.
|
||||
if (node->GetEnvironmentPtr() && node->Environment().traceLevel >= 1000000 && node->NeedsGradient()) // very high number, since this spews like hell
|
||||
// Extreme Tracing, part 2/4
|
||||
if (node->HasEnvironmentPtr() && node->Environment().IsLogLevelNodeTrace() && node->NeedsGradient())
|
||||
DumpNode<float>(node, /*dumpGradient=*/true) || DumpNode<double>(node, true);
|
||||
}
|
||||
}
|
||||
|
@ -197,37 +185,7 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
|||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
|
||||
{
|
||||
}
|
||||
// TODO: merge with the main ForwardProp() function.
|
||||
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const FrameRange & fr, ComputationNodeBasePtr startNode, ComputationNodeBasePtr endNode)
|
||||
{
|
||||
// if start node is nullptr, forward will be enable
|
||||
bool enableForward = startNode ? false : true;
|
||||
|
||||
for (auto& node : m_nestedNodes)
|
||||
{
|
||||
#if 0
|
||||
if (dynamic_pointer_cast<LearnableParameter<float>>(node))
|
||||
dynamic_pointer_cast<ComputationNode<float>>(node)->DebugLogMinibatch();
|
||||
#endif
|
||||
if (node->IsOutOfDateWrtInputs() && enableForward)
|
||||
{
|
||||
node->BeginForwardProp();
|
||||
node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
|
||||
node->EndForwardProp();
|
||||
|
||||
node->BumpEvalTimeStamp();
|
||||
}
|
||||
|
||||
if (node == startNode)
|
||||
{
|
||||
enableForward = true;
|
||||
}
|
||||
else if (node == endNode)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// helper for logging. Returns false if it was not able to dynamic-cast nodep to ComputationNode<ElemType>
|
||||
template<class ElemType>
|
||||
static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient)
|
||||
|
@ -294,6 +252,15 @@ static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient)
|
|||
node->BumpEvalTimeStamp();
|
||||
}
|
||||
}
|
||||
|
||||
// Extreme Tracing, part 3/4
|
||||
for (auto& node : m_nestedNodes)
|
||||
{
|
||||
if (node->HasEnvironmentPtr() && node->Environment().IsLogLevelNodeTrace())
|
||||
{
|
||||
DumpNode<float>(node, /*dumpGradient=*/false) || DumpNode<double>(node, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::EndForwardProp() /*override*/
|
||||
|
@ -326,6 +293,15 @@ static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient)
|
|||
// a node that is outside the loop, which is done later in EndBackprop() in PAR mode.
|
||||
}
|
||||
}
|
||||
|
||||
// Extreme Tracing, part 4
|
||||
for (auto& node : m_nestedNodes)
|
||||
{
|
||||
if (node->HasEnvironmentPtr() && node->Environment().IsLogLevelNodeTrace() && node->NeedsGradient())
|
||||
{
|
||||
DumpNode<float>(node, /*dumpGradient=*/true) || DumpNode<double>(node, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// called after last iteration step of ComputeGradient()
|
||||
|
|
|
@ -72,8 +72,6 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -83,10 +81,6 @@
|
|||
<Command>if exist "%ProgramW6432%\NVIDIA Corporation\NVSMI" xcopy /I /D /Y "%ProgramW6432%\NVIDIA Corporation\NVSMI\nvml*.dll" "$(TargetDir)"</Command>
|
||||
<Message>Copying NVidia GDK extension DLL to target folder</Message>
|
||||
</PostBuildEvent>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\Common\CrossProcessMutex.h" />
|
||||
|
|
|
@ -647,6 +647,8 @@ public:
|
|||
LogicError("Environment: No environment has been set.");
|
||||
return *m_environment;
|
||||
}
|
||||
|
||||
bool HasEnvironmentPtr() const { return m_environment.get() != nullptr; }
|
||||
ComputationEnvironmentPtr GetEnvironmentPtr() const { return m_environment; }
|
||||
void SetEnvironment(ComputationEnvironmentPtr environment) { m_environment = environment; }
|
||||
|
||||
|
@ -1886,10 +1888,6 @@ public:
|
|||
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const override {}
|
||||
virtual std::set<std::pair<const MatrixBase*, std::wstring>> GetMatrixInfo() const override { NOT_IMPLEMENTED; }
|
||||
|
||||
virtual void ForwardProp(const FrameRange&, const ComputationNodeBasePtr, const ComputationNodeBasePtr) { NOT_IMPLEMENTED; }
|
||||
|
||||
std::vector<ComputationNodeBasePtr> GetNestedNodes() { return m_nestedNodes; }
|
||||
|
||||
protected: public: // needed in ComputationNetwork::FindInRecurrentLoops(), which really should be part of SEQTraversalFlowControlNode
|
||||
std::vector<ComputationNodeBasePtr> m_nestedNodes; // nodes tucked away in this node, in evaluation order
|
||||
};
|
||||
|
|
|
@ -47,6 +47,7 @@ public:
|
|||
MarkValueNonSharable();
|
||||
m_initString = L"fromValue"; // default init is with 0; typically overwritten
|
||||
m_initValue = 0;
|
||||
m_regMultiplier = 1.0f; // enable reg in update by default
|
||||
}
|
||||
LearnableParameter(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& shape) :
|
||||
LearnableParameter(deviceId, name)
|
||||
|
@ -142,6 +143,14 @@ public:
|
|||
// called from CloneFunction(..., parameters="constant")
|
||||
virtual void FreezeParameters() override; // from IFreezable
|
||||
|
||||
// Setting the reg multiplier for a learnable node, effecting L1Reg and L2Reg both.
|
||||
void SetRegMultiplier(float regMultiplier)
|
||||
{
|
||||
m_regMultiplier = regMultiplier;
|
||||
}
|
||||
// called from SGD UpdateWeights, to adjust the reg for each node
|
||||
float GetRegMultiplier() const { return m_regMultiplier; }
|
||||
|
||||
private:
|
||||
// init parameters for deferred initialization (which happens in Validate())
|
||||
std::wstring m_initString; // if non-empty then deferred initialization is needed. Gets cleared upon completion of deferred init.
|
||||
|
@ -151,6 +160,9 @@ private:
|
|||
int m_initOutputRank;
|
||||
bool m_initOnCPUOnly;
|
||||
ElemType m_initValue;
|
||||
|
||||
// flags related to gradient update
|
||||
float m_regMultiplier; // The multiplier to adjust the L1Reg and L2Reg for Learnable node
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "ComputationNode.h"
|
||||
#include "BatchNormalizationEngine.h"
|
||||
#include "RNGHandle.h"
|
||||
#include "InputAndParamNodes.h"
|
||||
#include "CPURNGHandle.h"
|
||||
|
||||
|
||||
|
@ -2223,15 +2224,15 @@ class BatchNormalizationNode : public ComputationNodeNonLooping<ElemType>, publi
|
|||
public:
|
||||
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name) :
|
||||
Base(deviceId, name), m_spatial(false), m_normTimeConst(0), m_blendTimeConst(0), m_epsilon(0), m_useCntkEngine(true),
|
||||
m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW), m_postBatchNormalization(false), m_swapNormTimeConst(0),
|
||||
m_swapBlendTimeConst(0), m_convertRunningVariancePending(false)
|
||||
m_samplesSeen(0), m_imageLayoutKind(ImageLayoutKind::CHW),
|
||||
m_convertRunningVariancePending(false)
|
||||
{
|
||||
}
|
||||
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool spatial, double normalizationTimeConstant, double blendTimeConstant,
|
||||
double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind) :
|
||||
Base(deviceId, name), m_spatial(spatial), m_normTimeConst(normalizationTimeConstant), m_blendTimeConst(blendTimeConstant),
|
||||
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0), m_postBatchNormalization(false),
|
||||
m_swapNormTimeConst(0), m_swapBlendTimeConst(0), m_convertRunningVariancePending(false)
|
||||
m_epsilon(epsilon), m_useCntkEngine(useCntkEngine), m_imageLayoutKind(imageLayoutKind), m_samplesSeen(0),
|
||||
m_convertRunningVariancePending(false)
|
||||
{
|
||||
}
|
||||
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp) :
|
||||
|
@ -2241,9 +2242,6 @@ public:
|
|||
ImageLayoutKindFrom(configp->Get(L"imageLayout")))
|
||||
{
|
||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||
m_postBatchNormalization = false;
|
||||
m_swapNormTimeConst = 0;
|
||||
m_swapBlendTimeConst = 0;
|
||||
}
|
||||
|
||||
void Save(File& fstream) const override
|
||||
|
@ -2360,7 +2358,7 @@ private: // time-constant conversions
|
|||
double ComputeExpAvgFactor() const
|
||||
{
|
||||
// in inference mode, only use long-term mean and do not update running estimates
|
||||
if (!Environment().IsTraining() && !m_postBatchNormalization)
|
||||
if (!Environment().IsTraining())
|
||||
{
|
||||
if (m_samplesSeen == 0)
|
||||
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
||||
|
@ -2392,7 +2390,7 @@ private: // time-constant conversions
|
|||
double ComputeBlendFactor() const
|
||||
{
|
||||
// in inference mode, only use long-term mean and do not update running estimates
|
||||
if (!Environment().IsTraining() && !m_postBatchNormalization)
|
||||
if (!Environment().IsTraining())
|
||||
{
|
||||
if (m_samplesSeen == 0)
|
||||
RuntimeError("%ls: inference mode is used, but nothing has been trained.", NodeName().c_str());
|
||||
|
@ -2441,7 +2439,7 @@ public:
|
|||
// In inference-only mode, m_savedMean and m_saveInvStdDev will not be
|
||||
// produced and BackpropToNonLooping() may not be called. In
|
||||
// non-inference (training) mode, saved statistics must be produced.
|
||||
bool inferenceOnly = !Environment().IsTraining() && !m_postBatchNormalization;
|
||||
bool inferenceOnly = !Environment().IsTraining();
|
||||
m_bnEng->Forward(/*in=*/ sliceInputValue, scale, bias, // (in)
|
||||
inferenceOnly, expAvgFactor, blendFactor,
|
||||
runMean, runVariance, // (in/out) running estimates, updated from the current MB mean/variance
|
||||
|
@ -2506,14 +2504,6 @@ public:
|
|||
}
|
||||
|
||||
virtual void EndForwardProp() override
|
||||
{
|
||||
if(m_postBatchNormalization)
|
||||
m_samplesSeen += GetMBLayout()->GetActualNumSamples();
|
||||
|
||||
Base::EndForwardProp();
|
||||
}
|
||||
|
||||
virtual void EndBackprop() override
|
||||
{
|
||||
// Update samples if not locked.
|
||||
double expAvgFactor = ComputeExpAvgFactor(); // weight for the new MB statistics in the running estimate. The previous value of the running statistics is kept with weight (1-this)
|
||||
|
@ -2655,28 +2645,29 @@ public:
|
|||
m_blendTimeConst = std::numeric_limits<double>::infinity();
|
||||
}
|
||||
|
||||
// ResetStatisticsState will set the batch normal statistics into initial state
|
||||
// used for re-statistics the mean and variance of BN
|
||||
// any others use may lead undependable results, please be careful
|
||||
void ResetStatisticsState()
|
||||
{
|
||||
m_samplesSeen = 0;
|
||||
m_normTimeConst = 0;
|
||||
m_blendTimeConst = 0;
|
||||
}
|
||||
// Turn off the L1 and L2 regularization
|
||||
void DisableRegInBatchNormalization()
|
||||
{
|
||||
let scaleNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(1));
|
||||
let biasNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(2));
|
||||
scaleNode->SetRegMultiplier(0.f);
|
||||
biasNode->SetRegMultiplier(0.f);
|
||||
}
|
||||
double NormalizationTimeConstant() const { return m_normTimeConst; }
|
||||
double BlendTimeConstant() const { return m_blendTimeConst; }
|
||||
bool Spatial() const { return m_spatial; }
|
||||
double Epsilon() const { return m_epsilon; }
|
||||
bool UseCNTKEngine() const { return m_useCntkEngine; }
|
||||
|
||||
void SetPostBatchNormalizationBegin()
|
||||
{
|
||||
m_postBatchNormalization = true;
|
||||
m_samplesSeen = 0;
|
||||
m_swapNormTimeConst = m_normTimeConst;
|
||||
m_swapBlendTimeConst = m_blendTimeConst;
|
||||
m_normTimeConst = -1;
|
||||
m_blendTimeConst = 0;
|
||||
}
|
||||
void SetPostBatchNormalizationEnd()
|
||||
{
|
||||
m_postBatchNormalization = false;
|
||||
m_normTimeConst = m_swapNormTimeConst;
|
||||
m_blendTimeConst = m_swapBlendTimeConst;
|
||||
}
|
||||
|
||||
private:
|
||||
// Old versioning - do not use. Do not remove until we're sure there are no old models around.
|
||||
struct VersionInfo
|
||||
|
@ -2740,11 +2731,6 @@ private:
|
|||
|
||||
std::unique_ptr<BatchNormEngine<ElemType>> m_bnEng;
|
||||
|
||||
// post batch normalization process mark
|
||||
bool m_postBatchNormalization;
|
||||
|
||||
double m_swapNormTimeConst;
|
||||
double m_swapBlendTimeConst;
|
||||
bool m_convertRunningVariancePending;
|
||||
};
|
||||
|
||||
|
|
|
@ -120,8 +120,6 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -130,10 +128,6 @@
|
|||
<Command>if exist "%ProgramW6432%\NVIDIA Corporation\NVSMI" xcopy /I /D /Y "%ProgramW6432%\NVIDIA Corporation\NVSMI\nvml*.dll" "$(TargetDir)"</Command>
|
||||
<Message>Copying NVidia GDK extension DLL to target folder</Message>
|
||||
</PostBuildEvent>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\Common\Include\Basics.h" />
|
||||
|
|
|
@ -78,16 +78,10 @@
|
|||
<ClCompile>
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<PreprocessorDefinitions>WIN32;NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="CNTKException.h" />
|
||||
|
|
|
@ -147,16 +147,10 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -105,16 +105,10 @@ if exist "$(CuDnnDll)" xcopy /D /Y "$(CuDnnDll)" "$(OutputPath)"
|
|||
<EnableParallelCodeGeneration>true</EnableParallelCodeGeneration>
|
||||
<FloatingPointExceptions>false</FloatingPointExceptions>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<CudaCompile>
|
||||
<HostDebugInfo>false</HostDebugInfo>
|
||||
</CudaCompile>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\Common\Include\File.h" />
|
||||
|
|
|
@ -90,8 +90,6 @@
|
|||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -101,10 +99,6 @@
|
|||
<AdditionalDependencies>Math.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -83,18 +83,12 @@
|
|||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -75,18 +75,12 @@
|
|||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="CompositeDataReader.h" />
|
||||
|
|
|
@ -88,8 +88,6 @@
|
|||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -99,10 +97,6 @@
|
|||
<AdditionalDependencies>Math.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -75,18 +75,12 @@
|
|||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -93,8 +93,6 @@
|
|||
<AdditionalIncludeDirectories Condition="'$(Configuration)|$(Platform)'=='Release|x64'">..\..\common\include;..\..\Math</AdditionalIncludeDirectories>
|
||||
<AdditionalIncludeDirectories Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">..\..\common\include;..\..\Math</AdditionalIncludeDirectories>
|
||||
<AdditionalIncludeDirectories Condition="'$(Configuration)|$(Platform)'=='Release_CpuOnly|x64'">..\..\common\include;..\..\Math</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -107,10 +105,6 @@
|
|||
<AdditionalLibraryDirectories Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<AdditionalLibraryDirectories Condition="'$(Configuration)|$(Platform)'=='Release_CpuOnly|x64'">$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<Text Include="ReadMe.txt" />
|
||||
|
|
|
@ -97,18 +97,12 @@
|
|||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -88,8 +88,6 @@
|
|||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -99,10 +97,6 @@
|
|||
<AdditionalDependencies>Math.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -91,8 +91,6 @@
|
|||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -102,10 +100,6 @@
|
|||
<AdditionalDependencies>Math.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -88,8 +88,6 @@
|
|||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -99,10 +97,6 @@
|
|||
<AdditionalDependencies>Math.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -41,13 +41,7 @@
|
|||
<ItemDefinitionGroup>
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\Math</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="ConfigUtil.h" />
|
||||
|
|
|
@ -91,8 +91,6 @@
|
|||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -102,10 +100,6 @@
|
|||
<AdditionalDependencies>Math.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -90,8 +90,6 @@
|
|||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -101,10 +99,6 @@
|
|||
<AdditionalDependencies>Math.lib;Common.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
// PostStat.cpp -- CNTK post statistics related actions
|
||||
//
|
||||
|
||||
#include "PostComputingActions.h"
|
||||
|
||||
#include "TrainingNodes.h"
|
||||
#include "ProgressTracing.h"
|
||||
#include "DataReaderHelpers.h"
|
||||
#include "SimpleDistGradAggregator.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace Microsoft { namespace MSR{ namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader * dataReader, const vector<wstring>& evalNodeNames,
|
||||
const wstring newModelPath, const size_t mbSize, const int iters)
|
||||
{
|
||||
// since the mean and variance of bn will be modified in statistics,
|
||||
// training mode will make it work. And there is no back prop, other parameters
|
||||
// are fixed during computing.
|
||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::training);
|
||||
|
||||
// bn nodes need to be computed from bottom to top with evaluating order
|
||||
let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
|
||||
|
||||
// find all the BN nodes by evalOrder
|
||||
std::vector<ComputationNodeBasePtr> bnNodes;
|
||||
std::set<ComputationNodeBasePtr> bnNodesLogged; // (avoid double record of batch normalization nodes)
|
||||
for (auto& evalNode : evalNodes)
|
||||
{
|
||||
for (auto& node : m_net->GetEvalOrder(evalNode))
|
||||
{
|
||||
let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
if (bnNode)
|
||||
{
|
||||
if (bnNodesLogged.insert(node).second)
|
||||
{
|
||||
// reset the statistics states of bn nodes
|
||||
bnNode->ResetStatisticsState();
|
||||
bnNode->SetNormalizationTimeConstants(-1, bnNode->NormalizationTimeConstant(),
|
||||
0, bnNode->BlendTimeConstant());
|
||||
bnNodes.push_back(node);
|
||||
// add BN nodes into the evaluation group, then they will be added into root nodes when
|
||||
// the network re-compile
|
||||
m_net->AddToNodeGroup(L"evaluation", bnNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// re-compile the network to add bn nodes as rootNodes.
|
||||
m_net->CompileNetwork();
|
||||
|
||||
// allocate memory for all bnNodes evalOrder
|
||||
m_net->AllocateAllMatrices(bnNodes, std::vector<ComputationNodeBasePtr>(), nullptr);
|
||||
|
||||
// prepare features
|
||||
auto& featureNodes = m_net->FeatureNodes();
|
||||
|
||||
StreamMinibatchInputs inputMatrices;
|
||||
for (auto& node : featureNodes)
|
||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||
|
||||
bool useParallelTrain = (m_mpi != nullptr);
|
||||
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
|
||||
size_t totalEpochSize = bnNodes.size() * mbSize * iters;
|
||||
|
||||
m_net->StartEvaluateMinibatchLoop(bnNodes);
|
||||
|
||||
if (useDistributedMBReading)
|
||||
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), inputMatrices.GetStreamDescriptions(), totalEpochSize);
|
||||
else
|
||||
dataReader->StartMinibatchLoop(mbSize, 0, inputMatrices.GetStreamDescriptions(), totalEpochSize);
|
||||
|
||||
for (auto& node : bnNodes)
|
||||
{
|
||||
let bnNode = static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
size_t actualMBSize = 0;
|
||||
|
||||
LOGPRINTF(stderr, "Estimating Statistics --> %ls\n", bnNode->GetName().c_str());
|
||||
|
||||
|
||||
// for every single bn node, the statistics is the average of mean and variance for several times in forward prop
|
||||
// the forward prop is from the feature to the current bn node
|
||||
for (int iter = 0; iter < iters; iter++)
|
||||
{
|
||||
// during the bn stat, dataRead must be ensured
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
|
||||
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
|
||||
if (!wasDataRead) LogicError("DataRead Failure in batch normalization statistics");
|
||||
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
|
||||
// forward prop till reaching the current bn node
|
||||
m_net->ForwardProp(node);
|
||||
}
|
||||
|
||||
// after finished statistics, the mean and variance of the bn node should be freezd.
|
||||
bnNode->FreezeParameters();
|
||||
|
||||
// Sync during or after all iters of a BN node are equivalent
|
||||
if (useParallelTrain)
|
||||
{
|
||||
if (m_gradHeader == nullptr)
|
||||
{
|
||||
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
|
||||
{
|
||||
DistGradHeader::Destroy(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
// push the statistics results of mean and variance of bn nodes into mpi updating vector
|
||||
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
|
||||
|
||||
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
|
||||
|
||||
auto runMeanParameterPtr = node->Input(3);
|
||||
auto runStdParameterPtr = node->Input(4);
|
||||
|
||||
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
|
||||
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
|
||||
|
||||
learnParamsValues[0] = &(runMeanNode->Value());
|
||||
learnParamsValues[1] = &(runStdNode->Value());
|
||||
|
||||
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
|
||||
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), 0);
|
||||
|
||||
// get the average mean and variance across all the workers
|
||||
for (auto& parameter : learnParamsValues)
|
||||
{
|
||||
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dataReader->DataEnd();
|
||||
|
||||
// remove all the added BN nodes from evaluation group
|
||||
for (auto& bnNode : bnNodes)
|
||||
{
|
||||
m_net->RemoveFromNodeGroup(L"evaluation", bnNode);
|
||||
}
|
||||
|
||||
// save model
|
||||
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
|
||||
m_net->Save(newModelPath);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template class PostComputingActions<float>;
|
||||
template class PostComputingActions<double>;
|
||||
|
||||
}}}
|
|
@ -0,0 +1,65 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
// PostStat.h -- CNTK post statistics related actions
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#include "ComputationNode.h"
|
||||
#include "ComputationNetwork.h"
|
||||
#include "MPIWrapper.h"
|
||||
#include "DataReader.h"
|
||||
#include "IDistGradAggregator.h"
|
||||
#include "DistGradHeader.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class IDistGradAggregator;
|
||||
|
||||
// Post statistics normally called between training and evaluating, to generate the statistics results used by evaluating
|
||||
// For now, the application is only with statistics mean and variance of Batch Normalization nodes after training
|
||||
template <class ElemType>
|
||||
class PostComputingActions
|
||||
{
|
||||
public:
|
||||
PostComputingActions(ComputationNetworkPtr net, const MPIWrapperPtr& mpi, bool enableDistributedMBReading = false, const int traceLevel = 0) :
|
||||
m_net(net),
|
||||
m_traceLevel(traceLevel),
|
||||
m_mpi(mpi),
|
||||
m_distGradAgg(nullptr),
|
||||
m_gradHeader(nullptr),
|
||||
m_enableDistributedMBReading(enableDistributedMBReading)
|
||||
{
|
||||
}
|
||||
|
||||
// This function is used for evaluating the mean and variance of all batch normalization nodes after training.
|
||||
// Details will link to the wiki https://github.com/Microsoft/CNTK/wiki/Post-Batch-Normalization-Statistics
|
||||
// The reason why put it into evalute is the action take place after trainning and non-backprop processing, which makes me believe
|
||||
// this function is like a kind of evaluate function.
|
||||
// In this function,
|
||||
// 1. since all other weights are fix except the un-pbn nodes, I set the networkoperationMode into inferring.
|
||||
// 2. The next thing is to load the network model and data source, I follow the Evaluate function to do so, however, I delete something
|
||||
// seem useless, like error statistics etc.
|
||||
// 3. Finding the BN nodes in the network and put them into a vector with evaluate order (This links the nestedNode vector I got in
|
||||
// ControlFlowNetwork)
|
||||
// 4. From node to node in the BN vector to generate the mean and various (This links to the changes of BatchNormalizationNode
|
||||
// in TrainingNodes.h, since I need to make the nodes "learn" mean and variance in inferring mode)
|
||||
// 5. Consider the multi-GPU, we need to sync up the BN results between all the worker and average the value.
|
||||
void BatchNormalizationStatistics(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring newModelPath,
|
||||
const size_t mbSize, const int iters = 30);
|
||||
|
||||
private:
|
||||
ComputationNetworkPtr m_net;
|
||||
MPIWrapperPtr m_mpi;
|
||||
bool m_enableDistributedMBReading;
|
||||
|
||||
int m_traceLevel;
|
||||
|
||||
std::shared_ptr<IDistGradAggregator<ElemType>> m_distGradAgg;
|
||||
std::shared_ptr<struct DistGradHeader> m_gradHeader;
|
||||
};
|
||||
}}}
|
|
@ -8,6 +8,7 @@
|
|||
#include "SpecialPurposeNodes.h" // for SequenceWithSoftmaxNode
|
||||
#include "DataReaderHelpers.h"
|
||||
#include "MatrixQuantizerImpl.h"
|
||||
#include "InputAndParamNodes.h"
|
||||
|
||||
#ifdef CNTK_PARALLEL_TRAINING_SUPPORT
|
||||
//static inline bool operator==(const std::pair<double,size_t>& a, double b) { assert(b==0); return a.first == b; }
|
||||
|
@ -962,24 +963,13 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
EpochCriterion epochCriterionLastLogged = epochCriterion;
|
||||
vector<EpochCriterion> epochEvalErrorsLastLogged = epochEvalErrors;
|
||||
|
||||
// Now, we need to use a switch to enable/disable wk in BatchNormalization.
|
||||
// If we can determine whether wk added or not for each node, then, discard this
|
||||
// TODO: Define "wk" and say what this is for and in which context it is used.
|
||||
std::unordered_set<ComputationNodeBasePtr> batchNormalizationWeights;
|
||||
if (m_disableWkInBatchNormal) {
|
||||
for (auto& evalNode : evaluationNodes)
|
||||
// NOTE: For ResNet, the regularization in BatchNormalization should be disable.
|
||||
if (m_disableRegInBatchNormalization) {
|
||||
let bnNodes = net->GetNodesWithType(L"BatchNormalization");
|
||||
for (auto &node : bnNodes)
|
||||
{
|
||||
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(net->GetNestedNetwork(evalNode));
|
||||
for (auto& node : nestedNetwork->GetNestedNodes())
|
||||
{
|
||||
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
|
||||
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
if (castNode)
|
||||
{
|
||||
batchNormalizationWeights.insert(castNode->GetInputs()[1]);
|
||||
batchNormalizationWeights.insert(castNode->GetInputs()[2]);
|
||||
}
|
||||
}
|
||||
let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
bnNode->DisableRegInBatchNormalization();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1207,8 +1197,8 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
LogicError("%ls %ls operation has NaNs in smoothedGradient.", node->NodeName().c_str(), node->OperationName().c_str());
|
||||
#endif
|
||||
double nodeDependentLearningRatePerSample = learnRatePerSample * node->GetLearningRateMultiplier();
|
||||
double nodeDependentRegMultiplier = dynamic_pointer_cast<LearnableParameter<ElemType>>(node)->GetRegMultiplier();
|
||||
double momentumPerSample = GetMomentumPerSample(epochNumber /*BUGBUG workaround:*/, net->GetMBLayoutPtrOfNetwork()->GetNumParallelSequences());
|
||||
double l2Factor = batchNormalizationWeights.find(node) == batchNormalizationWeights.end() ? 1.0 : 0.0;
|
||||
// TODO: Check why l2Factor is not applied to L1. Bug?
|
||||
// BUGBUG (Issue #95): Access to net MBLayout can no longer be done if we have multiple input layouts
|
||||
UpdateWeights(dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value(),
|
||||
|
@ -1216,7 +1206,7 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
*smoothedGradientIter, *smoothedCountIter,
|
||||
nodeDependentLearningRatePerSample, momentumPerSample,
|
||||
numSamplesInMinibatch,
|
||||
m_L2RegWeight * l2Factor, m_L1RegWeight,
|
||||
m_L2RegWeight * nodeDependentRegMultiplier, m_L1RegWeight * nodeDependentRegMultiplier,
|
||||
m_needAveMultiplier, m_useNesterovMomentum);
|
||||
node->BumpEvalTimeStamp();
|
||||
#ifdef _DEBUG
|
||||
|
@ -1999,7 +1989,7 @@ void SGD<ElemType>::AttemptUtteranceDerivativeFeatures(ComputationNetworkPtr net
|
|||
|
||||
template <class ElemType>
|
||||
void SGD<ElemType>::InitDistGradAgg(int numEvalNodes, int numGradientBits, int traceLevel)
|
||||
{
|
||||
{
|
||||
assert(GetParallelizationMethod() == ParallelizationMethod::dataParallelSGD);
|
||||
if (traceLevel > 0)
|
||||
fprintf(stderr, "Initializing dataParallelSGD for %d-bit quantization.\n", numGradientBits);
|
||||
|
@ -2141,7 +2131,6 @@ void SGD<ElemType>::UpdateWeights(Matrix<ElemType>& functionValues, Matrix<ElemT
|
|||
}
|
||||
|
||||
// protected:
|
||||
|
||||
template <class ElemType>
|
||||
void SGD<ElemType>::ClipGradient(Matrix<ElemType>& gradient, const size_t actualMBSize) const
|
||||
{
|
||||
|
@ -2621,8 +2610,7 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
|||
m_seqGammarCalcLMF = configSGD(L"seqGammarLMF", 14.0);
|
||||
m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0);
|
||||
m_seqGammarCalcWP = configSGD(L"seqGammarWordPen", 0.0);
|
||||
|
||||
m_disableWkInBatchNormal = configSGD(L"disableWkInBatchNormal", false);
|
||||
m_disableRegInBatchNormalization = configSGD(L"disableRegInBatchNormalization", false);
|
||||
|
||||
m_dropoutRates = configSGD(L"dropoutRate", ConfigRecordType::Array(doubleargvector(vector<double>{0.0})));
|
||||
m_batchNormalizationTimeConstant = configSGD(L"batchNormalizationTimeConstant", ConfigRecordType::Array(doubleargvector(vector<double>{0})));
|
||||
|
|
|
@ -307,7 +307,10 @@ protected:
|
|||
double m_seqGammarCalcbMMIFactor;
|
||||
bool m_seqGammarCalcUsesMBR;
|
||||
|
||||
bool m_disableWkInBatchNormal; // TODO: comment?
|
||||
// decide whether should apply regularization into BatchNormalizationNode
|
||||
// true: disable Regularization
|
||||
// false: enable Regularization (default)
|
||||
bool m_disableRegInBatchNormalization;
|
||||
};
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -87,8 +87,6 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -98,10 +96,6 @@
|
|||
<Command>if exist "%ProgramW6432%\NVIDIA Corporation\NVSMI" xcopy /I /D /Y "%ProgramW6432%\NVIDIA Corporation\NVSMI\nvml*.dll" "$(TargetDir)"</Command>
|
||||
<Message>Copying NVidia GDK extension DLL to target folder</Message>
|
||||
</PostBuildEvent>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\Common\CrossProcessMutex.h" />
|
||||
|
@ -139,6 +133,7 @@
|
|||
<ClInclude Include="..\ComputationNetworkLib\NonlinearityNodes.h" />
|
||||
<ClInclude Include="..\ComputationNetworkLib\RecurrentNodes.h" />
|
||||
<ClInclude Include="MASGD.h" />
|
||||
<ClInclude Include="PostComputingActions.h" />
|
||||
<ClInclude Include="SimpleDistGradAggregator.h" />
|
||||
<ClInclude Include="SimpleEvaluator.h" />
|
||||
<ClInclude Include="SimpleOutputWriter.h" />
|
||||
|
@ -147,6 +142,7 @@
|
|||
<ClInclude Include="targetver.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="PostComputingActions.cpp" />
|
||||
<ClCompile Include="Profiler.cpp" />
|
||||
<ClCompile Include="SGD.cpp" />
|
||||
<ClCompile Include="stdafx.cpp" />
|
||||
|
|
|
@ -10,6 +10,9 @@
|
|||
<ClCompile Include="SGD.cpp">
|
||||
<Filter>SGD</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="PostComputingActions.cpp">
|
||||
<Filter>Stat</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\Common\Include\fileutil.h">
|
||||
|
@ -131,6 +134,9 @@
|
|||
</ClInclude>
|
||||
<ClInclude Include="Criterion.h">
|
||||
<Filter>SGD</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="PostComputingActions.h">
|
||||
<Filter>Stat</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\Common\Include\ASGDCommon.h">
|
||||
<Filter>Parallelization</Filter>
|
||||
|
@ -173,5 +179,8 @@
|
|||
<Filter Include="Data Reading">
|
||||
<UniqueIdentifier>{b866d513-7bd0-497c-98c2-f62dbcd4cde4}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="Stat">
|
||||
<UniqueIdentifier>{f406217f-5a11-44ca-bb34-52254dbee8af}</UniqueIdentifier>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -52,36 +52,7 @@ public:
|
|||
{
|
||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
||||
|
||||
// determine nodes to evaluate
|
||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||
|
||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
||||
if (evalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
|
||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
||||
|
||||
for (const auto& node : m_net->EvaluationNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
|
||||
for (const auto& node : m_net->FinalCriterionNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
||||
{
|
||||
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
|
||||
if (!criteriaLogged.insert(node).second)
|
||||
continue;
|
||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
let evalNodes = m_net->GetEvalNodesWithName(evalNodeNames);
|
||||
|
||||
// initialize eval results
|
||||
std::vector<EpochCriterion> evalResults(evalNodes.size(), EpochCriterion(0));
|
||||
|
@ -257,154 +228,6 @@ public:
|
|||
return evalResults;
|
||||
}
|
||||
|
||||
// TODO: remove code dup w.r.t. Evaluate()
|
||||
void EvaluateBN(IDataReader* dataReader, const vector<wstring>& evalNodeNames, const wstring exportPath, const size_t mbSize, const int iters = 240, const size_t testSize = requestDataSize)
|
||||
{
|
||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
||||
|
||||
// determine nodes to evaluate
|
||||
std::vector<ComputationNodeBasePtr> evalNodes;
|
||||
|
||||
set<ComputationNodeBasePtr> criteriaLogged; // (keeps track ot duplicates to avoid we don't double-log critera)
|
||||
if (evalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (m_net->EvaluationNodes().empty() && m_net->FinalCriterionNodes().empty())
|
||||
InvalidArgument("There is no default evaluation node or training criterion specified in the network.");
|
||||
|
||||
for (const auto& node : m_net->EvaluationNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
|
||||
for (const auto& node : m_net->FinalCriterionNodes())
|
||||
if (criteriaLogged.insert(node).second)
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < evalNodeNames.size(); i++)
|
||||
{
|
||||
const auto& node = m_net->GetNodeFromName(evalNodeNames[i]);
|
||||
if (!criteriaLogged.insert(node).second)
|
||||
continue;
|
||||
if (node->GetSampleLayout().GetNumElements() != 1)
|
||||
InvalidArgument("Criterion nodes to evaluate must have dimension 1x1.");
|
||||
evalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// allocate memory for forward computation
|
||||
m_net->AllocateAllMatrices(evalNodes, {}, nullptr);
|
||||
|
||||
// prepare features and labels
|
||||
auto& featureNodes = m_net->FeatureNodes();
|
||||
auto& labelNodes = m_net->LabelNodes();
|
||||
|
||||
StreamMinibatchInputs inputMatrices;
|
||||
for (auto& node : featureNodes)
|
||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||
for (auto& node : labelNodes)
|
||||
inputMatrices.AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||
|
||||
bool useParallelTrain = (m_mpi != nullptr);
|
||||
bool useDistributedMBReading = useParallelTrain && m_enableDistributedMBReading && dataReader->SupportsDistributedMBRead();
|
||||
if (useDistributedMBReading)
|
||||
dataReader->StartDistributedMinibatchLoop(mbSize, 0, m_mpi->CurrentNodeRank(), m_mpi->NumNodesInUse(), testSize);
|
||||
else
|
||||
dataReader->StartMinibatchLoop(mbSize, 0, testSize);
|
||||
|
||||
m_net->StartEvaluateMinibatchLoop(evalNodes);
|
||||
|
||||
// Passing in two empty node lists so the dispatcher can work for the evalNodes.
|
||||
std::list<ComputationNodeBasePtr> learnableNodes;
|
||||
std::vector<ComputationNodeBasePtr> criterionNodes;
|
||||
|
||||
// First, all batch normalization nodes should be marked.
|
||||
std::vector<ComputationNodeBasePtr> batchNormalNodes;
|
||||
shared_ptr<FlowControlNode> nestedNetwork = static_pointer_cast<FlowControlNode>(m_net->GetNestedNetwork(evalNodes[0]));
|
||||
for (auto& node : nestedNetwork->GetNestedNodes())
|
||||
{
|
||||
shared_ptr<BatchNormalizationNode<ElemType>> castNode =
|
||||
dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
if (castNode)
|
||||
{
|
||||
batchNormalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Push all batch normalization mean and std into learn params values for mpi update
|
||||
std::vector<Matrix<ElemType>*> learnParamsValues(2, nullptr);
|
||||
|
||||
bool noMoreSamplesToProcess = false;
|
||||
for (auto& node : batchNormalNodes)
|
||||
{
|
||||
shared_ptr<BatchNormalizationNode<ElemType>> batchNode =
|
||||
static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
batchNode->SetPostBatchNormalizationBegin();
|
||||
size_t actualMBSize = 0;
|
||||
|
||||
LOGPRINTF(stderr, "Start evaluating: %ls\n", batchNode->GetName().c_str());
|
||||
|
||||
// Post batch normal iters
|
||||
for (int iter = 0; iter < iters; iter++)
|
||||
{
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
|
||||
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
|
||||
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess))
|
||||
break;
|
||||
|
||||
// TODO should handle it, since post BN exist no samples in iters
|
||||
if (!wasDataRead)
|
||||
actualMBSize = 0;
|
||||
|
||||
// Batch Normalization Evaluate don't need to support subMinibatches
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
ComputationNetwork::BumpEvalTimeStamp(labelNodes);
|
||||
|
||||
m_net->ForwardProp(evalNodes[0], nullptr, node);
|
||||
dataReader->DataEnd();
|
||||
}
|
||||
batchNode->SetPostBatchNormalizationEnd();
|
||||
|
||||
// Sync during or after all iters of a BN node are equivalent
|
||||
if (useParallelTrain)
|
||||
{
|
||||
if (m_gradHeader == nullptr)
|
||||
{
|
||||
m_gradHeader.reset(DistGradHeader::Create(evalNodes.size()), [](DistGradHeader* ptr)
|
||||
{
|
||||
DistGradHeader::Destroy(ptr);
|
||||
});
|
||||
}
|
||||
SimpleDistGradAggregator<ElemType> distGradAgg(m_mpi, false /*useAsyncAggregation*/, 0 /*syncStatsTrace*/);
|
||||
|
||||
auto runMeanParameterPtr = node->GetInputs()[3];
|
||||
auto runStdParameterPtr = node->GetInputs()[4];
|
||||
|
||||
shared_ptr<ComputationNode<ElemType>> runMeanNode = static_pointer_cast<ComputationNode<ElemType>>(runMeanParameterPtr);
|
||||
shared_ptr<ComputationNode<ElemType>> runStdNode = static_pointer_cast<ComputationNode<ElemType>>(runStdParameterPtr);
|
||||
|
||||
learnParamsValues[0] = &(runMeanNode->Value());
|
||||
learnParamsValues[1] = &(runStdNode->Value());
|
||||
|
||||
m_gradHeader->numSamples = actualMBSize ? 1 : actualMBSize;
|
||||
distGradAgg.AggregateGradients(learnParamsValues, m_gradHeader.get(), /*resetState =*/ false);
|
||||
|
||||
for (auto& parameter : learnParamsValues)
|
||||
{
|
||||
(*parameter) /= (ElemType)m_mpi->NumNodesInUse();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save Model
|
||||
if (!useParallelTrain || m_mpi->CurrentNodeRank() == m_mpi->MainNodeRank())
|
||||
m_net->Save(exportPath);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
protected:
|
||||
void DisplayEvalStatistics(const size_t startMBNum, const size_t endMBNum, const size_t numSamplesLastLogged,
|
||||
const vector<ComputationNodeBasePtr>& evalNodes,
|
||||
|
|
|
@ -41,13 +41,7 @@
|
|||
<ClCompile>
|
||||
<PreprocessorDefinitions>WIN32;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalIncludeDirectories>$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\Math</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -100,8 +100,6 @@
|
|||
<FloatingPointExceptions>false</FloatingPointExceptions>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
|
@ -110,10 +108,6 @@
|
|||
<ProjectReference>
|
||||
<LinkLibraryDependencies>true</LinkLibraryDependencies>
|
||||
</ProjectReference>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\..\..\..\Examples\Evaluation\CPPEvalClient\CPPEvalClient.cpp" />
|
||||
|
|
|
@ -89,16 +89,10 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -87,8 +87,6 @@
|
|||
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -97,10 +95,6 @@
|
|||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>Math.lib; Common.lib; %(AdditionalDependencies)</AdditionalDependencies>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -109,17 +109,11 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
<DelayLoadDLLs>%(DelayLoadDLLs);nvml.dll;$(CudaRuntimeDll)</DelayLoadDLLs>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -105,13 +105,7 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>$(CudaToolkitIncludeDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ImportGroup Condition="$(GpuBuild)" Label="ExtensionSettings">
|
||||
<Import Project="$(VCTargetsPath)\BuildCustomizations\CUDA $(CudaVersion).props" />
|
||||
|
|
|
@ -109,17 +109,11 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
<DelayLoadDLLs>%(DelayLoadDLLs);nvml.dll;$(CudaRuntimeDll)</DelayLoadDLLs>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -90,17 +90,11 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
<DelayLoadDLLs>%(DelayLoadDLLs);nvml.dll;$(CudaRuntimeDll)</DelayLoadDLLs>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -93,16 +93,10 @@
|
|||
<PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<MultiProcessorCompilation>true</MultiProcessorCompilation>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
|
@ -97,8 +97,6 @@
|
|||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Release|x64'">MultiThreaded</RuntimeLibrary>
|
||||
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">MultiThreaded</RuntimeLibrary>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</BrowseInformation>
|
||||
<BrowseInformation Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</BrowseInformation>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
@ -107,10 +105,6 @@
|
|||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKLibrary-2.0.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
</Link>
|
||||
<Bscmake>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</PreserveSbr>
|
||||
<PreserveSbr Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">true</PreserveSbr>
|
||||
</Bscmake>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
|
|
Загрузка…
Ссылка в новой задаче