This commit is contained in:
Rui Zhao (SPEECH) 2019-10-04 17:03:47 -07:00
Родитель a6190d0572
Коммит 9800670b2b
5 изменённых файлов: 628 добавлений и 298 удалений

Просмотреть файл

@ -194,8 +194,34 @@ void DoAdapt(const ConfigParameters& config)
SGD<ElemType> sgd(configSGD);
sgd.InitMPI(MPIWrapper::GetInstance());
sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode);
//for RNNT TS
int startEpoch = sgd.DetermineStartEpoch(makeMode);
if (startEpoch == sgd.GetMaxEpochs())
{
LOGPRINTF(stderr, "No further training is necessary.\n");
return;
}
wstring modelFileName =sgd.GetModelNameForEpoch(int(startEpoch) - 1);
bool loadNetworkFromCheckpoint = startEpoch >= 0;
if (loadNetworkFromCheckpoint)
LOGPRINTF(stderr, "\nStarting from checkpoint. Loading network from '%ls'.\n", modelFileName.c_str());
else
LOGPRINTF(stderr, "\nCreating virgin network.\n");
// determine the network-creation function
// We have several ways to create that network.
function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn;
createNetworkFn = GetNetworkFactory<ConfigParameters, ElemType>(config);
// create or load from checkpoint
shared_ptr<ComputationNetwork> net = !loadNetworkFromCheckpoint ? createNetworkFn(deviceId) : ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelFileName);
sgd.Adapt(net, loadNetworkFromCheckpoint, origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode);
}
template void DoAdapt<float>(const ConfigParameters& config);

Просмотреть файл

