Merge branch 'master' into ebarsoum/ImageHandsOn
Checkin...
This commit is contained in:
Коммит
8f3d78430e
|
@ -1 +1 @@
|
|||
Subproject commit 26475afc2945db5be61494dfbb542ba058f9b862
|
||||
Subproject commit 6535b08760744c890a88e4c934352ae7fb6b6e30
|
|
@ -3623,6 +3623,9 @@ namespace CNTK
|
|||
// Optionally overridable method to restore state pertaining this distributed training method from a previous checkpoint
|
||||
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0;
|
||||
|
||||
// Return the distributed communicator used in the distributed trainer
|
||||
CNTK_API virtual DistributedCommunicatorPtr GetCommunicator() = 0;
|
||||
|
||||
virtual ~DistributedTrainer() {}
|
||||
};
|
||||
|
||||
|
|
|
@ -30,6 +30,11 @@ namespace CNTK
|
|||
void RestoreFromCheckpoint(const Dictionary& checkpoint) override;
|
||||
|
||||
private:
|
||||
DistributedCommunicatorPtr GetCommunicator() override
|
||||
{
|
||||
return m_communicator;
|
||||
}
|
||||
|
||||
DistributedCommunicatorPtr m_communicator;
|
||||
bool m_useAsyncBufferedParameterUpdate;
|
||||
};
|
||||
|
|
|
@ -223,22 +223,43 @@ namespace CNTK
|
|||
|
||||
void Trainer::SaveCheckpoint(const std::wstring& modelFilePath, bool usinglegacyModelFormat)
|
||||
{
|
||||
m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat);
|
||||
|
||||
vector<DictionaryValue> learnerStates;
|
||||
|
||||
for (const auto& learner : m_parameterLearners)
|
||||
bool shouldSave = true;
|
||||
if (m_distributedTrainer != nullptr)
|
||||
{
|
||||
// TODO: add DictionaryValue(T&&)
|
||||
learnerStates.push_back(DictionaryValue(learner->Serialize()));
|
||||
// all workers need to sync up before saving model to avoid write-after-read hazard
|
||||
// i.e. one worker is in the middle of reading a checkpoint while another overwrites
|
||||
m_distributedTrainer->GetCommunicator()->Barrier();
|
||||
|
||||
// for distributed training, only save checkpoint at worker 0
|
||||
shouldSave = m_distributedTrainer->GetCommunicator()->CurrentWorker().IsMain();
|
||||
}
|
||||
|
||||
if (shouldSave)
|
||||
{
|
||||
m_combinedTrainingFunction->SaveModel(modelFilePath, usinglegacyModelFormat);
|
||||
|
||||
vector<DictionaryValue> learnerStates;
|
||||
|
||||
for (const auto& learner : m_parameterLearners)
|
||||
{
|
||||
// TODO: add DictionaryValue(T&&)
|
||||
learnerStates.push_back(DictionaryValue(learner->Serialize()));
|
||||
}
|
||||
|
||||
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
|
||||
auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false);
|
||||
// TODO: this will create an extra copy of all leaner states,
|
||||
// add DictionaryValue ctor that takes an rvalue!
|
||||
*ckpStream << DictionaryValue(learnerStates);
|
||||
ckpStream->flush();
|
||||
std::wstring trainerStateCheckpointFilePath = GetTrainerStateCheckpointFilePath(modelFilePath);
|
||||
auto ckpStream = GetFstream(trainerStateCheckpointFilePath, false);
|
||||
// TODO: this will create an extra copy of all leaner states,
|
||||
// add DictionaryValue ctor that takes an rvalue!
|
||||
*ckpStream << DictionaryValue(learnerStates);
|
||||
ckpStream->flush();
|
||||
}
|
||||
|
||||
if (m_distributedTrainer != nullptr)
|
||||
{
|
||||
// all workers need to sync up after saving model to avoid read-after-write hazard
|
||||
// i.e. one worker is in the middle of write while another tries to read
|
||||
m_distributedTrainer->GetCommunicator()->Barrier();
|
||||
}
|
||||
}
|
||||
|
||||
void Trainer::RestoreFromCheckpoint(const std::wstring& modelFilePath)
|
||||
|
|
|
@ -49,7 +49,8 @@ BOOL APIENTRY DllMain(HMODULE /*hModule*/,
|
|||
}
|
||||
break;
|
||||
case DLL_PROCESS_DETACH:
|
||||
_CrtSetReportHook2(_CRT_RPTHOOK_REMOVE, HandleDebugAssert);
|
||||
// DLL_PROCESS_DETACH may have race condition with code page unload
|
||||
//_CrtSetReportHook2(_CRT_RPTHOOK_REMOVE, HandleDebugAssert);
|
||||
break;
|
||||
#else
|
||||
case DLL_PROCESS_ATTACH:
|
||||
|
|
|
@ -57,15 +57,15 @@ void TrainSimpleDistributedFeedForwardClassifer(const DeviceDescriptor& device,
|
|||
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>> inputMeansAndInvStdDevs = { { featureStreamInfo, { nullptr, nullptr } } };
|
||||
ComputeInputPerDimMeansAndInvStdDevs(minibatchSource, inputMeansAndInvStdDevs);
|
||||
|
||||
auto nonLinearity = std::bind(Sigmoid, _1, L"");
|
||||
auto nonLinearity = std::bind(Sigmoid, _1, L"Sigmoid");
|
||||
auto input = InputVariable({ inputDim }, DataType::Float, L"features");
|
||||
auto normalizedinput = PerDimMeanVarianceNormalize(input, inputMeansAndInvStdDevs[featureStreamInfo].first, inputMeansAndInvStdDevs[featureStreamInfo].second);
|
||||
auto classifierOutput = FullyConnectedDNNLayer(normalizedinput, hiddenLayerDim, device, nonLinearity);
|
||||
auto classifierOutput = FullyConnectedDNNLayer(normalizedinput, hiddenLayerDim, device, nonLinearity, std::wstring(L"FullyConnectedInput") );
|
||||
for (size_t i = 1; i < numHiddenLayers; ++i)
|
||||
classifierOutput = FullyConnectedDNNLayer(classifierOutput, hiddenLayerDim, device, nonLinearity);
|
||||
classifierOutput = FullyConnectedDNNLayer(classifierOutput, hiddenLayerDim, device, nonLinearity, std::wstring(L"FullyConnectedHidden"));
|
||||
|
||||
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device));
|
||||
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device));
|
||||
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device), L"outputTimesParam");
|
||||
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device), L"outputBiasParam");
|
||||
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
|
||||
|
||||
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels");
|
||||
|
|
|
@ -157,16 +157,16 @@ inline CNTK::FunctionPtr FullyConnectedLinearLayer(CNTK::Variable input, size_t
|
|||
assert(input.Shape().Rank() == 1);
|
||||
size_t inputDim = input.Shape()[0];
|
||||
|
||||
auto timesParam = CNTK::Parameter({ outputDim, inputDim }, CNTK::DataType::Float, CNTK::GlorotUniformInitializer(), device);
|
||||
auto timesFunction = CNTK::Times(timesParam, input);
|
||||
auto timesParam = CNTK::Parameter({ outputDim, inputDim }, CNTK::DataType::Float, CNTK::GlorotUniformInitializer(), device, L"timesParam");
|
||||
auto timesFunction = CNTK::Times(timesParam, input, L"times");
|
||||
|
||||
auto plusParam = CNTK::Parameter({ outputDim }, 0.0f, device);
|
||||
auto plusParam = CNTK::Parameter({ outputDim }, 0.0f, device, L"plusParam");
|
||||
return CNTK::Plus(plusParam, timesFunction, outputName);
|
||||
}
|
||||
|
||||
inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t outputDim, const CNTK::DeviceDescriptor& device, const std::function<CNTK::FunctionPtr(const CNTK::FunctionPtr&)>& nonLinearity)
|
||||
inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t outputDim, const CNTK::DeviceDescriptor& device, const std::function<CNTK::FunctionPtr(const CNTK::FunctionPtr&)>& nonLinearity, const std::wstring& outputName = L"")
|
||||
{
|
||||
return nonLinearity(FullyConnectedLinearLayer(input, outputDim, device));
|
||||
return nonLinearity(FullyConnectedLinearLayer(input, outputDim, device, outputName));
|
||||
}
|
||||
|
||||
inline CNTK::FunctionPtr FullyConnectedFeedForwardClassifierNet(CNTK::Variable input,
|
||||
|
|
|
@ -109,7 +109,6 @@ void TestNDArrayView(size_t numAxes, const DeviceDescriptor& device)
|
|||
|
||||
// Test readonliness
|
||||
auto errorMsg = "Was incorrectly able to get a writable buffer pointer from a readonly view";
|
||||
|
||||
// Should not be able to get the WritableDataBuffer for a read-only view
|
||||
VerifyException([&aliasView]() {
|
||||
ElementType* aliasViewBuffer = aliasView->WritableDataBuffer<ElementType>();
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Загрузка…
Ссылка в новой задаче