Merge branch 'master' into ebarsoum/ImageHandsOn

Checkin...
This commit is contained in:
Emad Barsoum 2016-11-01 23:00:26 -07:00
Родитель ee4b0fdba4 7f7f90fbb4
Коммит 8f3d78430e
9 изменённых файлов: 80 добавлений и 51 удалений

@ -1 +1 @@
Subproject commit 26475afc2945db5be61494dfbb542ba058f9b862
Subproject commit 6535b08760744c890a88e4c934352ae7fb6b6e30

Просмотреть файл

@ -3623,6 +3623,9 @@ namespace CNTK
// Optionally overridable method to restore state pertaining this distributed training method from a previous checkpoint
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0;
// Return the distributed communicator used in the distributed trainer
CNTK_API virtual DistributedCommunicatorPtr GetCommunicator() = 0;
virtual ~DistributedTrainer() {}
};

Просмотреть файл

@ -30,6 +30,11 @@ namespace CNTK
void RestoreFromCheckpoint(const Dictionary& checkpoint) override;
private:
DistributedCommunicatorPtr GetCommunicator() override
{
return m_communicator;
}
DistributedCommunicatorPtr m_communicator;
bool m_useAsyncBufferedParameterUpdate;
};

Просмотреть файл

@ -222,6 +222,19 @@ namespace CNTK
}
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, bool usinglegacyModelFormat)
{
bool shouldSave = true;
if (m_distributedTrainer != nullptr)
{
// all workers need to sync up before saving model to avoid write-after-read hazard
// i.e. one worker is in the middle of reading a checkpoint while another overwrites
m_distributedTrainer->GetCommunicator()->Barrier();
// for distributed training, only save checkpoint at worker 0
shouldSave = m_distributedTrainer->GetCommunicator()->CurrentWorker().IsMain();
}
if (shouldSave)
{
m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat);
@ -241,6 +254,14 @@ namespace CNTK
ckpStream->flush();
}
if (m_distributedTrainer != nullptr)
{
// all workers need to sync up after saving model to avoid read-after-write hazard
// i.e. one worker is in the middle of write while another tries to read
m_distributedTrainer->GetCommunicator()->Barrier();
}
}
void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
{
// Restore the model's parameters

Просмотреть файл

@ -49,7 +49,8 @@ BOOL APIENTRY DllMain(HMODULE /*hModule*/,
}
break;
case DLL_PROCESS_DETACH:
_CrtSetReportHook2(_CRT_RPTHOOK_REMOVE, HandleDebugAssert);
// DLL_PROCESS_DETACH may have race condition with code page unload
//_CrtSetReportHook2(_CRT_RPTHOOK_REMOVE, HandleDebugAssert);
break;
#else
case DLL_PROCESS_ATTACH:

Просмотреть файл

@ -57,15 +57,15 @@ void TrainSimpleDistributedFeedForwardClassifer(const DeviceDescriptor& device,
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>> inputMeansAndInvStdDevs = { { featureStreamInfo, { nullptr, nullptr } } };
ComputeInputPerDimMeansAndInvStdDevs(minibatchSource, inputMeansAndInvStdDevs);
auto nonLinearity = std::bind(Sigmoid, _1, L"");
auto nonLinearity = std::bind(Sigmoid, _1, L"Sigmoid");
auto input = InputVariable({ inputDim }, DataType::Float, L"features");
auto normalizedinput = PerDimMeanVarianceNormalize(input, inputMeansAndInvStdDevs[featureStreamInfo].first, inputMeansAndInvStdDevs[featureStreamInfo].second);
auto classifierOutput = FullyConnectedDNNLayer(normalizedinput, hiddenLayerDim, device, nonLinearity);
auto classifierOutput = FullyConnectedDNNLayer(normalizedinput, hiddenLayerDim, device, nonLinearity, std::wstring(L"FullyConnectedInput") );
for (size_t i = 1; i < numHiddenLayers; ++i)
classifierOutput = FullyConnectedDNNLayer(classifierOutput, hiddenLayerDim, device, nonLinearity);
classifierOutput = FullyConnectedDNNLayer(classifierOutput, hiddenLayerDim, device, nonLinearity, std::wstring(L"FullyConnectedHidden"));
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device));
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device));
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device), L"outputTimesParam");
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device), L"outputBiasParam");
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels");

Просмотреть файл

@ -157,16 +157,16 @@ inline CNTK::FunctionPtr FullyConnectedLinearLayer(CNTK::Variable input, size_t
assert(input.Shape().Rank() == 1);
size_t inputDim = input.Shape()[0];
auto timesParam = CNTK::Parameter({ outputDim, inputDim }, CNTK::DataType::Float, CNTK::GlorotUniformInitializer(), device);
auto timesFunction = CNTK::Times(timesParam, input);
auto timesParam = CNTK::Parameter({ outputDim, inputDim }, CNTK::DataType::Float, CNTK::GlorotUniformInitializer(), device, L"timesParam");
auto timesFunction = CNTK::Times(timesParam, input, L"times");
auto plusParam = CNTK::Parameter({ outputDim }, 0.0f, device);
auto plusParam = CNTK::Parameter({ outputDim }, 0.0f, device, L"plusParam");
return CNTK::Plus(plusParam, timesFunction, outputName);
}
inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t outputDim, const CNTK::DeviceDescriptor& device, const std::function<CNTK::FunctionPtr(const CNTK::FunctionPtr&)>& nonLinearity)
inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t outputDim, const CNTK::DeviceDescriptor& device, const std::function<CNTK::FunctionPtr(const CNTK::FunctionPtr&)>& nonLinearity, const std::wstring& outputName = L"")
{
return nonLinearity(FullyConnectedLinearLayer(input, outputDim, device));
return nonLinearity(FullyConnectedLinearLayer(input, outputDim, device, outputName));
}
inline CNTK::FunctionPtr FullyConnectedFeedForwardClassifierNet(CNTK::Variable input,

Просмотреть файл

@ -109,7 +109,6 @@ void TestNDArrayView(size_t numAxes, const DeviceDescriptor& device)
// Test readonliness
auto errorMsg = "Was incorrectly able to get a writable buffer pointer from a readonly view";
// Should not be able to get the WritableDataBuffer for a read-only view
VerifyException([&aliasView]() {
ElementType* aliasViewBuffer = aliasView->WritableDataBuffer<ElementType>();

Различия файлов скрыты, потому что одна или несколько строк слишком длинны