@ -220,6 +220,215 @@ public:
}
}
//decoding for RNNT
template <class ElemType>
void RNNT_decode_greedy(const std::vector<std::wstring>& outputNodeNames, Matrix<ElemType>& encodeInputMatrix, MBLayout& encodeMBLayout,
Matrix<ElemType>& decodeInputMatrix, MBLayout& decodeMBLayout,vector<vector<size_t>> &outputlabels, float groundTruthWeight /*mt19937_64 randGen*/)
{
if (outputNodeNames.size() == 0)
fprintf(stderr, "OutputNodeNames are not specified, using the default outputnodes.\n");
std::vector<ComputationNodeBasePtr> outputNodes = OutputNodesByName(outputNodeNames);
//AllocateAllMatrices({}, outputNodes, nullptr);
//encoder related nodes
std::vector<std::wstring> encodeOutputNodeNames(outputNodeNames.begin(), outputNodeNames.begin() + 1);
std::vector<ComputationNodeBasePtr> encodeOutputNodes = OutputNodesByName(encodeOutputNodeNames);
std::vector<ComputationNodeBasePtr> encodeInputNodes = InputNodesForOutputs(encodeOutputNodeNames);
//StreamMinibatchInputs encodeInputMatrices = DataReaderHelpers::RetrieveInputMatrices(encodeInputNodes);
StartEvaluateMinibatchLoop(encodeOutputNodes[0]);
//prediction related nodes
std::vector<std::wstring> decodeOutputNodeNames(outputNodeNames.begin() + 1, outputNodeNames.begin() + 2);
std::vector<ComputationNodeBasePtr> decodeOutputNodes = OutputNodesByName(decodeOutputNodeNames);
std::vector<ComputationNodeBasePtr> decodeinputNodes = InputNodesForOutputs(decodeOutputNodeNames);
//StreamMinibatchInputs decodeinputMatrices = DataReaderHelpers::RetrieveInputMatrices(decodeinputNodes);
//joint nodes
ComputationNodeBasePtr PlusNode = GetNodeFromName(outputNodeNames[2]);
ComputationNodeBasePtr PlusTransNode = GetNodeFromName(outputNodeNames[3]);
ComputationNodeBasePtr WmNode = GetNodeFromName(outputNodeNames[4]);
ComputationNodeBasePtr bmNode = GetNodeFromName(outputNodeNames[5]);
std::vector<ComputationNodeBasePtr> Plusnodes, Plustransnodes;
Plusnodes.push_back(PlusNode);
Plustransnodes.push_back(PlusTransNode);
//start eval
StartEvaluateMinibatchLoop(decodeOutputNodes[0]);
//auto lminput = decodeinputMatrices.begin();
size_t deviceid = decodeInputMatrix.GetDeviceId();
std::map<std::wstring, void*, nocase_compare> outputMatrices;
Matrix<ElemType> encodeOutput(deviceid);
Matrix<ElemType> decodeOutput(deviceid), Wm(deviceid), bm(deviceid), tempMatrix(deviceid);
Matrix<ElemType> greedyOutput(deviceid), greedyOutputMax(deviceid);
Matrix<ElemType> sumofENandDE(deviceid), maxIdx(deviceid), maxVal(deviceid);
Matrix<ElemType> lmin(deviceid);
Wm.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(WmNode)->Value()));
bm.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(bmNode)->Value()));
const size_t numIterationsBeforePrintingProgress = 100;
//get MBlayer of encoder input
size_t numParallelSequences = encodeMBLayout.GetNumParallelSequences();
const auto numSequences = encodeMBLayout.GetNumSequences();
//get frame number, phone number and output label number
const size_t numRows = encodeInputMatrix.GetNumRows();
const size_t numCols = encodeInputMatrix.GetNumCols();
//size_t maxFrameNum = numCols / numParallelSequences;
std::vector<size_t> uttFrameBeginIdx;
// the frame number of each utterance. The size of this vector = the number of all utterances in this minibatch
std::vector<size_t> uttFrameNum;
// map from utterance ID to minibatch channel ID. We need this because each channel may contain more than one utterance.
std::vector<size_t> uttFrameToChanInd;
//size_t totalcol = 0;
uttFrameNum.clear();
uttFrameToChanInd.clear();
uttFrameBeginIdx.clear();
uttFrameNum.reserve(numSequences);
uttFrameToChanInd.reserve(numSequences);
uttFrameBeginIdx.reserve(numSequences);
//get utt information, such as channel map id and utt begin frame, utt frame num, utt phone num for frame and phone respectively....
size_t seqId = 0; //frame
size_t totalframenum = 0;
for (const auto& seq : encodeMBLayout.GetAllSequences())
{
if (seq.seqId == GAP_SEQUENCE_ID)
{
continue;
}
assert(seq.seqId == seqId);
seqId++;
uttFrameToChanInd.push_back(seq.s);
size_t numFrames = seq.GetNumTimeSteps();
uttFrameBeginIdx.push_back(seq.tBegin);
uttFrameNum.push_back(numFrames);
totalframenum += numFrames;
}
//resize output
outputlabels.resize(numSequences);
// forward prop encoder
ComputationNetwork::BumpEvalTimeStamp(encodeInputNodes);
ForwardProp(encodeOutputNodes[0]);
encodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(encodeOutputNodes[0])->Value()));
size_t vocabSize = bm.GetNumRows();
size_t blankId = vocabSize - 1;
for (size_t uttID = 0; uttID < numSequences; uttID ++)
{
lmin.Resize(vocabSize, 1);
lmin.SetValue(0.0);
lmin(blankId, 0) = 1;
decodeMBLayout.Init(1, 1);
std::swap(decodeInputMatrix, lmin);
decodeMBLayout.AddSequence(NEW_SEQUENCE_ID, 0, 0, 2000);
ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);
ForwardProp(decodeOutputNodes[0]);
greedyOutputMax.Resize(vocabSize, 2000);
size_t lmt = 0;
for (size_t t = 0; t < uttFrameNum[uttID]; t++)
{
decodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(decodeOutputNodes[0])->Value()));
//auto edNode = PlusNode->As<PlusBroadcastNode<ElemType>>();
size_t tinMB = (t + uttFrameBeginIdx[uttID]) * numParallelSequences + uttFrameToChanInd[uttID];
sumofENandDE.AssignSumOf(encodeOutput.ColumnSlice(tinMB, 1), decodeOutput);
//sumofENandDE.AssignSumOf(encodeOutput.ColumnSlice(t, 1), decodeOutput);
(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusNode)->Value())->SetValue(sumofENandDE);
ComputationNetwork::BumpEvalTimeStamp(Plusnodes);
auto PlusMBlayout = PlusNode->GetMBLayout();
PlusMBlayout->Init(1, 1);
PlusMBlayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, 1);
ForwardPropFromTo(Plusnodes, Plustransnodes);
decodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusTransNode)->Value()));
tempMatrix.AssignProductOf(Wm, true, decodeOutput, false);
decodeOutput.AssignSumOf(tempMatrix, bm);
decodeOutput.VectorMax(maxIdx, maxVal, true);
size_t maxId = (size_t)(maxIdx.Get00Element());
if (maxId != blankId)
{
outputlabels[uttID].push_back(maxId);
lmin.Resize(vocabSize, 1);
lmin.SetValue(0.0);
lmin(maxId, 0) = 1.0;
greedyOutputMax.SetColumn(lmin, lmt);
std::swap(decodeInputMatrix, lmin);
decodeMBLayout.Init(1, 1);
decodeMBLayout.AddSequence(NEW_SEQUENCE_ID, 0, -1 - lmt, 1999 - lmt);
ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);
//DataReaderHelpers::NotifyChangedNodes<ElemType>(m_net, decodeinputMatrices);
ForwardProp(decodeOutputNodes[0]);
lmt++;
}
}
if (lmt == 0)
{
outputlabels[uttID].push_back(blankId);
}
//break;
}
//decode
//make new MBLayout for decoder input
MBLayoutPtr newdecodeMBLayout = make_shared<MBLayout>();
std::vector<std::pair<size_t, size_t>> placement;
std::vector<MBLayout::SequenceInfo> sequences;
for (size_t i = 0; i < outputlabels.size(); ++i)
sequences.push_back({i, SIZE_MAX, 0, outputlabels[i].size()});
std::vector<size_t> rowAllocations;
newdecodeMBLayout->InitAsPackedSequences(sequences, placement, rowAllocations);
decodeInputMatrix.Resize(vocabSize, newdecodeMBLayout->GetNumCols() * newdecodeMBLayout->GetNumParallelSequences());
decodeInputMatrix.SetValue(0.0f);
//fill the decoder input
const auto& sequenceInfos = newdecodeMBLayout->GetAllSequences();
for (int i = 0; i < sequenceInfos.size(); ++i)
{
const auto& sequenceInfo = sequenceInfos[i];
// skip gaps
if (sequenceInfo.seqId == GAP_SEQUENCE_ID)
{
continue;
}
//const auto& sequence = batch[sequenceInfo.seqId];
size_t numSamples = outputlabels[i].size();
assert(numSamples == sequenceInfo.GetNumTimeSteps());
// Iterate over all samples in the sequence, keep track of the sample offset (which is especially
// important for sparse input, where offset == number of preceding nnz elements).
for (size_t sampleIndex = 0;sampleIndex < numSamples; ++sampleIndex)
{
// Compute the offset into the destination buffer, using the layout information
// to get the column index corresponding to the given sample.
size_t destinationOffset = newdecodeMBLayout->GetColumnIndex(sequenceInfo, sampleIndex) ;
// verify that there's enough space left in the buffer to fit a full sample.
//assert(destinationOffset <= buffer.m_size - sampleSize);
//auto* destination = bufferPtr + destinationOffset;
decodeInputMatrix.SetValue(outputlabels[i][sampleIndex], destinationOffset, 1.0f);
}
}
//copy the new MBLayout
decodeMBLayout.CopyFrom(newdecodeMBLayout);
// clean up
}
static void BumpEvalTimeStamp(const std::vector<ComputationNodeBasePtr>& nodes);
void ResetEvalTimeStamps();
void SetEvalTimeStampsOutdatedWithRegardToAll();

