diff --git a/MachineLearning/CNTK/CNTK.cpp b/MachineLearning/CNTK/CNTK.cpp index a025e74ea..6ccbe8854 100644 --- a/MachineLearning/CNTK/CNTK.cpp +++ b/MachineLearning/CNTK/CNTK.cpp @@ -835,12 +835,11 @@ public: template void DoTrain(const ConfigRecordType & config) { - const ConfigRecordType & configSGD(config(L"SGD", ConfigRecordType::Record())); bool makeMode = config(L"makeMode", true); + DEVICEID_TYPE deviceId = DeviceFromConfig(config); - shared_ptr> netBuilder;// = GetCreateNetworkFunction(config); + shared_ptr> netBuilder; - // TODO: turn the netBuilder into a lambda if (config.Exists(L"createNetwork")) { netBuilder = make_shared>(config); @@ -866,6 +865,13 @@ void DoTrain(const ConfigRecordType & config) RuntimeError("No network builder found in the config file. NDLNetworkBuilder or SimpleNetworkBuilde must be specified"); } + // network creation is handled by a lambda, which we define here + function createNetworkFn = [netBuilder](DEVICEID_TYPE deviceId) + { + ComputationNetwork * net = netBuilder->BuildNetworkFromDescription(); + return shared_ptr(net); + }; + // BUGBUG: inconsistency with BrainScript: old config passes a config dict, whereas BrainScript creates the object right away const ConfigRecordType & readerConfig(config(L"reader", ConfigRecordType::Record())); //readerConfig.Insert("traceLevel", config(L"traceLevel", "0")); // TODO: fix this by making this an optional arg; or if this should not be inherited, then by disabling it @@ -879,9 +885,10 @@ void DoTrain(const ConfigRecordType & config) cvDataReader = unique_ptr >{ new DataReader(cvReaderConfig) }; } + const ConfigRecordType & configSGD(config(L"SGD", ConfigRecordType::Record())); SGD sgd(SGDParams(configSGD, (ElemType)0)); - sgd.Train(netBuilder.get(), dataReader.get(), cvDataReader.get(), makeMode); + sgd.Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode); } namespace Microsoft { namespace MSR { namespace ScriptableObjects { diff --git a/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.cpp b/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.cpp index 0a470d19f..105668259 100644 --- a/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.cpp +++ b/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.cpp @@ -398,8 +398,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ECN"); - //some internal values in the nodes are computed during validation + // some internal values in the nodes are computed during validation ValidateNetwork(false, bAllowNoCriterionNode); + ResetEvalTimeStamp(); } // ----------------------------------------------------------------------- diff --git a/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.h b/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.h index f171936f2..e68d9e308 100644 --- a/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.h +++ b/MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.h @@ -984,5 +984,6 @@ private: // TODO: make all private that can be made private // TODO: does this apply to anything else besides temporary node-internal intermediate results? What, for example? MatrixPool m_matrixPool; }; +typedef shared_ptr ComputationNetworkPtr; }}} diff --git a/MachineLearning/CNTKSGDLib/SGD.cpp b/MachineLearning/CNTKSGDLib/SGD.cpp index da066e9a9..666ee00c9 100644 --- a/MachineLearning/CNTKSGDLib/SGD.cpp +++ b/MachineLearning/CNTKSGDLib/SGD.cpp @@ -6,7 +6,7 @@ #include "SGD.h" #include "DataReaderHelpers.h" #include "AllReduceDistGradAggregator.h" -#include "ProgressTracing.h" +#include "ProgressTracing.h" #include @@ -513,12 +513,12 @@ namespace Microsoft { namespace MSR { namespace CNTK { } template - void SGD::Train(IComputationNetBuilder* netBuilder, + void SGD::Train(function createNetworkFn, DEVICEID_TYPE deviceId, IDataReader* trainSetDataReader, IDataReader* validationSetDataReader, const bool makeMode) { - if (netBuilder == nullptr || trainSetDataReader == nullptr) + if (trainSetDataReader == nullptr) { InvalidArgument("netBuilder and trainSetDataReader should not be null.\n"); } @@ -533,8 +533,15 @@ namespace Microsoft { namespace MSR { namespace CNTK { if (startEpoch >= 0) fprintf(stderr, "Starting from checkpoint. Load Network From File %ls.\n", modelFileName.c_str()); - ComputationNetwork* net = startEpoch < 0 ? netBuilder->BuildNetworkFromDescription() : - netBuilder->LoadNetworkFromFile(modelFileName); + shared_ptr net; + if (startEpoch < 0) + net = createNetworkFn(deviceId); + else + { + net = make_shared(deviceId); + net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, false/*bAllowNoCriterionNode*/, nullptr/*anotherNetwork*/); + } + // TODO: BUGBUG: if not starting from checkpoint, need to synchronize initial model // strategy should be to run the initializer above on mpiRank==0, and then broadcast parameters. @@ -2596,7 +2603,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { template class SGD; template class SGD; - // register ComputationNode with the ScriptableObject system - ScriptableObjects::ConfigurableRuntimeTypeRegister::Add registerComputationNode(L"SGDParams"); + // register ComputationNode with the ScriptableObject system + ScriptableObjects::ConfigurableRuntimeTypeRegister::Add registerComputationNode(L"SGDParams"); }}} diff --git a/MachineLearning/CNTKSGDLib/SGD.h b/MachineLearning/CNTKSGDLib/SGD.h index aaf174d19..f22d7dfdb 100644 --- a/MachineLearning/CNTKSGDLib/SGD.h +++ b/MachineLearning/CNTKSGDLib/SGD.h @@ -275,7 +275,7 @@ public: void SequenceTrain(IComputationNetBuilder* netBuilder, wstring origModelFileName, IDataReader* trainSetDataReader, IDataReader* validationSetDataReader, const DEVICEID_TYPE deviceID, const bool makeMode = true); - void Train(IComputationNetBuilder* netBuilder, + void Train(function createNetworkFn, DEVICEID_TYPE deviceId, IDataReader* trainSetDataReader, IDataReader* validationSetDataReader, const bool makeMode = true);