SGD::Train() now takes a lambda to create the virgin network instead of a IComputationNetBuilder. As a consequence, it now loads the model directly without going through a net builder. Why should it? We can now remove the load function from the network builder classes (once all Train()-like functions have been updated to do the same).

This commit is contained in:
Frank Seide 2015-11-21 01:09:18 -08:00
Родитель f32860c400
Коммит 517d409c95
5 изменённых файлов: 29 добавлений и 13 удалений

Просмотреть файл

@ -835,12 +835,11 @@ public:
template <class ConfigRecordType, typename ElemType>
void DoTrain(const ConfigRecordType & config)
{
const ConfigRecordType & configSGD(config(L"SGD", ConfigRecordType::Record()));
bool makeMode = config(L"makeMode", true);
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
shared_ptr<IComputationNetBuilder<ElemType>> netBuilder;// = GetCreateNetworkFunction(config);
shared_ptr<IComputationNetBuilder<ElemType>> netBuilder;
// TODO: turn the netBuilder into a lambda
if (config.Exists(L"createNetwork"))
{
netBuilder = make_shared<BrainScriptNetworkBuilder<ElemType>>(config);
@ -866,6 +865,13 @@ void DoTrain(const ConfigRecordType & config)
RuntimeError("No network builder found in the config file. NDLNetworkBuilder or SimpleNetworkBuilde must be specified");
}
// network creation is handled by a lambda, which we define here
function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn = [netBuilder](DEVICEID_TYPE deviceId)
{
ComputationNetwork * net = netBuilder->BuildNetworkFromDescription();
return shared_ptr<ComputationNetwork>(net);
};
// BUGBUG: inconsistency with BrainScript: old config passes a config dict, whereas BrainScript creates the object right away
const ConfigRecordType & readerConfig(config(L"reader", ConfigRecordType::Record()));
//readerConfig.Insert("traceLevel", config(L"traceLevel", "0")); // TODO: fix this by making this an optional arg; or if this should not be inherited, then by disabling it
@ -879,9 +885,10 @@ void DoTrain(const ConfigRecordType & config)
cvDataReader = unique_ptr<DataReader<ElemType> >{ new DataReader<ElemType>(cvReaderConfig) };
}
const ConfigRecordType & configSGD(config(L"SGD", ConfigRecordType::Record()));
SGD<ElemType> sgd(SGDParams(configSGD, (ElemType)0));
sgd.Train(netBuilder.get(), dataReader.get(), cvDataReader.get(), makeMode);
sgd.Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode);
}
namespace Microsoft { namespace MSR { namespace ScriptableObjects {

Просмотреть файл

@ -398,8 +398,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ECN");
//some internal values in the nodes are computed during validation
// some internal values in the nodes are computed during validation
ValidateNetwork(false, bAllowNoCriterionNode);
ResetEvalTimeStamp();
}
// -----------------------------------------------------------------------

Просмотреть файл

@ -984,5 +984,6 @@ private: // TODO: make all private that can be made private
// TODO: does this apply to anything else besides temporary node-internal intermediate results? What, for example?
MatrixPool m_matrixPool;
};
typedef shared_ptr<ComputationNetwork> ComputationNetworkPtr;
}}}

Просмотреть файл

@ -6,7 +6,7 @@
#include "SGD.h"
#include "DataReaderHelpers.h"
#include "AllReduceDistGradAggregator.h"
#include "ProgressTracing.h"
#include "ProgressTracing.h"
#include <map>
@ -513,12 +513,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
template<class ElemType>
void SGD<ElemType>::Train(IComputationNetBuilder<ElemType>* netBuilder,
void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader,
const bool makeMode)
{
if (netBuilder == nullptr || trainSetDataReader == nullptr)
if (trainSetDataReader == nullptr)
{
InvalidArgument("netBuilder and trainSetDataReader should not be null.\n");
}
@ -533,8 +533,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (startEpoch >= 0)
fprintf(stderr, "Starting from checkpoint. Load Network From File %ls.\n", modelFileName.c_str());
ComputationNetwork* net = startEpoch < 0 ? netBuilder->BuildNetworkFromDescription() :
netBuilder->LoadNetworkFromFile(modelFileName);
shared_ptr<ComputationNetwork> net;
if (startEpoch < 0)
net = createNetworkFn(deviceId);
else
{
net = make_shared<ComputationNetwork>(deviceId);
net->LoadFromFile<ElemType>(modelFileName, FileOptions::fileOptionsBinary, false/*bAllowNoCriterionNode*/, nullptr/*anotherNetwork*/);
}
// TODO: BUGBUG: if not starting from checkpoint, need to synchronize initial model
// strategy should be to run the initializer above on mpiRank==0, and then broadcast parameters.
@ -2596,7 +2603,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
template class SGD<float>;
template class SGD<double>;
// register ComputationNode with the ScriptableObject system
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<SGDParams> registerComputationNode(L"SGDParams");
// register ComputationNode with the ScriptableObject system
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<SGDParams> registerComputationNode(L"SGDParams");
}}}

Просмотреть файл

@ -275,7 +275,7 @@ public:
void SequenceTrain(IComputationNetBuilder<ElemType>* netBuilder, wstring origModelFileName,
IDataReader<ElemType>* trainSetDataReader, IDataReader<ElemType>* validationSetDataReader,
const DEVICEID_TYPE deviceID, const bool makeMode = true);
void Train(IComputationNetBuilder<ElemType>* netBuilder,
void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
IDataReader<ElemType>* trainSetDataReader,
IDataReader<ElemType>* validationSetDataReader,
const bool makeMode = true);