Просмотреть файл

@ -220,7 +220,7 @@ public:
virtual void Save(File& fstream) const override
{
Base::Save(fstream);
fstream << m_combineMode;
//fstream << m_combineMode;
}
virtual void Load(File& fstream, size_t modelVersion) override

Просмотреть файл

@ -90,7 +90,7 @@ void SGD<ElemType>::Train(shared_ptr<ComputationNetwork> net, DEVICEID_TYPE devi
// -----------------------------------------------------------------------
template <class ElemType>
void SGD<ElemType>::Adapt(wstring origModelFileName, wstring refNodeName,
void SGD<ElemType>::Adapt(shared_ptr<ComputationNetwork> net, bool networkLoadedFromCheckpoint, wstring origModelFileName, wstring refNodeName,
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader,
const DEVICEID_TYPE deviceId, const bool makeMode)
@ -102,7 +102,7 @@ void SGD<ElemType>::Adapt(wstring origModelFileName, wstring refNodeName,
return;
}
ComputationNetworkPtr net;
/*ComputationNetworkPtr net;
bool networkLoadedFromCheckpoint = false;
if (startEpoch >= 0)
{
@ -115,17 +115,22 @@ void SGD<ElemType>::Adapt(wstring origModelFileName, wstring refNodeName,
{
LOGPRINTF(stderr, "Load Network From the original model file %ls.\n", origModelFileName.c_str());
net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, origModelFileName);
}
}*/
startEpoch = max(startEpoch, 0);
ComputationNetworkPtr refNet;
m_needAdaptRegularization = m_adaptationRegType != AdaptationRegType::None && m_adaptationRegWeight > 0;
if (m_needAdaptRegularization)
if (m_needAdaptRegularization && m_adaptationRegType != AdaptationRegType::TS)
{
LOGPRINTF(stderr, "Load reference Network From the original model file %ls.\n", origModelFileName.c_str());
refNet = ComputationNetwork::CreateFromFile<ElemType>(deviceId, origModelFileName);
}
else if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS)
{
refNet = make_shared<ComputationNetwork>(deviceId);
refNet->Read<ElemType>(origModelFileName);
}
ComputationNodeBasePtr refNode;
if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL)
@ -649,39 +654,89 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
EpochCriterion epochCriterion; // criterion values are returned in this
std::vector<EpochCriterion> epochEvalErrors(evaluationNodes.size());
totalMBsSeen += TrainOneEpoch(net,
refNet,
refNode,
i,
m_epochSize,
trainSetDataReader,
learnRatePerSample,
chosenMinibatchSize,
featureNodes,
labelNodes,
criterionNodes,
evaluationNodes,
inputMatrices,
learnableNodes, smoothedGradients, smoothedCounts,
epochCriterion, epochEvalErrors,
"", SIZE_MAX, totalMBsSeen, tensorBoardWriter, startEpoch);
totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only
timer.Stop();
double epochTime = timer.ElapsedSeconds();
//RNNT TS
StreamMinibatchInputs decodeinputMatrices, encodeInputMatrices;
vector<wstring> outputNodeNamesVector;
if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS && refNet)
{
if (m_outputNodeNames.size() > 0)
{
// clear out current list of outputNodes
while (!refNet->OutputNodes().empty())
refNet->RemoveFromNodeGroup(L"output", refNet->OutputNodes().back());
// and insert the desired nodes instead
for (wstring name : m_outputNodeNames)
{
if (!refNet->NodeNameExists(name))
{
fprintf(stderr, "PatchOutputNodes: No node named '%ls'; skipping\n", name.c_str());
continue;
}
outputNodeNamesVector.push_back(name);
let& node = refNet->GetNodeFromName(name);
refNet->AddToNodeGroup(L"output", node);
}
}
refNet->CompileNetwork();
//PatchOutputNodes(refNet, m_outputNodeNames, outputNodeNamesVector);
if (m_useEvalCriterionControlLR && epochEvalErrors.size() > 0)
lrControlCriterion = epochEvalErrors[0].Average();
else
lrControlCriterion = epochCriterion.Average();
std::vector<ComputationNodeBasePtr> outputNodes = refNet->OutputNodesByName(outputNodeNamesVector);
refNet->AllocateAllMatrices({}, outputNodes, nullptr);
LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Training] ", i + 1, (int) m_maxEpochs);
epochCriterion.LogCriterion(criterionNodes[0]->NodeName());
std::vector<std::wstring> encodeOutputNodeNames(outputNodeNamesVector.begin(), outputNodeNamesVector.begin() + 1);
std::vector<ComputationNodeBasePtr> encodeOutputNodes = refNet->OutputNodesByName(encodeOutputNodeNames);
std::vector<ComputationNodeBasePtr> encodeInputNodes = refNet->InputNodesForOutputs(encodeOutputNodeNames);
encodeInputMatrices = DataReaderHelpers::RetrieveInputMatrices(encodeInputNodes);
m_lastFinishedEpochTrainLoss = epochCriterion.Average();
for (size_t j = 0; j < epochEvalErrors.size(); j++)
epochEvalErrors[j].LogCriterion(evaluationNodes[j]->NodeName());
fprintf(stderr, "totalSamplesSeen = %zu; learningRatePerSample = %.8g; epochTime=%.6gs\n", totalTrainingSamplesSeen, learnRatePerSample, epochTime);
//get decode input matrix
std::vector<std::wstring> decodeOutputNodeNames(outputNodeNamesVector.begin() + 1, outputNodeNamesVector.begin() + 2);
std::vector<ComputationNodeBasePtr> decodeOutputNodes = refNet->OutputNodesByName(decodeOutputNodeNames);
std::vector<ComputationNodeBasePtr> decodeinputNodes = refNet->InputNodesForOutputs(decodeOutputNodeNames);
decodeinputMatrices = DataReaderHelpers::RetrieveInputMatrices(decodeinputNodes);
//DataReaderHelpers::
//StreamBatch batch;
}
totalMBsSeen += TrainOneEpoch(net,
refNet,
refNode,
i,
m_epochSize,
trainSetDataReader,
learnRatePerSample,
chosenMinibatchSize,
featureNodes,
labelNodes,
criterionNodes,
evaluationNodes,
inputMatrices,
learnableNodes, smoothedGradients, smoothedCounts,
epochCriterion, epochEvalErrors,
"", SIZE_MAX, totalMBsSeen, tensorBoardWriter, startEpoch,
outputNodeNamesVector, &encodeInputMatrices, &decodeinputMatrices);
totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only
timer.Stop();
double epochTime = timer.ElapsedSeconds();
if (m_useEvalCriterionControlLR && epochEvalErrors.size() > 0)
lrControlCriterion = epochEvalErrors[0].Average();
else
lrControlCriterion = epochCriterion.Average();
LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Training] ", i + 1, (int) m_maxEpochs);
epochCriterion.LogCriterion(criterionNodes[0]->NodeName());
m_lastFinishedEpochTrainLoss = epochCriterion.Average();
for (size_t j = 0; j < epochEvalErrors.size(); j++)
epochEvalErrors[j].LogCriterion(evaluationNodes[j]->NodeName());
fprintf(stderr, "totalSamplesSeen = %zu; learningRatePerSample = %.8g; epochTime=%.6gs\n", totalTrainingSamplesSeen, learnRatePerSample, epochTime);
#if 0
// TODO: This was only printed if >1 eval criterion. Why? Needed?
LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: Criterion Node [%ls] Per Sample = %.8g\n",
@ -694,283 +749,283 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
}
#endif
if (tensorBoardWriter)
{
tensorBoardWriter->WriteValue(L"summary/" + criterionNodes[0]->NodeName(), (float) epochCriterion.Average(), i + 1);
for (size_t j = 0; j < epochEvalErrors.size(); j++)
{
tensorBoardWriter->WriteValue(L"summary/" + evaluationNodes[0]->NodeName(), (float) epochEvalErrors[j].Average(), i + 1);
}
tensorBoardWriter->Flush();
}
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
{
// TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD, since Multiverso has another background thread using MPI.
// Making the evaluation serial (non-distributed) will slowdown training especially when validation set is large.
SimpleEvaluator<ElemType> evalforvalidation(net, UsingAsyncGradientAggregation(i + 1) ? nullptr : m_mpi, m_enableDistributedMBReading);
vector<wstring> cvSetTrainAndEvalNodes;
if (criterionNodes.size() > 0)
{
cvSetTrainAndEvalNodes.push_back(criterionNodes[0]->NodeName());
}
for (let node : evaluationNodes)
{
cvSetTrainAndEvalNodes.push_back(node->NodeName());
}
// BUGBUG: We should not use the training MB size. The training MB size is constrained by both convergence and memory. Eval is only constrained by memory.
let vScore = evalforvalidation.Evaluate(validationSetDataReader, cvSetTrainAndEvalNodes, UsingAsyncGradientAggregation(i + 1) ? m_mbSize[i] / m_mpi->NumNodesInUse() : m_mbSize[i]);
LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Validate] ", i + 1, (int) m_maxEpochs);
for (size_t k = 0; k < vScore.size() /*&& k < 2*/; k++)
vScore[k].LogCriterion(cvSetTrainAndEvalNodes[k], /*addSemicolon=*/k + 1 < vScore.size());
//fprintf(stderr, "%s %ls = %.8f * %d", k ? ";" : "", cvSetTrainAndEvalNodes[k].c_str(), vScore[k].Average(), (int)vScore[k].second);
fprintf(stderr, "\n");
if (tensorBoardWriter)
{
tensorBoardWriter->WriteValue(L"summary/" + criterionNodes[0]->NodeName(), (float) epochCriterion.Average(), i + 1);
for (size_t j = 0; j < epochEvalErrors.size(); j++)
for (size_t k = 0; k < vScore.size(); k++)
{
tensorBoardWriter->WriteValue(L"summary/" + evaluationNodes[0]->NodeName(), (float) epochEvalErrors[j].Average(), i + 1);
tensorBoardWriter->WriteValue(L"summary/test_" + cvSetTrainAndEvalNodes[k], (float) vScore[k].Average(), i + 1);
}
tensorBoardWriter->Flush();
}
if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr)
if (m_saveBestModelPerCriterion)
{
// TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD, since Multiverso has another background thread using MPI.
// Making the evaluation serial (non-distributed) will slowdown training especially when validation set is large.
SimpleEvaluator<ElemType> evalforvalidation(net, UsingAsyncGradientAggregation(i + 1) ? nullptr : m_mpi, m_enableDistributedMBReading);
vector<wstring> cvSetTrainAndEvalNodes;
if (criterionNodes.size() > 0)
{
cvSetTrainAndEvalNodes.push_back(criterionNodes[0]->NodeName());
}
for (let node : evaluationNodes)
{
cvSetTrainAndEvalNodes.push_back(node->NodeName());
}
// BUGBUG: We should not use the training MB size. The training MB size is constrained by both convergence and memory. Eval is only constrained by memory.
let vScore = evalforvalidation.Evaluate(validationSetDataReader, cvSetTrainAndEvalNodes, UsingAsyncGradientAggregation(i + 1) ? m_mbSize[i] / m_mpi->NumNodesInUse() : m_mbSize[i]);
LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Validate] ", i + 1, (int) m_maxEpochs);
for (size_t k = 0; k < vScore.size() /*&& k < 2*/; k++)
vScore[k].LogCriterion(cvSetTrainAndEvalNodes[k], /*addSemicolon=*/k + 1 < vScore.size());
//fprintf(stderr, "%s %ls = %.8f * %d", k ? ";" : "", cvSetTrainAndEvalNodes[k].c_str(), vScore[k].Average(), (int)vScore[k].second);
fprintf(stderr, "\n");
if (tensorBoardWriter)
{
for (size_t k = 0; k < vScore.size(); k++)
{
tensorBoardWriter->WriteValue(L"summary/test_" + cvSetTrainAndEvalNodes[k], (float) vScore[k].Average(), i + 1);
}
tensorBoardWriter->Flush();
}
if (m_saveBestModelPerCriterion)
{
// Loops through criteria (i.e. score) and updates the best one if smaller value is found.
UpdateBestEpochs(vScore, cvSetTrainAndEvalNodes, i, m_criteriaBestEpoch);
}
if (m_useCVSetControlLRIfCVExists)
{
if (m_useEvalCriterionControlLR && vScore.size() > 1)
lrControlCriterion = vScore[1].Average(); // use the first of possibly multiple eval criteria
else
lrControlCriterion = vScore[0].Average(); // the first one is the training criterion
}
// Loops through criteria (i.e. score) and updates the best one if smaller value is found.
UpdateBestEpochs(vScore, cvSetTrainAndEvalNodes, i, m_criteriaBestEpoch);
}
// broadcast epochCriterion to make sure each processor will have the same learning rate schedule
if ((GetParallelizationMethod() == ParallelizationMethod::modelAveragingSGD ||
GetParallelizationMethod() == ParallelizationMethod::blockMomentumSGD) &&
(m_mpi->NumNodesInUse() > 1))
if (m_useCVSetControlLRIfCVExists)
{
m_mpi->Bcast(&epochCriterion.first, 1, m_mpi->MainNodeRank());
m_mpi->Bcast(&epochCriterion.second, 1, m_mpi->MainNodeRank());
m_mpi->Bcast(&lrControlCriterion, 1, m_mpi->MainNodeRank());
}
bool loadedPrevModel = false;
size_t epochsSinceLastLearnRateAdjust = i % m_learnRateAdjustInterval + 1;
if (avgCriterion == numeric_limits<double>::infinity())
{
avgCriterion = lrControlCriterion;
}
else
{
avgCriterion = ((epochsSinceLastLearnRateAdjust - 1 - epochsNotCountedInAvgCriterion) *
avgCriterion +
lrControlCriterion) /
(epochsSinceLastLearnRateAdjust - epochsNotCountedInAvgCriterion);
}
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch &&
m_learningRatesParam.size() <= i && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval)
{
if (std::isnan(avgCriterion) || (prevCriterion - avgCriterion < 0 && prevCriterion != numeric_limits<double>::infinity()))
{
if (m_loadBestModel)
{
// roll back
auto bestModelPath = GetModelNameForEpoch(i - m_learnRateAdjustInterval);
LOGPRINTF(stderr, "Loading (rolling back to) previous model with best training-criterion value: %ls.\n", bestModelPath.c_str());
net->RereadPersistableParameters<ElemType>(bestModelPath);
LoadCheckPointInfo(i - m_learnRateAdjustInterval,
/*out*/ totalTrainingSamplesSeen,
/*out*/ learnRatePerSample,
smoothedGradients,
smoothedCounts,
/*out*/ prevCriterion,
/*out*/ m_prevChosenMinibatchSize);
loadedPrevModel = true;
}
}
if (m_continueReduce)
{
if (std::isnan(avgCriterion) ||
(prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion &&
prevCriterion != numeric_limits<double>::infinity()))
{
if (learnRateReduced == false)
{
learnRateReduced = true;
}
else
{
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
net->Save(GetModelNameForEpoch(i, true));
LOGPRINTF(stderr, "Finished training and saved final model\n\n");
break;
}
}
if (learnRateReduced)
{
learnRatePerSample *= m_learnRateDecreaseFactor;
LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample);
}
}
if (m_useEvalCriterionControlLR && vScore.size() > 1)
lrControlCriterion = vScore[1].Average(); // use the first of possibly multiple eval criteria
else
{
if (std::isnan(avgCriterion) ||
(prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion &&
prevCriterion != numeric_limits<double>::infinity()))
{
learnRatePerSample *= m_learnRateDecreaseFactor;
LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample);
}
else if (prevCriterion - avgCriterion > m_increaseLearnRateIfImproveMoreThan * prevCriterion &&
prevCriterion != numeric_limits<double>::infinity())
{
learnRatePerSample *= m_learnRateIncreaseFactor;
LOGPRINTF(stderr, "learnRatePerSample increased to %.8g\n", learnRatePerSample);
}
}
}
else
{
if (std::isnan(avgCriterion))
RuntimeError("The training criterion is not a number (NAN).");
}
// not loading previous values then set them
if (!loadedPrevModel && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval)
{
prevCriterion = avgCriterion;
epochsNotCountedInAvgCriterion = 0;
}
// Synchronize all ranks before proceeding to ensure that
// nobody tries reading the checkpoint file at the same time
// as rank 0 deleting it below
SynchronizeWorkers();
// Persist model and check-point info
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
{
if (loadedPrevModel)
{
// If previous best model is loaded, we will first remove epochs that lead to worse results
for (int j = 1; j < m_learnRateAdjustInterval; j++)
{
int epochToDelete = i - j;
LOGPRINTF(stderr, "SGD: removing model and checkpoint files for epoch %d after rollback to epoch %lu\n", epochToDelete + 1, (unsigned long) (i - m_learnRateAdjustInterval) + 1); // report 1 based epoch number
_wunlink(GetModelNameForEpoch(epochToDelete).c_str());
_wunlink(GetCheckPointFileNameForEpoch(epochToDelete).c_str());
}
// Set i back to the loaded model
i -= m_learnRateAdjustInterval;
LOGPRINTF(stderr, "SGD: revoke back to and update checkpoint file for epoch %d\n", i + 1); // report 1 based epoch number
SaveCheckPointInfo(
i,
totalTrainingSamplesSeen,
learnRatePerSample,
smoothedGradients,
smoothedCounts,
prevCriterion,
chosenMinibatchSize);
}
else
{
SaveCheckPointInfo(
i,
totalTrainingSamplesSeen,
learnRatePerSample,
smoothedGradients,
smoothedCounts,
prevCriterion,
chosenMinibatchSize);
auto modelName = GetModelNameForEpoch(i);
if (m_traceLevel > 0)
LOGPRINTF(stderr, "SGD: Saving checkpoint model '%ls'\n", modelName.c_str());
net->Save(modelName);
if (!m_keepCheckPointFiles)
{
// delete previous checkpoint file to save space
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && m_loadBestModel)
{
if (epochsSinceLastLearnRateAdjust != 1)
{
_wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str());
}
if (epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval)
{
_wunlink(GetCheckPointFileNameForEpoch(i - m_learnRateAdjustInterval).c_str());
}
}
else
{
_wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str());
}
}
}
}
else
{
if (loadedPrevModel)
{
// Set i back to the loaded model
i -= m_learnRateAdjustInterval;
}
}
if (learnRatePerSample < 1e-12)
{
LOGPRINTF(stderr, "learnRate per sample is reduced to %.8g which is below 1e-12. stop training.\n",
learnRatePerSample);
lrControlCriterion = vScore[0].Average(); // the first one is the training criterion
}
}
// --- END OF MAIN EPOCH LOOP
// Check if we need to save best model per criterion and this is the main node as well.
if (m_saveBestModelPerCriterion && ((m_mpi == nullptr) || m_mpi->IsMainNode()))
// broadcast epochCriterion to make sure each processor will have the same learning rate schedule
if ((GetParallelizationMethod() == ParallelizationMethod::modelAveragingSGD ||
GetParallelizationMethod() == ParallelizationMethod::blockMomentumSGD) &&
(m_mpi->NumNodesInUse() > 1))
{
// For each criterion copies the best epoch to the new file with criterion name appended.
CopyBestEpochs(m_criteriaBestEpoch, *this, m_maxEpochs - 1);
m_mpi->Bcast(&epochCriterion.first, 1, m_mpi->MainNodeRank());
m_mpi->Bcast(&epochCriterion.second, 1, m_mpi->MainNodeRank());
m_mpi->Bcast(&lrControlCriterion, 1, m_mpi->MainNodeRank());
}
bool loadedPrevModel = false;
size_t epochsSinceLastLearnRateAdjust = i % m_learnRateAdjustInterval + 1;
if (avgCriterion == numeric_limits<double>::infinity())
{
avgCriterion = lrControlCriterion;
}
else
{
avgCriterion = ((epochsSinceLastLearnRateAdjust - 1 - epochsNotCountedInAvgCriterion) *
avgCriterion +
lrControlCriterion) /
(epochsSinceLastLearnRateAdjust - epochsNotCountedInAvgCriterion);
}
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch &&
m_learningRatesParam.size() <= i && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval)
{
if (std::isnan(avgCriterion) || (prevCriterion - avgCriterion < 0 && prevCriterion != numeric_limits<double>::infinity()))
{
if (m_loadBestModel)
{
// roll back
auto bestModelPath = GetModelNameForEpoch(i - m_learnRateAdjustInterval);
LOGPRINTF(stderr, "Loading (rolling back to) previous model with best training-criterion value: %ls.\n", bestModelPath.c_str());
net->RereadPersistableParameters<ElemType>(bestModelPath);
LoadCheckPointInfo(i - m_learnRateAdjustInterval,
/*out*/ totalTrainingSamplesSeen,
/*out*/ learnRatePerSample,
smoothedGradients,
smoothedCounts,
/*out*/ prevCriterion,
/*out*/ m_prevChosenMinibatchSize);
loadedPrevModel = true;
}
}
if (m_continueReduce)
{
if (std::isnan(avgCriterion) ||
(prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion &&
prevCriterion != numeric_limits<double>::infinity()))
{
if (learnRateReduced == false)
{
learnRateReduced = true;
}
else
{
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
net->Save(GetModelNameForEpoch(i, true));
LOGPRINTF(stderr, "Finished training and saved final model\n\n");
break;
}
}
if (learnRateReduced)
{
learnRatePerSample *= m_learnRateDecreaseFactor;
LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample);
}
}
else
{
if (std::isnan(avgCriterion) ||
(prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion &&
prevCriterion != numeric_limits<double>::infinity()))
{
learnRatePerSample *= m_learnRateDecreaseFactor;
LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample);
}
else if (prevCriterion - avgCriterion > m_increaseLearnRateIfImproveMoreThan * prevCriterion &&
prevCriterion != numeric_limits<double>::infinity())
{
learnRatePerSample *= m_learnRateIncreaseFactor;
LOGPRINTF(stderr, "learnRatePerSample increased to %.8g\n", learnRatePerSample);
}
}
}
else
{
if (std::isnan(avgCriterion))
RuntimeError("The training criterion is not a number (NAN).");
}
// not loading previous values then set them
if (!loadedPrevModel && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval)
{
prevCriterion = avgCriterion;
epochsNotCountedInAvgCriterion = 0;
}
// Synchronize all ranks before proceeding to ensure that
// rank 0 has finished writing the model file
// TODO[DataASGD]: should othet other rank waiting in async-mode
// nobody tries reading the checkpoint file at the same time
// as rank 0 deleting it below
SynchronizeWorkers();
// progress tracing for compute cluster management
ProgressTracing::TraceProgressPercentage(m_maxEpochs, 0.0, true);
ProgressTracing::TraceTrainLoss(m_lastFinishedEpochTrainLoss);
// since we linked feature nodes. we need to remove it from the deletion
if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL && refNode != nullptr)
// Persist model and check-point info
if ((m_mpi == nullptr) || m_mpi->IsMainNode())
{
for (size_t i = 0; i < refFeatureNodes.size(); i++)
if (loadedPrevModel)
{
// note we need to handle deletion carefully
refNet->ReplaceNode(refFeatureNodes[i]->NodeName(), refFeatureNodes[i]);
// If previous best model is loaded, we will first remove epochs that lead to worse results
for (int j = 1; j < m_learnRateAdjustInterval; j++)
{
int epochToDelete = i - j;
LOGPRINTF(stderr, "SGD: removing model and checkpoint files for epoch %d after rollback to epoch %lu\n", epochToDelete + 1, (unsigned long) (i - m_learnRateAdjustInterval) + 1); // report 1 based epoch number
_wunlink(GetModelNameForEpoch(epochToDelete).c_str());
_wunlink(GetCheckPointFileNameForEpoch(epochToDelete).c_str());
}
// Set i back to the loaded model
i -= m_learnRateAdjustInterval;
LOGPRINTF(stderr, "SGD: revoke back to and update checkpoint file for epoch %d\n", i + 1); // report 1 based epoch number
SaveCheckPointInfo(
i,
totalTrainingSamplesSeen,
learnRatePerSample,
smoothedGradients,
smoothedCounts,
prevCriterion,
chosenMinibatchSize);
}
else
{
SaveCheckPointInfo(
i,
totalTrainingSamplesSeen,
learnRatePerSample,
smoothedGradients,
smoothedCounts,
prevCriterion,
chosenMinibatchSize);
auto modelName = GetModelNameForEpoch(i);
if (m_traceLevel > 0)
LOGPRINTF(stderr, "SGD: Saving checkpoint model '%ls'\n", modelName.c_str());
net->Save(modelName);
if (!m_keepCheckPointFiles)
{
// delete previous checkpoint file to save space
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && m_loadBestModel)
{
if (epochsSinceLastLearnRateAdjust != 1)
{
_wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str());
}
if (epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval)
{
_wunlink(GetCheckPointFileNameForEpoch(i - m_learnRateAdjustInterval).c_str());
}
}
else
{
_wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str());
}
}
}
}
else
{
if (loadedPrevModel)
{
// Set i back to the loaded model
i -= m_learnRateAdjustInterval;
}
}
delete inputMatrices;
if (m_parallelizationMethod == ParallelizationMethod::dataParallelASGD)
m_pASGDHelper.reset();
if (learnRatePerSample < 1e-12)
{
LOGPRINTF(stderr, "learnRate per sample is reduced to %.8g which is below 1e-12. stop training.\n",
learnRatePerSample);
}
}
// --- END OF MAIN EPOCH LOOP
// Check if we need to save best model per criterion and this is the main node as well.
if (m_saveBestModelPerCriterion && ((m_mpi == nullptr) || m_mpi->IsMainNode()))
{
// For each criterion copies the best epoch to the new file with criterion name appended.
CopyBestEpochs(m_criteriaBestEpoch, *this, m_maxEpochs - 1);
}
// Synchronize all ranks before proceeding to ensure that
// rank 0 has finished writing the model file
// TODO[DataASGD]: should othet other rank waiting in async-mode
SynchronizeWorkers();
// progress tracing for compute cluster management
ProgressTracing::TraceProgressPercentage(m_maxEpochs, 0.0, true);
ProgressTracing::TraceTrainLoss(m_lastFinishedEpochTrainLoss);
// since we linked feature nodes. we need to remove it from the deletion
if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL && refNode != nullptr)
{
for (size_t i = 0; i < refFeatureNodes.size(); i++)
{
// note we need to handle deletion carefully
refNet->ReplaceNode(refFeatureNodes[i]->NodeName(), refFeatureNodes[i]);
}
}
delete inputMatrices;
if (m_parallelizationMethod == ParallelizationMethod::dataParallelASGD)
m_pASGDHelper.reset();
} // namespace CNTK
// -----------------------------------------------------------------------
// TrainOneEpoch() -- train one epoch
@ -998,9 +1053,15 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
const size_t maxNumberOfSamples,
const size_t totalMBsSeenBefore,
::CNTK::Internal::TensorBoardFileWriterPtr tensorBoardWriter,
const int startEpoch)
const int startEpoch,
const std::vector<std::wstring>& outputNodeNamesVector,
StreamMinibatchInputs* encodeInputMatrices,
StreamMinibatchInputs* decodeinputMatrices)
{
PROFILE_SCOPE(profilerEvtMainEpoch);
//for schedule sampling
std::random_device rd;
std::mt19937_64 randGen{rd()};
ScopedNetworkOperationMode modeGuard(net, NetworkOperationMode::training);
@ -1165,6 +1226,14 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
bool noMoreSamplesToProcess = false;
bool isFirstMinibatch = true;
//RNNT TS
//RNNT TS
if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS && refNet)
{
}
//Microsoft::MSR::CNTK::StartProfiler();
for (;;)
{
@ -1234,6 +1303,26 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
dynamic_pointer_cast<ComputationNode<ElemType>>(labelNodes[0])->Value());
}
//RNNT TS
if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS && refNet)
{
auto reffeainput = (*encodeInputMatrices).begin();
auto feainput = (*inputMatrices).begin();
reffeainput->second.pMBLayout->CopyFrom(feainput->second.pMBLayout);
auto encodeMBLayout = feainput->second.pMBLayout;
reffeainput->second.GetMatrix<ElemType>().AssignRowSliceValuesOf(feainput->second.GetMatrix<ElemType>(), 0, 240);
auto lminput = (*decodeinputMatrices).begin();
auto decodeMBLayout = lminput->second.pMBLayout;
reffeainput->second.GetMatrix<ElemType>().AssignRowSliceValuesOf(feainput->second.GetMatrix<ElemType>(), 0, 240);
vector<vector<size_t>> outputlabels;
refNet->RNNT_decode_greedy(outputNodeNamesVector, reffeainput->second.GetMatrix<ElemType>(), *encodeMBLayout, reffeainput->second.GetMatrix<ElemType>(), *decodeMBLayout, outputlabels, 1.0);
//DataReaderHelpers::
//StreamBatch batch;
}
// do forward and back propagation
// We optionally break the minibatch into sub-minibatches.
@ -1256,7 +1345,6 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
// may be changed and need to be recomputed when gradient and function value share the same matrix
net->ForwardProp(forwardPropRoots); // the bulk of this evaluation is reused in ComputeGradient() below
// ===========================================================
// backprop
// ===========================================================
@ -1401,7 +1489,6 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
double p_norm = dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Gradient().FrobeniusNorm();
//long m = (long) GetNumElements();
totalNorm += p_norm * p_norm;
}
}
totalNorm = sqrt(totalNorm);
@ -1807,7 +1894,6 @@ bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
net->ForwardProp(nodes);
numItersSinceLastPrintOfProgress = ProgressTracing::TraceFakeProgress(numIterationsBeforePrintingProgress, numItersSinceLastPrintOfProgress);
}
// finalize
@ -2218,12 +2304,12 @@ void SGD<ElemType>::TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net,
std::string prefixMsg,
const size_t maxNumOfSamples)
{
TrainOneEpoch(net, refNet, refNode, epochNumber, epochSize,
trainSetDataReader, learnRatePerSample, minibatchSize, featureNodes,
labelNodes, criterionNodes, evaluationNodes,
inputMatrices, learnableNodes, smoothedGradients, smoothedCounts,
/*out*/ epochCriterion, /*out*/ epochEvalErrors,
" " + prefixMsg, maxNumOfSamples); // indent log msg by 2 (that is 1 more than the Finished message below)
//TrainOneEpoch(net, refNet, refNode, epochNumber, epochSize,
// trainSetDataReader, learnRatePerSample, minibatchSize, featureNodes,
// labelNodes, criterionNodes, evaluationNodes,
// inputMatrices, learnableNodes, smoothedGradients, smoothedCounts,
// /*out*/ epochCriterion, /*out*/ epochEvalErrors,
// " " + prefixMsg, maxNumOfSamples); // indent log msg by 2 (that is 1 more than the Finished message below)
LOGPRINTF(stderr, " Finished Mini-Epoch[%d]: ", (int) epochNumber + 1);
epochCriterion.LogCriterion(criterionNodes[0]->NodeName());
@ -2856,6 +2942,8 @@ static AdaptationRegType ParseAdaptationRegType(const wstring& s)
return AdaptationRegType::None;
else if (EqualCI(s, L"kl") || EqualCI(s, L"klReg"))
return AdaptationRegType::KL;
else if (EqualCI(s, L"ts"))
return AdaptationRegType::TS;
else
InvalidArgument("ParseAdaptationRegType: Invalid Adaptation Regularization Type. Valid values are (none | kl)");
}
@ -3069,6 +3157,9 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_doGradientCheck = configSGD(L"gradientcheck", false);
m_gradientCheckSigDigit = configSGD(L"sigFigs", 6.0); // TODO: why is this a double?
//RNNT TS
m_outputNodeNames = configSGD(L"OutputNodeNames", ConfigArray(""));
if (m_doGradientCheck && sizeofElemType != sizeof(double))
{
LogicError("Gradient check needs to use precision = 'double'.");
@ -3374,6 +3465,6 @@ void SGDParams::InitializeAndCheckBlockMomentumSGDParameters()
// register SGD<> with the ScriptableObject system
ScriptableObjects::ConfigurableRuntimeTypeRegister::AddFloatDouble<SGD<float>, SGD<double>> registerSGDOptimizer(L"SGDOptimizer");
} // namespace CNTK
} // namespace MSR
} // namespace Microsoft
} // namespace Microsoft

