From 9800670b2bf5f244ed83c8c40e38aa2760e65e8b Mon Sep 17 00:00:00 2001 From: "Rui Zhao (SPEECH)" Date: Fri, 4 Oct 2019 17:03:47 -0700 Subject: [PATCH] TS v1 --- Source/ActionsLib/TrainActions.cpp | 28 +- .../ComputationNetwork.h | 209 ++++++ .../LinearAlgebraNodes.h | 2 +- Source/SGDLib/SGD.cpp | 677 ++++++++++-------- Source/SGDLib/SGD.h | 10 +- 5 files changed, 628 insertions(+), 298 deletions(-) diff --git a/Source/ActionsLib/TrainActions.cpp b/Source/ActionsLib/TrainActions.cpp index e3b75654e..332c86aae 100644 --- a/Source/ActionsLib/TrainActions.cpp +++ b/Source/ActionsLib/TrainActions.cpp @@ -194,8 +194,34 @@ void DoAdapt(const ConfigParameters& config) SGD sgd(configSGD); + + sgd.InitMPI(MPIWrapper::GetInstance()); - sgd.Adapt(origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode); + + //for RNNT TS + int startEpoch = sgd.DetermineStartEpoch(makeMode); + if (startEpoch == sgd.GetMaxEpochs()) + { + LOGPRINTF(stderr, "No further training is necessary.\n"); + return; + } + wstring modelFileName =sgd.GetModelNameForEpoch(int(startEpoch) - 1); + bool loadNetworkFromCheckpoint = startEpoch >= 0; + if (loadNetworkFromCheckpoint) + LOGPRINTF(stderr, "\nStarting from checkpoint. Loading network from '%ls'.\n", modelFileName.c_str()); + else + LOGPRINTF(stderr, "\nCreating virgin network.\n"); + + // determine the network-creation function + // We have several ways to create that network. + function createNetworkFn; + + createNetworkFn = GetNetworkFactory(config); + + // create or load from checkpoint + shared_ptr net = !loadNetworkFromCheckpoint ? createNetworkFn(deviceId) : ComputationNetwork::CreateFromFile(deviceId, modelFileName); + + sgd.Adapt(net, loadNetworkFromCheckpoint, origModelFileName, refNodeName, dataReader.get(), cvDataReader.get(), deviceId, makeMode); } template void DoAdapt(const ConfigParameters& config); diff --git a/Source/ComputationNetworkLib/ComputationNetwork.h b/Source/ComputationNetworkLib/ComputationNetwork.h index 0ce7bc565..c53ad7dac 100644 --- a/Source/ComputationNetworkLib/ComputationNetwork.h +++ b/Source/ComputationNetworkLib/ComputationNetwork.h @@ -220,6 +220,215 @@ public: } } + + //decoding for RNNT + template + void RNNT_decode_greedy(const std::vector& outputNodeNames, Matrix& encodeInputMatrix, MBLayout& encodeMBLayout, + Matrix& decodeInputMatrix, MBLayout& decodeMBLayout,vector> &outputlabels, float groundTruthWeight /*mt19937_64 randGen*/) + { + if (outputNodeNames.size() == 0) + fprintf(stderr, "OutputNodeNames are not specified, using the default outputnodes.\n"); + std::vector outputNodes = OutputNodesByName(outputNodeNames); + //AllocateAllMatrices({}, outputNodes, nullptr); + + //encoder related nodes + std::vector encodeOutputNodeNames(outputNodeNames.begin(), outputNodeNames.begin() + 1); + std::vector encodeOutputNodes = OutputNodesByName(encodeOutputNodeNames); + std::vector encodeInputNodes = InputNodesForOutputs(encodeOutputNodeNames); + //StreamMinibatchInputs encodeInputMatrices = DataReaderHelpers::RetrieveInputMatrices(encodeInputNodes); + StartEvaluateMinibatchLoop(encodeOutputNodes[0]); + + //prediction related nodes + std::vector decodeOutputNodeNames(outputNodeNames.begin() + 1, outputNodeNames.begin() + 2); + std::vector decodeOutputNodes = OutputNodesByName(decodeOutputNodeNames); + std::vector decodeinputNodes = InputNodesForOutputs(decodeOutputNodeNames); + //StreamMinibatchInputs decodeinputMatrices = DataReaderHelpers::RetrieveInputMatrices(decodeinputNodes); + + //joint nodes + ComputationNodeBasePtr PlusNode = GetNodeFromName(outputNodeNames[2]); + ComputationNodeBasePtr PlusTransNode = GetNodeFromName(outputNodeNames[3]); + ComputationNodeBasePtr WmNode = GetNodeFromName(outputNodeNames[4]); + ComputationNodeBasePtr bmNode = GetNodeFromName(outputNodeNames[5]); + std::vector Plusnodes, Plustransnodes; + Plusnodes.push_back(PlusNode); + Plustransnodes.push_back(PlusTransNode); + + //start eval + StartEvaluateMinibatchLoop(decodeOutputNodes[0]); + //auto lminput = decodeinputMatrices.begin(); + size_t deviceid = decodeInputMatrix.GetDeviceId(); + std::map outputMatrices; + Matrix encodeOutput(deviceid); + Matrix decodeOutput(deviceid), Wm(deviceid), bm(deviceid), tempMatrix(deviceid); + Matrix greedyOutput(deviceid), greedyOutputMax(deviceid); + Matrix sumofENandDE(deviceid), maxIdx(deviceid), maxVal(deviceid); + Matrix lmin(deviceid); + Wm.SetValue(*(&dynamic_pointer_cast>(WmNode)->Value())); + bm.SetValue(*(&dynamic_pointer_cast>(bmNode)->Value())); + const size_t numIterationsBeforePrintingProgress = 100; + + //get MBlayer of encoder input + size_t numParallelSequences = encodeMBLayout.GetNumParallelSequences(); + const auto numSequences = encodeMBLayout.GetNumSequences(); + //get frame number, phone number and output label number + const size_t numRows = encodeInputMatrix.GetNumRows(); + const size_t numCols = encodeInputMatrix.GetNumCols(); + + //size_t maxFrameNum = numCols / numParallelSequences; + + std::vector uttFrameBeginIdx; + // the frame number of each utterance. The size of this vector = the number of all utterances in this minibatch + std::vector uttFrameNum; + // map from utterance ID to minibatch channel ID. We need this because each channel may contain more than one utterance. + std::vector uttFrameToChanInd; + //size_t totalcol = 0; + + uttFrameNum.clear(); + uttFrameToChanInd.clear(); + uttFrameBeginIdx.clear(); + + uttFrameNum.reserve(numSequences); + uttFrameToChanInd.reserve(numSequences); + uttFrameBeginIdx.reserve(numSequences); + + //get utt information, such as channel map id and utt begin frame, utt frame num, utt phone num for frame and phone respectively.... + size_t seqId = 0; //frame + size_t totalframenum = 0; + for (const auto& seq : encodeMBLayout.GetAllSequences()) + { + if (seq.seqId == GAP_SEQUENCE_ID) + { + continue; + } + assert(seq.seqId == seqId); + seqId++; + uttFrameToChanInd.push_back(seq.s); + size_t numFrames = seq.GetNumTimeSteps(); + uttFrameBeginIdx.push_back(seq.tBegin); + uttFrameNum.push_back(numFrames); + totalframenum += numFrames; + } + + //resize output + outputlabels.resize(numSequences); + + // forward prop encoder + ComputationNetwork::BumpEvalTimeStamp(encodeInputNodes); + ForwardProp(encodeOutputNodes[0]); + encodeOutput.SetValue(*(&dynamic_pointer_cast>(encodeOutputNodes[0])->Value())); + + size_t vocabSize = bm.GetNumRows(); + size_t blankId = vocabSize - 1; + + for (size_t uttID = 0; uttID < numSequences; uttID ++) + { + + lmin.Resize(vocabSize, 1); + lmin.SetValue(0.0); + lmin(blankId, 0) = 1; + decodeMBLayout.Init(1, 1); + std::swap(decodeInputMatrix, lmin); + decodeMBLayout.AddSequence(NEW_SEQUENCE_ID, 0, 0, 2000); + ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes); + ForwardProp(decodeOutputNodes[0]); + greedyOutputMax.Resize(vocabSize, 2000); + size_t lmt = 0; + for (size_t t = 0; t < uttFrameNum[uttID]; t++) + { + + decodeOutput.SetValue(*(&dynamic_pointer_cast>(decodeOutputNodes[0])->Value())); + //auto edNode = PlusNode->As>(); + size_t tinMB = (t + uttFrameBeginIdx[uttID]) * numParallelSequences + uttFrameToChanInd[uttID]; + sumofENandDE.AssignSumOf(encodeOutput.ColumnSlice(tinMB, 1), decodeOutput); + + + //sumofENandDE.AssignSumOf(encodeOutput.ColumnSlice(t, 1), decodeOutput); + (&dynamic_pointer_cast>(PlusNode)->Value())->SetValue(sumofENandDE); + ComputationNetwork::BumpEvalTimeStamp(Plusnodes); + auto PlusMBlayout = PlusNode->GetMBLayout(); + PlusMBlayout->Init(1, 1); + PlusMBlayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, 1); + ForwardPropFromTo(Plusnodes, Plustransnodes); + decodeOutput.SetValue(*(&dynamic_pointer_cast>(PlusTransNode)->Value())); + tempMatrix.AssignProductOf(Wm, true, decodeOutput, false); + decodeOutput.AssignSumOf(tempMatrix, bm); + decodeOutput.VectorMax(maxIdx, maxVal, true); + size_t maxId = (size_t)(maxIdx.Get00Element()); + if (maxId != blankId) + { + outputlabels[uttID].push_back(maxId); + lmin.Resize(vocabSize, 1); + lmin.SetValue(0.0); + lmin(maxId, 0) = 1.0; + greedyOutputMax.SetColumn(lmin, lmt); + std::swap(decodeInputMatrix, lmin); + decodeMBLayout.Init(1, 1); + decodeMBLayout.AddSequence(NEW_SEQUENCE_ID, 0, -1 - lmt, 1999 - lmt); + ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes); + //DataReaderHelpers::NotifyChangedNodes(m_net, decodeinputMatrices); + ForwardProp(decodeOutputNodes[0]); + lmt++; + } + } + + if (lmt == 0) + { + outputlabels[uttID].push_back(blankId); + } + //break; + } + + //decode + + //make new MBLayout for decoder input + MBLayoutPtr newdecodeMBLayout = make_shared(); + std::vector> placement; + std::vector sequences; + for (size_t i = 0; i < outputlabels.size(); ++i) + sequences.push_back({i, SIZE_MAX, 0, outputlabels[i].size()}); + + std::vector rowAllocations; + newdecodeMBLayout->InitAsPackedSequences(sequences, placement, rowAllocations); + + decodeInputMatrix.Resize(vocabSize, newdecodeMBLayout->GetNumCols() * newdecodeMBLayout->GetNumParallelSequences()); + decodeInputMatrix.SetValue(0.0f); + //fill the decoder input + const auto& sequenceInfos = newdecodeMBLayout->GetAllSequences(); + for (int i = 0; i < sequenceInfos.size(); ++i) + { + const auto& sequenceInfo = sequenceInfos[i]; + // skip gaps + if (sequenceInfo.seqId == GAP_SEQUENCE_ID) + { + continue; + } + + //const auto& sequence = batch[sequenceInfo.seqId]; + size_t numSamples = outputlabels[i].size(); + assert(numSamples == sequenceInfo.GetNumTimeSteps()); + + + // Iterate over all samples in the sequence, keep track of the sample offset (which is especially + // important for sparse input, where offset == number of preceding nnz elements). + for (size_t sampleIndex = 0;sampleIndex < numSamples; ++sampleIndex) + { + // Compute the offset into the destination buffer, using the layout information + // to get the column index corresponding to the given sample. + size_t destinationOffset = newdecodeMBLayout->GetColumnIndex(sequenceInfo, sampleIndex) ; + // verify that there's enough space left in the buffer to fit a full sample. + //assert(destinationOffset <= buffer.m_size - sampleSize); + //auto* destination = bufferPtr + destinationOffset; + decodeInputMatrix.SetValue(outputlabels[i][sampleIndex], destinationOffset, 1.0f); + } + } + + //copy the new MBLayout + decodeMBLayout.CopyFrom(newdecodeMBLayout); + + // clean up + } + + static void BumpEvalTimeStamp(const std::vector& nodes); void ResetEvalTimeStamps(); void SetEvalTimeStampsOutdatedWithRegardToAll(); diff --git a/Source/ComputationNetworkLib/LinearAlgebraNodes.h b/Source/ComputationNetworkLib/LinearAlgebraNodes.h index 1e380c60b..178d53727 100755 --- a/Source/ComputationNetworkLib/LinearAlgebraNodes.h +++ b/Source/ComputationNetworkLib/LinearAlgebraNodes.h @@ -220,7 +220,7 @@ public: virtual void Save(File& fstream) const override { Base::Save(fstream); - fstream << m_combineMode; + //fstream << m_combineMode; } virtual void Load(File& fstream, size_t modelVersion) override diff --git a/Source/SGDLib/SGD.cpp b/Source/SGDLib/SGD.cpp index ea3015229..9a8b320df 100644 --- a/Source/SGDLib/SGD.cpp +++ b/Source/SGDLib/SGD.cpp @@ -90,7 +90,7 @@ void SGD::Train(shared_ptr net, DEVICEID_TYPE devi // ----------------------------------------------------------------------- template -void SGD::Adapt(wstring origModelFileName, wstring refNodeName, +void SGD::Adapt(shared_ptr net, bool networkLoadedFromCheckpoint, wstring origModelFileName, wstring refNodeName, IDataReader* trainSetDataReader, IDataReader* validationSetDataReader, const DEVICEID_TYPE deviceId, const bool makeMode) @@ -102,7 +102,7 @@ void SGD::Adapt(wstring origModelFileName, wstring refNodeName, return; } - ComputationNetworkPtr net; + /*ComputationNetworkPtr net; bool networkLoadedFromCheckpoint = false; if (startEpoch >= 0) { @@ -115,17 +115,22 @@ void SGD::Adapt(wstring origModelFileName, wstring refNodeName, { LOGPRINTF(stderr, "Load Network From the original model file %ls.\n", origModelFileName.c_str()); net = ComputationNetwork::CreateFromFile(deviceId, origModelFileName); - } + }*/ startEpoch = max(startEpoch, 0); ComputationNetworkPtr refNet; m_needAdaptRegularization = m_adaptationRegType != AdaptationRegType::None && m_adaptationRegWeight > 0; - if (m_needAdaptRegularization) + if (m_needAdaptRegularization && m_adaptationRegType != AdaptationRegType::TS) { LOGPRINTF(stderr, "Load reference Network From the original model file %ls.\n", origModelFileName.c_str()); refNet = ComputationNetwork::CreateFromFile(deviceId, origModelFileName); } + else if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS) + { + refNet = make_shared(deviceId); + refNet->Read(origModelFileName); + } ComputationNodeBasePtr refNode; if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL) @@ -649,39 +654,89 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, EpochCriterion epochCriterion; // criterion values are returned in this std::vector epochEvalErrors(evaluationNodes.size()); - totalMBsSeen += TrainOneEpoch(net, - refNet, - refNode, - i, - m_epochSize, - trainSetDataReader, - learnRatePerSample, - chosenMinibatchSize, - featureNodes, - labelNodes, - criterionNodes, - evaluationNodes, - inputMatrices, - learnableNodes, smoothedGradients, smoothedCounts, - epochCriterion, epochEvalErrors, - "", SIZE_MAX, totalMBsSeen, tensorBoardWriter, startEpoch); - totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only - timer.Stop(); - double epochTime = timer.ElapsedSeconds(); + //RNNT TS + StreamMinibatchInputs decodeinputMatrices, encodeInputMatrices; + vector outputNodeNamesVector; + if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS && refNet) + { + + + if (m_outputNodeNames.size() > 0) + { + // clear out current list of outputNodes + while (!refNet->OutputNodes().empty()) + refNet->RemoveFromNodeGroup(L"output", refNet->OutputNodes().back()); + // and insert the desired nodes instead + for (wstring name : m_outputNodeNames) + { + if (!refNet->NodeNameExists(name)) + { + fprintf(stderr, "PatchOutputNodes: No node named '%ls'; skipping\n", name.c_str()); + continue; + } + outputNodeNamesVector.push_back(name); + let& node = refNet->GetNodeFromName(name); + refNet->AddToNodeGroup(L"output", node); + } + } + refNet->CompileNetwork(); + + //PatchOutputNodes(refNet, m_outputNodeNames, outputNodeNamesVector); + - if (m_useEvalCriterionControlLR && epochEvalErrors.size() > 0) - lrControlCriterion = epochEvalErrors[0].Average(); - else - lrControlCriterion = epochCriterion.Average(); + std::vector outputNodes = refNet->OutputNodesByName(outputNodeNamesVector); + refNet->AllocateAllMatrices({}, outputNodes, nullptr); - LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Training] ", i + 1, (int) m_maxEpochs); - epochCriterion.LogCriterion(criterionNodes[0]->NodeName()); + std::vector encodeOutputNodeNames(outputNodeNamesVector.begin(), outputNodeNamesVector.begin() + 1); + std::vector encodeOutputNodes = refNet->OutputNodesByName(encodeOutputNodeNames); + std::vector encodeInputNodes = refNet->InputNodesForOutputs(encodeOutputNodeNames); + encodeInputMatrices = DataReaderHelpers::RetrieveInputMatrices(encodeInputNodes); - m_lastFinishedEpochTrainLoss = epochCriterion.Average(); - for (size_t j = 0; j < epochEvalErrors.size(); j++) - epochEvalErrors[j].LogCriterion(evaluationNodes[j]->NodeName()); - fprintf(stderr, "totalSamplesSeen = %zu; learningRatePerSample = %.8g; epochTime=%.6gs\n", totalTrainingSamplesSeen, learnRatePerSample, epochTime); + //get decode input matrix + std::vector decodeOutputNodeNames(outputNodeNamesVector.begin() + 1, outputNodeNamesVector.begin() + 2); + std::vector decodeOutputNodes = refNet->OutputNodesByName(decodeOutputNodeNames); + std::vector decodeinputNodes = refNet->InputNodesForOutputs(decodeOutputNodeNames); + decodeinputMatrices = DataReaderHelpers::RetrieveInputMatrices(decodeinputNodes); + + //DataReaderHelpers:: + + //StreamBatch batch; + } + totalMBsSeen += TrainOneEpoch(net, + refNet, + refNode, + i, + m_epochSize, + trainSetDataReader, + learnRatePerSample, + chosenMinibatchSize, + featureNodes, + labelNodes, + criterionNodes, + evaluationNodes, + inputMatrices, + learnableNodes, smoothedGradients, smoothedCounts, + epochCriterion, epochEvalErrors, + "", SIZE_MAX, totalMBsSeen, tensorBoardWriter, startEpoch, + outputNodeNamesVector, &encodeInputMatrices, &decodeinputMatrices); + totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only + + timer.Stop(); + double epochTime = timer.ElapsedSeconds(); + + if (m_useEvalCriterionControlLR && epochEvalErrors.size() > 0) + lrControlCriterion = epochEvalErrors[0].Average(); + else + lrControlCriterion = epochCriterion.Average(); + + LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Training] ", i + 1, (int) m_maxEpochs); + epochCriterion.LogCriterion(criterionNodes[0]->NodeName()); + + m_lastFinishedEpochTrainLoss = epochCriterion.Average(); + for (size_t j = 0; j < epochEvalErrors.size(); j++) + epochEvalErrors[j].LogCriterion(evaluationNodes[j]->NodeName()); + fprintf(stderr, "totalSamplesSeen = %zu; learningRatePerSample = %.8g; epochTime=%.6gs\n", totalTrainingSamplesSeen, learnRatePerSample, epochTime); #if 0 // TODO: This was only printed if >1 eval criterion. Why? Needed? LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: Criterion Node [%ls] Per Sample = %.8g\n", @@ -694,283 +749,283 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, } #endif + if (tensorBoardWriter) + { + tensorBoardWriter->WriteValue(L"summary/" + criterionNodes[0]->NodeName(), (float) epochCriterion.Average(), i + 1); + for (size_t j = 0; j < epochEvalErrors.size(); j++) + { + tensorBoardWriter->WriteValue(L"summary/" + evaluationNodes[0]->NodeName(), (float) epochEvalErrors[j].Average(), i + 1); + } + + tensorBoardWriter->Flush(); + } + + if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr) + { + // TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD, since Multiverso has another background thread using MPI. + // Making the evaluation serial (non-distributed) will slowdown training especially when validation set is large. + SimpleEvaluator evalforvalidation(net, UsingAsyncGradientAggregation(i + 1) ? nullptr : m_mpi, m_enableDistributedMBReading); + vector cvSetTrainAndEvalNodes; + if (criterionNodes.size() > 0) + { + cvSetTrainAndEvalNodes.push_back(criterionNodes[0]->NodeName()); + } + for (let node : evaluationNodes) + { + cvSetTrainAndEvalNodes.push_back(node->NodeName()); + } + + // BUGBUG: We should not use the training MB size. The training MB size is constrained by both convergence and memory. Eval is only constrained by memory. + let vScore = evalforvalidation.Evaluate(validationSetDataReader, cvSetTrainAndEvalNodes, UsingAsyncGradientAggregation(i + 1) ? m_mbSize[i] / m_mpi->NumNodesInUse() : m_mbSize[i]); + LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Validate] ", i + 1, (int) m_maxEpochs); + for (size_t k = 0; k < vScore.size() /*&& k < 2*/; k++) + vScore[k].LogCriterion(cvSetTrainAndEvalNodes[k], /*addSemicolon=*/k + 1 < vScore.size()); + //fprintf(stderr, "%s %ls = %.8f * %d", k ? ";" : "", cvSetTrainAndEvalNodes[k].c_str(), vScore[k].Average(), (int)vScore[k].second); + fprintf(stderr, "\n"); + if (tensorBoardWriter) { - tensorBoardWriter->WriteValue(L"summary/" + criterionNodes[0]->NodeName(), (float) epochCriterion.Average(), i + 1); - for (size_t j = 0; j < epochEvalErrors.size(); j++) + for (size_t k = 0; k < vScore.size(); k++) { - tensorBoardWriter->WriteValue(L"summary/" + evaluationNodes[0]->NodeName(), (float) epochEvalErrors[j].Average(), i + 1); + tensorBoardWriter->WriteValue(L"summary/test_" + cvSetTrainAndEvalNodes[k], (float) vScore[k].Average(), i + 1); } tensorBoardWriter->Flush(); } - if (validationSetDataReader != trainSetDataReader && validationSetDataReader != nullptr) + if (m_saveBestModelPerCriterion) { - // TODO(dataASGD) making evaluator becoming nondistributed one when using ASGD, since Multiverso has another background thread using MPI. - // Making the evaluation serial (non-distributed) will slowdown training especially when validation set is large. - SimpleEvaluator evalforvalidation(net, UsingAsyncGradientAggregation(i + 1) ? nullptr : m_mpi, m_enableDistributedMBReading); - vector cvSetTrainAndEvalNodes; - if (criterionNodes.size() > 0) - { - cvSetTrainAndEvalNodes.push_back(criterionNodes[0]->NodeName()); - } - for (let node : evaluationNodes) - { - cvSetTrainAndEvalNodes.push_back(node->NodeName()); - } - - // BUGBUG: We should not use the training MB size. The training MB size is constrained by both convergence and memory. Eval is only constrained by memory. - let vScore = evalforvalidation.Evaluate(validationSetDataReader, cvSetTrainAndEvalNodes, UsingAsyncGradientAggregation(i + 1) ? m_mbSize[i] / m_mpi->NumNodesInUse() : m_mbSize[i]); - LOGPRINTF(stderr, "Finished Epoch[%2d of %d]: [Validate] ", i + 1, (int) m_maxEpochs); - for (size_t k = 0; k < vScore.size() /*&& k < 2*/; k++) - vScore[k].LogCriterion(cvSetTrainAndEvalNodes[k], /*addSemicolon=*/k + 1 < vScore.size()); - //fprintf(stderr, "%s %ls = %.8f * %d", k ? ";" : "", cvSetTrainAndEvalNodes[k].c_str(), vScore[k].Average(), (int)vScore[k].second); - fprintf(stderr, "\n"); - - if (tensorBoardWriter) - { - for (size_t k = 0; k < vScore.size(); k++) - { - tensorBoardWriter->WriteValue(L"summary/test_" + cvSetTrainAndEvalNodes[k], (float) vScore[k].Average(), i + 1); - } - - tensorBoardWriter->Flush(); - } - - if (m_saveBestModelPerCriterion) - { - // Loops through criteria (i.e. score) and updates the best one if smaller value is found. - UpdateBestEpochs(vScore, cvSetTrainAndEvalNodes, i, m_criteriaBestEpoch); - } - - if (m_useCVSetControlLRIfCVExists) - { - if (m_useEvalCriterionControlLR && vScore.size() > 1) - lrControlCriterion = vScore[1].Average(); // use the first of possibly multiple eval criteria - else - lrControlCriterion = vScore[0].Average(); // the first one is the training criterion - } + // Loops through criteria (i.e. score) and updates the best one if smaller value is found. + UpdateBestEpochs(vScore, cvSetTrainAndEvalNodes, i, m_criteriaBestEpoch); } - // broadcast epochCriterion to make sure each processor will have the same learning rate schedule - if ((GetParallelizationMethod() == ParallelizationMethod::modelAveragingSGD || - GetParallelizationMethod() == ParallelizationMethod::blockMomentumSGD) && - (m_mpi->NumNodesInUse() > 1)) + if (m_useCVSetControlLRIfCVExists) { - m_mpi->Bcast(&epochCriterion.first, 1, m_mpi->MainNodeRank()); - m_mpi->Bcast(&epochCriterion.second, 1, m_mpi->MainNodeRank()); - m_mpi->Bcast(&lrControlCriterion, 1, m_mpi->MainNodeRank()); - } - - bool loadedPrevModel = false; - size_t epochsSinceLastLearnRateAdjust = i % m_learnRateAdjustInterval + 1; - if (avgCriterion == numeric_limits::infinity()) - { - avgCriterion = lrControlCriterion; - } - else - { - avgCriterion = ((epochsSinceLastLearnRateAdjust - 1 - epochsNotCountedInAvgCriterion) * - avgCriterion + - lrControlCriterion) / - (epochsSinceLastLearnRateAdjust - epochsNotCountedInAvgCriterion); - } - - if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && - m_learningRatesParam.size() <= i && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) - { - if (std::isnan(avgCriterion) || (prevCriterion - avgCriterion < 0 && prevCriterion != numeric_limits::infinity())) - { - if (m_loadBestModel) - { - // roll back - auto bestModelPath = GetModelNameForEpoch(i - m_learnRateAdjustInterval); - LOGPRINTF(stderr, "Loading (rolling back to) previous model with best training-criterion value: %ls.\n", bestModelPath.c_str()); - net->RereadPersistableParameters(bestModelPath); - LoadCheckPointInfo(i - m_learnRateAdjustInterval, - /*out*/ totalTrainingSamplesSeen, - /*out*/ learnRatePerSample, - smoothedGradients, - smoothedCounts, - /*out*/ prevCriterion, - /*out*/ m_prevChosenMinibatchSize); - loadedPrevModel = true; - } - } - - if (m_continueReduce) - { - if (std::isnan(avgCriterion) || - (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && - prevCriterion != numeric_limits::infinity())) - { - if (learnRateReduced == false) - { - learnRateReduced = true; - } - else - { - // In case of parallel training only the main node should we saving the model to prevent - // the parallel training nodes from colliding to write the same file - if ((m_mpi == nullptr) || m_mpi->IsMainNode()) - net->Save(GetModelNameForEpoch(i, true)); - - LOGPRINTF(stderr, "Finished training and saved final model\n\n"); - break; - } - } - - if (learnRateReduced) - { - learnRatePerSample *= m_learnRateDecreaseFactor; - LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample); - } - } + if (m_useEvalCriterionControlLR && vScore.size() > 1) + lrControlCriterion = vScore[1].Average(); // use the first of possibly multiple eval criteria else - { - if (std::isnan(avgCriterion) || - (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && - prevCriterion != numeric_limits::infinity())) - { - - learnRatePerSample *= m_learnRateDecreaseFactor; - LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample); - } - else if (prevCriterion - avgCriterion > m_increaseLearnRateIfImproveMoreThan * prevCriterion && - prevCriterion != numeric_limits::infinity()) - { - learnRatePerSample *= m_learnRateIncreaseFactor; - LOGPRINTF(stderr, "learnRatePerSample increased to %.8g\n", learnRatePerSample); - } - } - } - else - { - if (std::isnan(avgCriterion)) - RuntimeError("The training criterion is not a number (NAN)."); - } - - // not loading previous values then set them - if (!loadedPrevModel && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) - { - prevCriterion = avgCriterion; - epochsNotCountedInAvgCriterion = 0; - } - - // Synchronize all ranks before proceeding to ensure that - // nobody tries reading the checkpoint file at the same time - // as rank 0 deleting it below - SynchronizeWorkers(); - - // Persist model and check-point info - if ((m_mpi == nullptr) || m_mpi->IsMainNode()) - { - if (loadedPrevModel) - { - // If previous best model is loaded, we will first remove epochs that lead to worse results - for (int j = 1; j < m_learnRateAdjustInterval; j++) - { - int epochToDelete = i - j; - LOGPRINTF(stderr, "SGD: removing model and checkpoint files for epoch %d after rollback to epoch %lu\n", epochToDelete + 1, (unsigned long) (i - m_learnRateAdjustInterval) + 1); // report 1 based epoch number - _wunlink(GetModelNameForEpoch(epochToDelete).c_str()); - _wunlink(GetCheckPointFileNameForEpoch(epochToDelete).c_str()); - } - - // Set i back to the loaded model - i -= m_learnRateAdjustInterval; - LOGPRINTF(stderr, "SGD: revoke back to and update checkpoint file for epoch %d\n", i + 1); // report 1 based epoch number - SaveCheckPointInfo( - i, - totalTrainingSamplesSeen, - learnRatePerSample, - smoothedGradients, - smoothedCounts, - prevCriterion, - chosenMinibatchSize); - } - else - { - SaveCheckPointInfo( - i, - totalTrainingSamplesSeen, - learnRatePerSample, - smoothedGradients, - smoothedCounts, - prevCriterion, - chosenMinibatchSize); - auto modelName = GetModelNameForEpoch(i); - if (m_traceLevel > 0) - LOGPRINTF(stderr, "SGD: Saving checkpoint model '%ls'\n", modelName.c_str()); - net->Save(modelName); - if (!m_keepCheckPointFiles) - { - // delete previous checkpoint file to save space - if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && m_loadBestModel) - { - if (epochsSinceLastLearnRateAdjust != 1) - { - _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); - } - if (epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) - { - _wunlink(GetCheckPointFileNameForEpoch(i - m_learnRateAdjustInterval).c_str()); - } - } - else - { - _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); - } - } - } - } - else - { - if (loadedPrevModel) - { - // Set i back to the loaded model - i -= m_learnRateAdjustInterval; - } - } - - if (learnRatePerSample < 1e-12) - { - LOGPRINTF(stderr, "learnRate per sample is reduced to %.8g which is below 1e-12. stop training.\n", - learnRatePerSample); + lrControlCriterion = vScore[0].Average(); // the first one is the training criterion } } - // --- END OF MAIN EPOCH LOOP - // Check if we need to save best model per criterion and this is the main node as well. - if (m_saveBestModelPerCriterion && ((m_mpi == nullptr) || m_mpi->IsMainNode())) + // broadcast epochCriterion to make sure each processor will have the same learning rate schedule + if ((GetParallelizationMethod() == ParallelizationMethod::modelAveragingSGD || + GetParallelizationMethod() == ParallelizationMethod::blockMomentumSGD) && + (m_mpi->NumNodesInUse() > 1)) { - // For each criterion copies the best epoch to the new file with criterion name appended. - CopyBestEpochs(m_criteriaBestEpoch, *this, m_maxEpochs - 1); + m_mpi->Bcast(&epochCriterion.first, 1, m_mpi->MainNodeRank()); + m_mpi->Bcast(&epochCriterion.second, 1, m_mpi->MainNodeRank()); + m_mpi->Bcast(&lrControlCriterion, 1, m_mpi->MainNodeRank()); + } + + bool loadedPrevModel = false; + size_t epochsSinceLastLearnRateAdjust = i % m_learnRateAdjustInterval + 1; + if (avgCriterion == numeric_limits::infinity()) + { + avgCriterion = lrControlCriterion; + } + else + { + avgCriterion = ((epochsSinceLastLearnRateAdjust - 1 - epochsNotCountedInAvgCriterion) * + avgCriterion + + lrControlCriterion) / + (epochsSinceLastLearnRateAdjust - epochsNotCountedInAvgCriterion); + } + + if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && + m_learningRatesParam.size() <= i && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) + { + if (std::isnan(avgCriterion) || (prevCriterion - avgCriterion < 0 && prevCriterion != numeric_limits::infinity())) + { + if (m_loadBestModel) + { + // roll back + auto bestModelPath = GetModelNameForEpoch(i - m_learnRateAdjustInterval); + LOGPRINTF(stderr, "Loading (rolling back to) previous model with best training-criterion value: %ls.\n", bestModelPath.c_str()); + net->RereadPersistableParameters(bestModelPath); + LoadCheckPointInfo(i - m_learnRateAdjustInterval, + /*out*/ totalTrainingSamplesSeen, + /*out*/ learnRatePerSample, + smoothedGradients, + smoothedCounts, + /*out*/ prevCriterion, + /*out*/ m_prevChosenMinibatchSize); + loadedPrevModel = true; + } + } + + if (m_continueReduce) + { + if (std::isnan(avgCriterion) || + (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && + prevCriterion != numeric_limits::infinity())) + { + if (learnRateReduced == false) + { + learnRateReduced = true; + } + else + { + // In case of parallel training only the main node should we saving the model to prevent + // the parallel training nodes from colliding to write the same file + if ((m_mpi == nullptr) || m_mpi->IsMainNode()) + net->Save(GetModelNameForEpoch(i, true)); + + LOGPRINTF(stderr, "Finished training and saved final model\n\n"); + break; + } + } + + if (learnRateReduced) + { + learnRatePerSample *= m_learnRateDecreaseFactor; + LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample); + } + } + else + { + if (std::isnan(avgCriterion) || + (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && + prevCriterion != numeric_limits::infinity())) + { + + learnRatePerSample *= m_learnRateDecreaseFactor; + LOGPRINTF(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample); + } + else if (prevCriterion - avgCriterion > m_increaseLearnRateIfImproveMoreThan * prevCriterion && + prevCriterion != numeric_limits::infinity()) + { + learnRatePerSample *= m_learnRateIncreaseFactor; + LOGPRINTF(stderr, "learnRatePerSample increased to %.8g\n", learnRatePerSample); + } + } + } + else + { + if (std::isnan(avgCriterion)) + RuntimeError("The training criterion is not a number (NAN)."); + } + + // not loading previous values then set them + if (!loadedPrevModel && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) + { + prevCriterion = avgCriterion; + epochsNotCountedInAvgCriterion = 0; } // Synchronize all ranks before proceeding to ensure that - // rank 0 has finished writing the model file - // TODO[DataASGD]: should othet other rank waiting in async-mode + // nobody tries reading the checkpoint file at the same time + // as rank 0 deleting it below SynchronizeWorkers(); - // progress tracing for compute cluster management - ProgressTracing::TraceProgressPercentage(m_maxEpochs, 0.0, true); - ProgressTracing::TraceTrainLoss(m_lastFinishedEpochTrainLoss); - - // since we linked feature nodes. we need to remove it from the deletion - if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL && refNode != nullptr) + // Persist model and check-point info + if ((m_mpi == nullptr) || m_mpi->IsMainNode()) { - for (size_t i = 0; i < refFeatureNodes.size(); i++) + if (loadedPrevModel) { - // note we need to handle deletion carefully - refNet->ReplaceNode(refFeatureNodes[i]->NodeName(), refFeatureNodes[i]); + // If previous best model is loaded, we will first remove epochs that lead to worse results + for (int j = 1; j < m_learnRateAdjustInterval; j++) + { + int epochToDelete = i - j; + LOGPRINTF(stderr, "SGD: removing model and checkpoint files for epoch %d after rollback to epoch %lu\n", epochToDelete + 1, (unsigned long) (i - m_learnRateAdjustInterval) + 1); // report 1 based epoch number + _wunlink(GetModelNameForEpoch(epochToDelete).c_str()); + _wunlink(GetCheckPointFileNameForEpoch(epochToDelete).c_str()); + } + + // Set i back to the loaded model + i -= m_learnRateAdjustInterval; + LOGPRINTF(stderr, "SGD: revoke back to and update checkpoint file for epoch %d\n", i + 1); // report 1 based epoch number + SaveCheckPointInfo( + i, + totalTrainingSamplesSeen, + learnRatePerSample, + smoothedGradients, + smoothedCounts, + prevCriterion, + chosenMinibatchSize); + } + else + { + SaveCheckPointInfo( + i, + totalTrainingSamplesSeen, + learnRatePerSample, + smoothedGradients, + smoothedCounts, + prevCriterion, + chosenMinibatchSize); + auto modelName = GetModelNameForEpoch(i); + if (m_traceLevel > 0) + LOGPRINTF(stderr, "SGD: Saving checkpoint model '%ls'\n", modelName.c_str()); + net->Save(modelName); + if (!m_keepCheckPointFiles) + { + // delete previous checkpoint file to save space + if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && m_loadBestModel) + { + if (epochsSinceLastLearnRateAdjust != 1) + { + _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); + } + if (epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) + { + _wunlink(GetCheckPointFileNameForEpoch(i - m_learnRateAdjustInterval).c_str()); + } + } + else + { + _wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); + } + } + } + } + else + { + if (loadedPrevModel) + { + // Set i back to the loaded model + i -= m_learnRateAdjustInterval; } } - delete inputMatrices; - if (m_parallelizationMethod == ParallelizationMethod::dataParallelASGD) - m_pASGDHelper.reset(); + if (learnRatePerSample < 1e-12) + { + LOGPRINTF(stderr, "learnRate per sample is reduced to %.8g which is below 1e-12. stop training.\n", + learnRatePerSample); + } } +// --- END OF MAIN EPOCH LOOP + +// Check if we need to save best model per criterion and this is the main node as well. +if (m_saveBestModelPerCriterion && ((m_mpi == nullptr) || m_mpi->IsMainNode())) +{ + // For each criterion copies the best epoch to the new file with criterion name appended. + CopyBestEpochs(m_criteriaBestEpoch, *this, m_maxEpochs - 1); +} + +// Synchronize all ranks before proceeding to ensure that +// rank 0 has finished writing the model file +// TODO[DataASGD]: should othet other rank waiting in async-mode +SynchronizeWorkers(); + +// progress tracing for compute cluster management +ProgressTracing::TraceProgressPercentage(m_maxEpochs, 0.0, true); +ProgressTracing::TraceTrainLoss(m_lastFinishedEpochTrainLoss); + +// since we linked feature nodes. we need to remove it from the deletion +if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL && refNode != nullptr) +{ + for (size_t i = 0; i < refFeatureNodes.size(); i++) + { + // note we need to handle deletion carefully + refNet->ReplaceNode(refFeatureNodes[i]->NodeName(), refFeatureNodes[i]); + } +} + +delete inputMatrices; +if (m_parallelizationMethod == ParallelizationMethod::dataParallelASGD) + m_pASGDHelper.reset(); +} // namespace CNTK // ----------------------------------------------------------------------- // TrainOneEpoch() -- train one epoch @@ -998,9 +1053,15 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, const size_t maxNumberOfSamples, const size_t totalMBsSeenBefore, ::CNTK::Internal::TensorBoardFileWriterPtr tensorBoardWriter, - const int startEpoch) + const int startEpoch, + const std::vector& outputNodeNamesVector, + StreamMinibatchInputs* encodeInputMatrices, + StreamMinibatchInputs* decodeinputMatrices) { PROFILE_SCOPE(profilerEvtMainEpoch); + //for schedule sampling + std::random_device rd; + std::mt19937_64 randGen{rd()}; ScopedNetworkOperationMode modeGuard(net, NetworkOperationMode::training); @@ -1165,6 +1226,14 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, bool noMoreSamplesToProcess = false; bool isFirstMinibatch = true; + + //RNNT TS + //RNNT TS + if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS && refNet) + { + + + } //Microsoft::MSR::CNTK::StartProfiler(); for (;;) { @@ -1234,6 +1303,26 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, dynamic_pointer_cast>(labelNodes[0])->Value()); } + //RNNT TS + if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::TS && refNet) + { + auto reffeainput = (*encodeInputMatrices).begin(); + auto feainput = (*inputMatrices).begin(); + reffeainput->second.pMBLayout->CopyFrom(feainput->second.pMBLayout); + auto encodeMBLayout = feainput->second.pMBLayout; + reffeainput->second.GetMatrix().AssignRowSliceValuesOf(feainput->second.GetMatrix(), 0, 240); + + auto lminput = (*decodeinputMatrices).begin(); + auto decodeMBLayout = lminput->second.pMBLayout; + + reffeainput->second.GetMatrix().AssignRowSliceValuesOf(feainput->second.GetMatrix(), 0, 240); + + vector> outputlabels; + refNet->RNNT_decode_greedy(outputNodeNamesVector, reffeainput->second.GetMatrix(), *encodeMBLayout, reffeainput->second.GetMatrix(), *decodeMBLayout, outputlabels, 1.0); + //DataReaderHelpers:: + + //StreamBatch batch; + } // do forward and back propagation // We optionally break the minibatch into sub-minibatches. @@ -1256,7 +1345,6 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, // may be changed and need to be recomputed when gradient and function value share the same matrix net->ForwardProp(forwardPropRoots); // the bulk of this evaluation is reused in ComputeGradient() below - // =========================================================== // backprop // =========================================================== @@ -1401,7 +1489,6 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, double p_norm = dynamic_pointer_cast>(node)->Gradient().FrobeniusNorm(); //long m = (long) GetNumElements(); totalNorm += p_norm * p_norm; - } } totalNorm = sqrt(totalNorm); @@ -1807,7 +1894,6 @@ bool SGD::PreCompute(ComputationNetworkPtr net, net->ForwardProp(nodes); numItersSinceLastPrintOfProgress = ProgressTracing::TraceFakeProgress(numIterationsBeforePrintingProgress, numItersSinceLastPrintOfProgress); - } // finalize @@ -2218,12 +2304,12 @@ void SGD::TrainOneMiniEpochAndReloadModel(ComputationNetworkPtr net, std::string prefixMsg, const size_t maxNumOfSamples) { - TrainOneEpoch(net, refNet, refNode, epochNumber, epochSize, - trainSetDataReader, learnRatePerSample, minibatchSize, featureNodes, - labelNodes, criterionNodes, evaluationNodes, - inputMatrices, learnableNodes, smoothedGradients, smoothedCounts, - /*out*/ epochCriterion, /*out*/ epochEvalErrors, - " " + prefixMsg, maxNumOfSamples); // indent log msg by 2 (that is 1 more than the Finished message below) + //TrainOneEpoch(net, refNet, refNode, epochNumber, epochSize, + // trainSetDataReader, learnRatePerSample, minibatchSize, featureNodes, + // labelNodes, criterionNodes, evaluationNodes, + // inputMatrices, learnableNodes, smoothedGradients, smoothedCounts, + // /*out*/ epochCriterion, /*out*/ epochEvalErrors, + // " " + prefixMsg, maxNumOfSamples); // indent log msg by 2 (that is 1 more than the Finished message below) LOGPRINTF(stderr, " Finished Mini-Epoch[%d]: ", (int) epochNumber + 1); epochCriterion.LogCriterion(criterionNodes[0]->NodeName()); @@ -2856,6 +2942,8 @@ static AdaptationRegType ParseAdaptationRegType(const wstring& s) return AdaptationRegType::None; else if (EqualCI(s, L"kl") || EqualCI(s, L"klReg")) return AdaptationRegType::KL; + else if (EqualCI(s, L"ts")) + return AdaptationRegType::TS; else InvalidArgument("ParseAdaptationRegType: Invalid Adaptation Regularization Type. Valid values are (none | kl)"); } @@ -3069,6 +3157,9 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType) m_doGradientCheck = configSGD(L"gradientcheck", false); m_gradientCheckSigDigit = configSGD(L"sigFigs", 6.0); // TODO: why is this a double? + //RNNT TS + m_outputNodeNames = configSGD(L"OutputNodeNames", ConfigArray("")); + if (m_doGradientCheck && sizeofElemType != sizeof(double)) { LogicError("Gradient check needs to use precision = 'double'."); @@ -3374,6 +3465,6 @@ void SGDParams::InitializeAndCheckBlockMomentumSGDParameters() // register SGD<> with the ScriptableObject system ScriptableObjects::ConfigurableRuntimeTypeRegister::AddFloatDouble, SGD> registerSGDOptimizer(L"SGDOptimizer"); -} // namespace CNTK } // namespace MSR } // namespace Microsoft +} // namespace Microsoft diff --git a/Source/SGDLib/SGD.h b/Source/SGDLib/SGD.h index e7dc76aec..8b1199ec3 100644 --- a/Source/SGDLib/SGD.h +++ b/Source/SGDLib/SGD.h @@ -47,7 +47,8 @@ enum class LearningRateSearchAlgorithm : int enum class AdaptationRegType : int { None, - KL + KL, + TS }; enum class GradientsUpdateType : int @@ -266,6 +267,7 @@ protected: bool m_doGradientCheck; double m_gradientCheckSigDigit; + ConfigArray m_outputNodeNames; bool m_doUnitTest; bool m_useAllDataForPreComputedNode; @@ -386,7 +388,7 @@ public: void Train(shared_ptr net, DEVICEID_TYPE deviceId, IDataReader* trainSetDataReader, IDataReader* validationSetDataReader, int startEpoch, bool loadNetworkFromCheckpoint); - void Adapt(wstring origModelFileName, wstring refNodeName, + void Adapt(shared_ptr net, bool networkLoadedFromCheckpoint, wstring origModelFileName, wstring refNodeName, IDataReader* trainSetDataReader, IDataReader* validationSetDataReader, const DEVICEID_TYPE deviceID, const bool makeMode = true); @@ -512,7 +514,9 @@ protected: const size_t maxNumberOfSamples = SIZE_MAX, const size_t totalMBsSeenBefore = 0, ::CNTK::Internal::TensorBoardFileWriterPtr tensorBoardWriter = nullptr, - const int startEpoch = 0); + const int startEpoch = 0, const std::vector& outputNodeNamesVector = <>, + StreamMinibatchInputs* encodeInputMatrices = NULL, + StreamMinibatchInputs* decodeinputMatrices = NULL); void InitDistGradAgg(int numEvalNodes, int numGradientBits, int deviceId, int traceLevel); void InitModelAggregationHandler(int traceLevel, DEVICEID_TYPE devID);