Merge branch 'master' into ebarsoum/ImageHandsOn
Checkin...
This commit is contained in:
Коммит
8f3d78430e
|
@ -1 +1 @@
|
||||||
Subproject commit 26475afc2945db5be61494dfbb542ba058f9b862
|
Subproject commit 6535b08760744c890a88e4c934352ae7fb6b6e30
|
|
@ -3623,6 +3623,9 @@ namespace CNTK
|
||||||
// Optionally overridable method to restore state pertaining this distributed training method from a previous checkpoint
|
// Optionally overridable method to restore state pertaining this distributed training method from a previous checkpoint
|
||||||
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0;
|
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0;
|
||||||
|
|
||||||
|
// Return the distributed communicator used in the distributed trainer
|
||||||
|
CNTK_API virtual DistributedCommunicatorPtr GetCommunicator() = 0;
|
||||||
|
|
||||||
virtual ~DistributedTrainer() {}
|
virtual ~DistributedTrainer() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,11 @@ namespace CNTK
|
||||||
void RestoreFromCheckpoint(const Dictionary& checkpoint) override;
|
void RestoreFromCheckpoint(const Dictionary& checkpoint) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
DistributedCommunicatorPtr GetCommunicator() override
|
||||||
|
{
|
||||||
|
return m_communicator;
|
||||||
|
}
|
||||||
|
|
||||||
DistributedCommunicatorPtr m_communicator;
|
DistributedCommunicatorPtr m_communicator;
|
||||||
bool m_useAsyncBufferedParameterUpdate;
|
bool m_useAsyncBufferedParameterUpdate;
|
||||||
};
|
};
|
||||||
|
|
|
@ -223,22 +223,43 @@ namespace CNTK
|
||||||
|
|
||||||
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, bool usinglegacyModelFormat)
|
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, bool usinglegacyModelFormat)
|
||||||
{
|
{
|
||||||
m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat);
|
bool shouldSave = true;
|
||||||
|
if (m_distributedTrainer != nullptr)
|
||||||
vector<DictionaryValue> learnerStates;
|
|
||||||
|
|
||||||
for (const auto& learner : m_parameterLearners)
|
|
||||||
{
|
{
|
||||||
// TODO: add DictionaryValue(T&&)
|
// all workers need to sync up before saving model to avoid write-after-read hazard
|
||||||
learnerStates.push_back(DictionaryValue(learner->Serialize()));
|
// i.e. one worker is in the middle of reading a checkpoint while another overwrites
|
||||||
|
m_distributedTrainer->GetCommunicator()->Barrier();
|
||||||
|
|
||||||
|
// for distributed training, only save checkpoint at worker 0
|
||||||
|
shouldSave = m_distributedTrainer->GetCommunicator()->CurrentWorker().IsMain();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (shouldSave)
|
||||||
|
{
|
||||||
|
m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat);
|
||||||
|
|
||||||
|
vector<DictionaryValue> learnerStates;
|
||||||
|
|
||||||
|
for (const auto& learner : m_parameterLearners)
|
||||||
|
{
|
||||||
|
// TODO: add DictionaryValue(T&&)
|
||||||
|
learnerStates.push_back(DictionaryValue(learner->Serialize()));
|
||||||
|
}
|
||||||
|
|
||||||
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
|
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
|
||||||
auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false);
|
auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false);
|
||||||
// TODO: this will create an extra copy of all leaner states,
|
// TODO: this will create an extra copy of all leaner states,
|
||||||
// add DictionaryValue ctor that takes an rvalue!
|
// add DictionaryValue ctor that takes an rvalue!
|
||||||
*ckpStream << DictionaryValue(learnerStates);
|
*ckpStream << DictionaryValue(learnerStates);
|
||||||
ckpStream->flush();
|
ckpStream->flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m_distributedTrainer != nullptr)
|
||||||
|
{
|
||||||
|
// all workers need to sync up after saving model to avoid read-after-write hazard
|
||||||
|
// i.e. one worker is in the middle of write while another tries to read
|
||||||
|
m_distributedTrainer->GetCommunicator()->Barrier();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
|
void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
|
||||||
|
|
|
@ -49,7 +49,8 @@ BOOL APIENTRY DllMain(HMODULE /*hModule*/,
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case DLL_PROCESS_DETACH:
|
case DLL_PROCESS_DETACH:
|
||||||
_CrtSetReportHook2(_CRT_RPTHOOK_REMOVE, HandleDebugAssert);
|
// DLL_PROCESS_DETACH may have race condition with code page unload
|
||||||
|
//_CrtSetReportHook2(_CRT_RPTHOOK_REMOVE, HandleDebugAssert);
|
||||||
break;
|
break;
|
||||||
#else
|
#else
|
||||||
case DLL_PROCESS_ATTACH:
|
case DLL_PROCESS_ATTACH:
|
||||||
|
|
|
@ -57,15 +57,15 @@ void TrainSimpleDistributedFeedForwardClassifer(const DeviceDescriptor& device,
|
||||||
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>> inputMeansAndInvStdDevs = { { featureStreamInfo, { nullptr, nullptr } } };
|
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>> inputMeansAndInvStdDevs = { { featureStreamInfo, { nullptr, nullptr } } };
|
||||||
ComputeInputPerDimMeansAndInvStdDevs(minibatchSource, inputMeansAndInvStdDevs);
|
ComputeInputPerDimMeansAndInvStdDevs(minibatchSource, inputMeansAndInvStdDevs);
|
||||||
|
|
||||||
auto nonLinearity = std::bind(Sigmoid, _1, L"");
|
auto nonLinearity = std::bind(Sigmoid, _1, L"Sigmoid");
|
||||||
auto input = InputVariable({ inputDim }, DataType::Float, L"features");
|
auto input = InputVariable({ inputDim }, DataType::Float, L"features");
|
||||||
auto normalizedinput = PerDimMeanVarianceNormalize(input, inputMeansAndInvStdDevs[featureStreamInfo].first, inputMeansAndInvStdDevs[featureStreamInfo].second);
|
auto normalizedinput = PerDimMeanVarianceNormalize(input, inputMeansAndInvStdDevs[featureStreamInfo].first, inputMeansAndInvStdDevs[featureStreamInfo].second);
|
||||||
auto classifierOutput = FullyConnectedDNNLayer(normalizedinput, hiddenLayerDim, device, nonLinearity);
|
auto classifierOutput = FullyConnectedDNNLayer(normalizedinput, hiddenLayerDim, device, nonLinearity, std::wstring(L"FullyConnectedInput") );
|
||||||
for (size_t i = 1; i < numHiddenLayers; ++i)
|
for (size_t i = 1; i < numHiddenLayers; ++i)
|
||||||
classifierOutput = FullyConnectedDNNLayer(classifierOutput, hiddenLayerDim, device, nonLinearity);
|
classifierOutput = FullyConnectedDNNLayer(classifierOutput, hiddenLayerDim, device, nonLinearity, std::wstring(L"FullyConnectedHidden"));
|
||||||
|
|
||||||
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device));
|
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device), L"outputTimesParam");
|
||||||
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device));
|
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device), L"outputBiasParam");
|
||||||
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
|
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
|
||||||
|
|
||||||
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels");
|
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels");
|
||||||
|
|
|
@ -157,16 +157,16 @@ inline CNTK::FunctionPtr FullyConnectedLinearLayer(CNTK::Variable input, size_t
|
||||||
assert(input.Shape().Rank() == 1);
|
assert(input.Shape().Rank() == 1);
|
||||||
size_t inputDim = input.Shape()[0];
|
size_t inputDim = input.Shape()[0];
|
||||||
|
|
||||||
auto timesParam = CNTK::Parameter({ outputDim, inputDim }, CNTK::DataType::Float, CNTK::GlorotUniformInitializer(), device);
|
auto timesParam = CNTK::Parameter({ outputDim, inputDim }, CNTK::DataType::Float, CNTK::GlorotUniformInitializer(), device, L"timesParam");
|
||||||
auto timesFunction = CNTK::Times(timesParam, input);
|
auto timesFunction = CNTK::Times(timesParam, input, L"times");
|
||||||
|
|
||||||
auto plusParam = CNTK::Parameter({ outputDim }, 0.0f, device);
|
auto plusParam = CNTK::Parameter({ outputDim }, 0.0f, device, L"plusParam");
|
||||||
return CNTK::Plus(plusParam, timesFunction, outputName);
|
return CNTK::Plus(plusParam, timesFunction, outputName);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t outputDim, const CNTK::DeviceDescriptor& device, const std::function<CNTK::FunctionPtr(const CNTK::FunctionPtr&)>& nonLinearity)
|
inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t outputDim, const CNTK::DeviceDescriptor& device, const std::function<CNTK::FunctionPtr(const CNTK::FunctionPtr&)>& nonLinearity, const std::wstring& outputName = L"")
|
||||||
{
|
{
|
||||||
return nonLinearity(FullyConnectedLinearLayer(input, outputDim, device));
|
return nonLinearity(FullyConnectedLinearLayer(input, outputDim, device, outputName));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline CNTK::FunctionPtr FullyConnectedFeedForwardClassifierNet(CNTK::Variable input,
|
inline CNTK::FunctionPtr FullyConnectedFeedForwardClassifierNet(CNTK::Variable input,
|
||||||
|
|
|
@ -109,7 +109,6 @@ void TestNDArrayView(size_t numAxes, const DeviceDescriptor& device)
|
||||||
|
|
||||||
// Test readonliness
|
// Test readonliness
|
||||||
auto errorMsg = "Was incorrectly able to get a writable buffer pointer from a readonly view";
|
auto errorMsg = "Was incorrectly able to get a writable buffer pointer from a readonly view";
|
||||||
|
|
||||||
// Should not be able to get the WritableDataBuffer for a read-only view
|
// Should not be able to get the WritableDataBuffer for a read-only view
|
||||||
VerifyException([&aliasView]() {
|
VerifyException([&aliasView]() {
|
||||||
ElementType* aliasViewBuffer = aliasView->WritableDataBuffer<ElementType>();
|
ElementType* aliasViewBuffer = aliasView->WritableDataBuffer<ElementType>();
|
||||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Загрузка…
Ссылка в новой задаче