SGD::Train() now takes a lambda to create the virgin network instead of a IComputationNetBuilder. As a consequence, it now loads the model directly without going through a net builder. Why should it? We can now remove the load function from the network builder classes (once all Train()-like functions have been updated to do the same).
This commit is contained in:
Родитель
f32860c400
Коммит
517d409c95
|
@ -835,12 +835,11 @@ public:
|
|||
template <class ConfigRecordType, typename ElemType>
|
||||
void DoTrain(const ConfigRecordType & config)
|
||||
{
|
||||
const ConfigRecordType & configSGD(config(L"SGD", ConfigRecordType::Record()));
|
||||
bool makeMode = config(L"makeMode", true);
|
||||
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
||||
|
||||
shared_ptr<IComputationNetBuilder<ElemType>> netBuilder;// = GetCreateNetworkFunction(config);
|
||||
shared_ptr<IComputationNetBuilder<ElemType>> netBuilder;
|
||||
|
||||
// TODO: turn the netBuilder into a lambda
|
||||
if (config.Exists(L"createNetwork"))
|
||||
{
|
||||
netBuilder = make_shared<BrainScriptNetworkBuilder<ElemType>>(config);
|
||||
|
@ -866,6 +865,13 @@ void DoTrain(const ConfigRecordType & config)
|
|||
RuntimeError("No network builder found in the config file. NDLNetworkBuilder or SimpleNetworkBuilde must be specified");
|
||||
}
|
||||
|
||||
// network creation is handled by a lambda, which we define here
|
||||
function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn = [netBuilder](DEVICEID_TYPE deviceId)
|
||||
{
|
||||
ComputationNetwork * net = netBuilder->BuildNetworkFromDescription();
|
||||
return shared_ptr<ComputationNetwork>(net);
|
||||
};
|
||||
|
||||
// BUGBUG: inconsistency with BrainScript: old config passes a config dict, whereas BrainScript creates the object right away
|
||||
const ConfigRecordType & readerConfig(config(L"reader", ConfigRecordType::Record()));
|
||||
//readerConfig.Insert("traceLevel", config(L"traceLevel", "0")); // TODO: fix this by making this an optional arg; or if this should not be inherited, then by disabling it
|
||||
|
@ -879,9 +885,10 @@ void DoTrain(const ConfigRecordType & config)
|
|||
cvDataReader = unique_ptr<DataReader<ElemType> >{ new DataReader<ElemType>(cvReaderConfig) };
|
||||
}
|
||||
|
||||
const ConfigRecordType & configSGD(config(L"SGD", ConfigRecordType::Record()));
|
||||
SGD<ElemType> sgd(SGDParams(configSGD, (ElemType)0));
|
||||
|
||||
sgd.Train(netBuilder.get(), dataReader.get(), cvDataReader.get(), makeMode);
|
||||
sgd.Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode);
|
||||
}
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace ScriptableObjects {
|
||||
|
|
|
@ -398,8 +398,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
fstream.GetMarker(FileMarker::fileMarkerEndSection, L"ECN");
|
||||
|
||||
//some internal values in the nodes are computed during validation
|
||||
// some internal values in the nodes are computed during validation
|
||||
ValidateNetwork(false, bAllowNoCriterionNode);
|
||||
ResetEvalTimeStamp();
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -984,5 +984,6 @@ private: // TODO: make all private that can be made private
|
|||
// TODO: does this apply to anything else besides temporary node-internal intermediate results? What, for example?
|
||||
MatrixPool m_matrixPool;
|
||||
};
|
||||
typedef shared_ptr<ComputationNetwork> ComputationNetworkPtr;
|
||||
|
||||
}}}
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#include "SGD.h"
|
||||
#include "DataReaderHelpers.h"
|
||||
#include "AllReduceDistGradAggregator.h"
|
||||
#include "ProgressTracing.h"
|
||||
#include "ProgressTracing.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
|
@ -513,12 +513,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void SGD<ElemType>::Train(IComputationNetBuilder<ElemType>* netBuilder,
|
||||
void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader,
|
||||
const bool makeMode)
|
||||
{
|
||||
if (netBuilder == nullptr || trainSetDataReader == nullptr)
|
||||
if (trainSetDataReader == nullptr)
|
||||
{
|
||||
InvalidArgument("netBuilder and trainSetDataReader should not be null.\n");
|
||||
}
|
||||
|
@ -533,8 +533,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (startEpoch >= 0)
|
||||
fprintf(stderr, "Starting from checkpoint. Load Network From File %ls.\n", modelFileName.c_str());
|
||||
|
||||
ComputationNetwork* net = startEpoch < 0 ? netBuilder->BuildNetworkFromDescription() :
|
||||
netBuilder->LoadNetworkFromFile(modelFileName);
|
||||
shared_ptr<ComputationNetwork> net;
|
||||
if (startEpoch < 0)
|
||||
net = createNetworkFn(deviceId);
|
||||
else
|
||||
{
|
||||
net = make_shared<ComputationNetwork>(deviceId);
|
||||
net->LoadFromFile<ElemType>(modelFileName, FileOptions::fileOptionsBinary, false/*bAllowNoCriterionNode*/, nullptr/*anotherNetwork*/);
|
||||
}
|
||||
|
||||
// TODO: BUGBUG: if not starting from checkpoint, need to synchronize initial model
|
||||
// strategy should be to run the initializer above on mpiRank==0, and then broadcast parameters.
|
||||
|
||||
|
@ -2596,7 +2603,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
template class SGD<float>;
|
||||
template class SGD<double>;
|
||||
|
||||
// register ComputationNode with the ScriptableObject system
|
||||
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<SGDParams> registerComputationNode(L"SGDParams");
|
||||
// register ComputationNode with the ScriptableObject system
|
||||
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<SGDParams> registerComputationNode(L"SGDParams");
|
||||
|
||||
}}}
|
||||
|
|
|
@ -275,7 +275,7 @@ public:
|
|||
void SequenceTrain(IComputationNetBuilder<ElemType>* netBuilder, wstring origModelFileName,
|
||||
IDataReader<ElemType>* trainSetDataReader, IDataReader<ElemType>* validationSetDataReader,
|
||||
const DEVICEID_TYPE deviceID, const bool makeMode = true);
|
||||
void Train(IComputationNetBuilder<ElemType>* netBuilder,
|
||||
void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
|
||||
IDataReader<ElemType>* trainSetDataReader,
|
||||
IDataReader<ElemType>* validationSetDataReader,
|
||||
const bool makeMode = true);
|
||||
|
|
Загрузка…
Ссылка в новой задаче