Просмотреть файл

@ -47,7 +47,8 @@ enum class LearningRateSearchAlgorithm : int
enum class AdaptationRegType : int
{
None,
KL
KL,
TS
};
enum class GradientsUpdateType : int
@ -266,6 +267,7 @@ protected:
bool m_doGradientCheck;
double m_gradientCheckSigDigit;
ConfigArray m_outputNodeNames;
bool m_doUnitTest;
bool m_useAllDataForPreComputedNode;
@ -386,7 +388,7 @@ public:
void Train(shared_ptr<ComputationNetwork> net, DEVICEID_TYPE deviceId,
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader, int startEpoch, bool loadNetworkFromCheckpoint);
void Adapt(wstring origModelFileName, wstring refNodeName,
void Adapt(shared_ptr<ComputationNetwork> net, bool networkLoadedFromCheckpoint, wstring origModelFileName, wstring refNodeName,
IDataReader* trainSetDataReader,
IDataReader* validationSetDataReader,
const DEVICEID_TYPE deviceID, const bool makeMode = true);
@ -512,7 +514,9 @@ protected:
const size_t maxNumberOfSamples = SIZE_MAX,
const size_t totalMBsSeenBefore = 0,
::CNTK::Internal::TensorBoardFileWriterPtr tensorBoardWriter = nullptr,
const int startEpoch = 0);
const int startEpoch = 0, const std::vector<std::wstring>& outputNodeNamesVector = <>,
StreamMinibatchInputs* encodeInputMatrices = NULL,
StreamMinibatchInputs* decodeinputMatrices = NULL);
void InitDistGradAgg(int numEvalNodes, int numGradientBits, int deviceId, int traceLevel);
void InitModelAggregationHandler(int traceLevel, DEVICEID_TYPE devID);