CNTK v2 library: Fixed a bug in counting samples for computing criterion value
This commit is contained in:
Родитель
ec89875fd9
Коммит
762b4c2880
1
Makefile
1
Makefile
|
@ -422,6 +422,7 @@ CNTKLIBRARY_TESTS_SRC =\
|
|||
Tests/UnitTests/V2LibraryTests/FunctionTests.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/TruncatedLSTMAcousticModel.cpp \
|
||||
Examples/Evaluation/CPPEvalV2Client/EvalMultithreads.cpp \
|
||||
|
||||
CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests
|
||||
|
|
|
@ -785,6 +785,14 @@ namespace CNTK
|
|||
///
|
||||
virtual bool IsReadOnly() const { return m_data->IsReadOnly(); }
|
||||
|
||||
///
|
||||
/// Returns the number of masked/invalid values
|
||||
///
|
||||
virtual size_t MaskedCount() const
|
||||
{
|
||||
return m_mask ? m_mask->MaskedCount() : 0;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the NDArrayView object corresponding to the data contents of 'this value object.
|
||||
///
|
||||
|
@ -2606,6 +2614,8 @@ namespace CNTK
|
|||
///
|
||||
class Learner : public std::enable_shared_from_this<Learner>
|
||||
{
|
||||
static const std::wstring LearningRateAttributeName;
|
||||
|
||||
public:
|
||||
//
|
||||
// Method to update the parameters associated with this learner. By returning false, this method indicates that
|
||||
|
@ -2623,25 +2633,38 @@ namespace CNTK
|
|||
///
|
||||
// TODO: move the following two methods into ISerializable interface, make
|
||||
// Learner (and all other entities that need checkpointing capability) implement it.
|
||||
CNTK_API virtual Dictionary GetCheckpointState() const { return Dictionary(); }
|
||||
CNTK_API virtual Dictionary GetCheckpointState() const
|
||||
{
|
||||
Dictionary baseCheckpointState;
|
||||
baseCheckpointState[LearningRateAttributeName] = m_learningRate;
|
||||
|
||||
return baseCheckpointState;
|
||||
}
|
||||
|
||||
///
|
||||
/// Optionally overridable method to restore the learner's state from a previous checkpoint.
|
||||
///
|
||||
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& /*checkpoint*/) {}
|
||||
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint)
|
||||
{
|
||||
if (checkpoint.Contains(LearningRateAttributeName))
|
||||
m_learningRate = checkpoint[LearningRateAttributeName].Value<double>();
|
||||
}
|
||||
|
||||
///
|
||||
/// Destruct this Learner.
|
||||
///
|
||||
virtual ~Learner() {}
|
||||
|
||||
CNTK_API virtual void ResetLearningRate(double learningRate) { m_learningRate = learningRate; }
|
||||
CNTK_API virtual double LearningRate() const { return m_learningRate; }
|
||||
|
||||
protected:
|
||||
Learner(const std::vector<Parameter>& parameters)
|
||||
: m_parameters(parameters.begin(), parameters.end())
|
||||
Learner(const std::vector<Parameter>& parameters, double learningRate)
|
||||
: m_parameters(parameters.begin(), parameters.end()), m_learningRate(learningRate)
|
||||
{}
|
||||
|
||||
std::unordered_set<Parameter> m_parameters;
|
||||
|
||||
double m_learningRate;
|
||||
};
|
||||
|
||||
///
|
||||
|
@ -2876,7 +2899,9 @@ namespace CNTK
|
|||
FunctionPtr m_combinedTrainingFunction;
|
||||
FunctionPtr m_model;
|
||||
FunctionPtr m_lossFunction;
|
||||
FunctionPtr m_aggregatedLossFunction;
|
||||
FunctionPtr m_evaluationFunction;
|
||||
FunctionPtr m_aggregatedEvaluationFunction;
|
||||
|
||||
std::unordered_set<LearnerPtr> m_parameterLearners;
|
||||
|
||||
|
@ -3039,4 +3064,17 @@ namespace CNTK
|
|||
CNTK_API void ComputeInputPerDimMeansAndInvStdDevs(const MinibatchSourcePtr& minibatchSource,
|
||||
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>>& computedMeanAndVariances,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::CPUDevice());
|
||||
|
||||
///
|
||||
/// Set the process-wide setting for maximum number of CPU threads to be used by any individual compute operation
|
||||
/// Note that this is a per compute operation limit and if the user performs multiple compute operations concurrently
|
||||
/// by launching multiple threads and performing a compute operation inside, it will result in each of those concurrently
|
||||
/// executing operations to use the specified number of CPU threads limit.
|
||||
///
|
||||
CNTK_API void SetMaxNumCPUThreads(size_t numCPUThreads);
|
||||
|
||||
///
|
||||
/// Returns the current process-wide setting for maximum number of CPU threads to be used by any individual compute operation
|
||||
///
|
||||
CNTK_API size_t GetMaxNumCPUThreads();
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
#include "BestGpu.h"
|
||||
#include <mutex>
|
||||
#include <algorithm>
|
||||
#include <CPUMatrix.h> // For CPUMatrix::SetNumThreads
|
||||
#include <thread>
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
@ -166,4 +168,16 @@ namespace CNTK
|
|||
{
|
||||
s_uniqueDynamicAxisNames.RegisterAxisName(axisName);
|
||||
}
|
||||
|
||||
std::atomic<size_t> s_maxNumCPUThreads(std::thread::hardware_concurrency());
|
||||
void SetMaxNumCPUThreads(size_t numCPUThreads)
|
||||
{
|
||||
s_maxNumCPUThreads.store(numCPUThreads);
|
||||
Microsoft::MSR::CNTK::CPUMatrix<float>::SetNumThreads((int)numCPUThreads);
|
||||
}
|
||||
|
||||
size_t GetMaxNumCPUThreads()
|
||||
{
|
||||
return s_maxNumCPUThreads.load();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1686,18 +1686,27 @@ namespace CNTK
|
|||
}
|
||||
|
||||
ValuePtr nodeValue;
|
||||
auto layout = computationNode->GetMBLayout();
|
||||
switch (var.GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var,
|
||||
getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value(),
|
||||
computationNode->GetMBLayout());
|
||||
{
|
||||
auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
|
||||
if (varValue == nullptr)
|
||||
nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
|
||||
else
|
||||
nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, matrix, layout);
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var,
|
||||
getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value(),
|
||||
computationNode->GetMBLayout());
|
||||
{
|
||||
auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
|
||||
if (varValue == nullptr)
|
||||
nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
|
||||
else
|
||||
nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, matrix, layout);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType()));
|
||||
break;
|
||||
|
@ -2102,17 +2111,19 @@ namespace CNTK
|
|||
|
||||
FunctionPtr SquaredError(const Variable& prediction, const Variable& targets, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::SquaredError, prediction, targets, Dictionary(), name);
|
||||
auto difference = Minus(prediction, targets);
|
||||
auto squaredDifference = ElementTimes(difference, difference);
|
||||
return Internal::ReduceElements(squaredDifference, PrimitiveFunction::InternalSumReductionOpName, Axis::AllStaticAxes(), name);
|
||||
}
|
||||
|
||||
FunctionPtr CrossEntropyWithSoftmax(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return ReduceSum(Minus(ReduceLogSum(prediction, Axis(0)), TransposeTimes(labels, prediction)), name);
|
||||
return Minus(ReduceLogSum(prediction, Axis(0)), TransposeTimes(labels, prediction), name);
|
||||
}
|
||||
|
||||
FunctionPtr ClassificationError(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return ReduceSum(Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction))), name);
|
||||
return Minus(Constant::Scalar(prediction.GetDataType(), 1.0), TransposeTimes(labels, Hardmax(prediction)), name);
|
||||
}
|
||||
|
||||
FunctionPtr PastValue(const Variable& operand, const Variable& initialState, size_t offset, const std::wstring& name)
|
||||
|
|
|
@ -26,6 +26,9 @@ using namespace std;
|
|||
|
||||
namespace CNTK
|
||||
{
|
||||
/*static*/ const std::wstring Learner::LearningRateAttributeName = L"learningRate";
|
||||
/*static*/ const std::wstring LearnerBase::WasLearningRateResetAttributeName = L"wasLearningRateReset";
|
||||
|
||||
template <typename ElementType>
|
||||
/*static*/ shared_ptr<const Matrix<ElementType>> LearnerBase::GetMatrix(const NDArrayViewPtr& arrayView)
|
||||
{
|
||||
|
@ -141,7 +144,7 @@ namespace CNTK
|
|||
// L1 regularizer with proximal gradient descent method
|
||||
if (m_additionalOptions.l1RegularizationWeight > 0)
|
||||
{
|
||||
auto learningRate = ElementType(m_learningRates[m_sampleCount]);
|
||||
auto learningRate = ElementType(LearningRate());
|
||||
// multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample
|
||||
auto weight = ElementType(learningRate * m_additionalOptions.l1RegularizationWeight * actualMBSize);
|
||||
parameterValue->GetWritableMatrix<ElementType>()->InplaceSoftThreshold(weight);
|
||||
|
@ -159,8 +162,9 @@ namespace CNTK
|
|||
bool allocateSmoothGradients /* = true */,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
: Learner(parameters),
|
||||
m_learningRates(learningRates),
|
||||
: Learner(parameters, learningRates[0]),
|
||||
m_wasLearningRateReset(false),
|
||||
m_learningRateSchedule(learningRates),
|
||||
m_sampleCount(0),
|
||||
m_minibatchCount(0)
|
||||
{
|
||||
|
@ -225,7 +229,7 @@ namespace CNTK
|
|||
#endif
|
||||
|
||||
#if DUMPOUTPUT
|
||||
auto learningRate = ElementType(m_learningRates[m_sampleCount]);
|
||||
auto learningRate = ElementType(LearningRate());
|
||||
auto momentum = ElementType(MomentumPerMB(m_momentums[m_sampleCount], trainingSampleCount));
|
||||
LOGPRINTF(stderr, "learnRatePerSample=%0.8f, momentum=%0.8f, actualMBSize=%ld\n",
|
||||
learningRate, momentum, trainingSampleCount);
|
||||
|
@ -280,6 +284,9 @@ namespace CNTK
|
|||
checkpoint[L"sampleCount"] = m_sampleCount;
|
||||
checkpoint[L"minibatchCount"] = m_minibatchCount;
|
||||
|
||||
if (m_wasLearningRateReset)
|
||||
checkpoint[WasLearningRateResetAttributeName] = m_wasLearningRateReset;
|
||||
|
||||
// TODO: should we also save learning rate schedule into the checkpoint?
|
||||
// If that is the case, need to be able to override this method in subclasses
|
||||
// and save momentum schedule as well.
|
||||
|
@ -294,11 +301,19 @@ namespace CNTK
|
|||
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||
checkpoint[parameter.Uid()] = *smoothedGradientValue;
|
||||
}
|
||||
|
||||
// Add the base Learner's checkpoint state
|
||||
auto baseCheckpointState = Learner::GetCheckpointState();
|
||||
checkpoint.Add(baseCheckpointState);
|
||||
|
||||
return checkpoint;
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerBase::RestoreFromCheckpoint(const Dictionary& checkpoint) /*override*/
|
||||
{
|
||||
// Restore the base learner's checkpoint state
|
||||
Learner::RestoreFromCheckpoint(checkpoint);
|
||||
|
||||
m_sampleCount = checkpoint[L"sampleCount"].Value<size_t>();
|
||||
m_minibatchCount = checkpoint[L"minibatchCount"].Value<size_t>();
|
||||
|
||||
|
@ -309,6 +324,9 @@ namespace CNTK
|
|||
LogicError("Unsupported checkpoint version.");
|
||||
}
|
||||
|
||||
if (checkpoint.Contains(WasLearningRateResetAttributeName))
|
||||
m_wasLearningRateReset = checkpoint[WasLearningRateResetAttributeName].Value<bool>();
|
||||
|
||||
for (const auto& parameter : Parameters())
|
||||
{
|
||||
if (!checkpoint.Contains(parameter.Uid()))
|
||||
|
@ -348,7 +366,7 @@ namespace CNTK
|
|||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
|
||||
auto learningRate = ElementType(m_learningRates[m_sampleCount]);
|
||||
auto learningRate = ElementType(LearningRate());
|
||||
auto momentum = ElementType(MomentumPerMB(m_momentums[m_sampleCount], trainingSampleCount));
|
||||
|
||||
// TODO: break up the NormalGrad into 3 different functions, each with its own set of parameters
|
||||
|
@ -382,7 +400,7 @@ namespace CNTK
|
|||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
|
||||
auto learningRate = ElementType(m_learningRates[m_sampleCount]);
|
||||
auto learningRate = ElementType(LearningRate());
|
||||
|
||||
auto aveMultiplier = smoothedGradientMatrix->Adagrad(*gradientMatrix, m_needAveMultiplier);
|
||||
Matrix<ElementType>::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix);
|
||||
|
@ -418,7 +436,7 @@ namespace CNTK
|
|||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
|
||||
auto learningRate = m_learningRates[m_sampleCount];
|
||||
auto learningRate = LearningRate();
|
||||
auto momentum = MomentumPerMB(m_momentums[m_sampleCount], trainingSampleCount);
|
||||
|
||||
const double targetAdagradAvDenom = 0.0025; // 1/400 magic constant
|
||||
|
@ -469,7 +487,7 @@ namespace CNTK
|
|||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue);
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue);
|
||||
|
||||
auto learningRate = ElementType(m_learningRates[m_sampleCount]);
|
||||
auto learningRate = ElementType(LearningRate());
|
||||
|
||||
auto aveMultiplier = smoothedGradientMatrix->RmsProp(*gradientMatrix,
|
||||
ElementType(m_gamma), ElementType(m_inc),
|
||||
|
|
|
@ -26,6 +26,8 @@ namespace CNTK
|
|||
// and adds a few pre-/postprocessing methods (which are invoked before and after the update).
|
||||
class LearnerBase : public Learner
|
||||
{
|
||||
static const std::wstring WasLearningRateResetAttributeName;
|
||||
|
||||
public:
|
||||
virtual bool Update(const std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) override final;
|
||||
|
||||
|
@ -33,6 +35,20 @@ namespace CNTK
|
|||
|
||||
virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) override final;
|
||||
|
||||
virtual void ResetLearningRate(double learningRate) override final
|
||||
{
|
||||
m_wasLearningRateReset = true;
|
||||
Learner::ResetLearningRate(learningRate);
|
||||
}
|
||||
|
||||
virtual double LearningRate() const override final
|
||||
{
|
||||
if (m_wasLearningRateReset)
|
||||
return Learner::LearningRate();
|
||||
else
|
||||
return m_learningRateSchedule[m_sampleCount];
|
||||
}
|
||||
|
||||
protected:
|
||||
LearnerBase(const std::vector<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
|
@ -44,7 +60,8 @@ namespace CNTK
|
|||
|
||||
std::string LearnerType() const;
|
||||
|
||||
LearningRatesPerSample m_learningRates;
|
||||
bool m_wasLearningRateReset;
|
||||
LearningRatesPerSample m_learningRateSchedule;
|
||||
|
||||
AdditionalLearningOptions m_additionalOptions;
|
||||
|
||||
|
|
|
@ -13,7 +13,24 @@ namespace CNTK
|
|||
Trainer::Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const FunctionPtr& evaluationFunction, const std::unordered_set<LearnerPtr>& parameterLearners)
|
||||
: m_model(model), m_lossFunction(lossFunction), m_evaluationFunction(evaluationFunction), m_parameterLearners(parameterLearners), m_prevMinibatchNumSamples(1)
|
||||
{
|
||||
m_combinedTrainingFunction = Combine({ model, lossFunction, evaluationFunction });
|
||||
if (m_lossFunction->Output().DynamicAxes().empty())
|
||||
InvalidArgument("The loss function specified in the Trainer constructor must correspond to minibatch data and have dynamic axes");
|
||||
|
||||
if (m_evaluationFunction && m_evaluationFunction->Output().DynamicAxes().empty())
|
||||
InvalidArgument("The evaluation function specified in the Trainer constructor must correspond to minibatch data and have dynamic axes");
|
||||
|
||||
m_aggregatedLossFunction = ReduceSum(lossFunction);
|
||||
if (m_evaluationFunction)
|
||||
m_aggregatedEvaluationFunction = ReduceSum(m_evaluationFunction);
|
||||
|
||||
std::vector<FunctionPtr> combinedFunctionArgs = { m_model, m_aggregatedLossFunction, m_lossFunction };
|
||||
if (m_evaluationFunction)
|
||||
{
|
||||
combinedFunctionArgs.push_back(m_aggregatedEvaluationFunction);
|
||||
combinedFunctionArgs.push_back(m_evaluationFunction);
|
||||
}
|
||||
|
||||
m_combinedTrainingFunction = Combine(combinedFunctionArgs);
|
||||
|
||||
auto modelParameters = m_combinedTrainingFunction->Parameters();
|
||||
std::unordered_set<Parameter> learnerParameters;
|
||||
|
@ -66,20 +83,11 @@ namespace CNTK
|
|||
return scalar;
|
||||
}
|
||||
|
||||
static size_t GetSampleCountFromArguments(const Variable& evalOrLossArgument, const std::unordered_map<Variable, ValuePtr>& arguments)
|
||||
static size_t GetSampleCount(const Variable& var, const ValuePtr& value)
|
||||
{
|
||||
// Find the argument whose dynamic axes match the criterion operation's dynamic axes (i.e. label dynamic axes)
|
||||
// Then we determine the actual number of samples contributing to the training loss from the argument's Value object
|
||||
auto argumentIter = std::find_if(arguments.begin(), arguments.end(), [evalOrLossArgument](const std::pair<Variable, ValuePtr>& currentPair) {
|
||||
return (currentPair.first.DynamicAxes() == evalOrLossArgument.DynamicAxes());
|
||||
});
|
||||
|
||||
auto argumentValue = argumentIter->second;
|
||||
auto argumentVar = argumentIter->first;
|
||||
auto argumentDataShape = argumentValue->Shape();
|
||||
auto mask = argumentValue->Mask();
|
||||
size_t numMaskedSamples = (mask != nullptr) ? mask->MaskedCount() : 0;
|
||||
size_t numSamplesInDataArrayView = argumentDataShape.SubShape(argumentVar.Shape().Rank()).TotalSize();
|
||||
auto valueDataShape = value->Shape();
|
||||
size_t numMaskedSamples = value->MaskedCount();
|
||||
size_t numSamplesInDataArrayView = valueDataShape.SubShape(var.Shape().Rank()).TotalSize();
|
||||
if (numMaskedSamples > numSamplesInDataArrayView)
|
||||
LogicError("Number of masked values cannot exceed the number of samples that the Value object's Data NDArrayView can hold");
|
||||
|
||||
|
@ -88,15 +96,15 @@ namespace CNTK
|
|||
|
||||
double Trainer::TestMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
if (!m_evaluationFunction)
|
||||
if (!m_aggregatedEvaluationFunction)
|
||||
InvalidArgument("Trainer::TestMinibatch: Cannot test when no evaluation function was specified during 'this' trainer's construction");
|
||||
|
||||
// TODO: Should we refactor this code that is somewhat similar to the prologue of the TrainMinibatch function
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_evaluationFunction, nullptr } };
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_aggregatedEvaluationFunction, nullptr }, {m_evaluationFunction, nullptr} };
|
||||
m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice);
|
||||
|
||||
auto sampleCount = GetSampleCountFromArguments(*(m_evaluationFunction->Arguments().begin()), arguments);
|
||||
return (GetScalarValue(outputs[m_evaluationFunction]) / sampleCount);
|
||||
auto sampleCount = GetSampleCount(m_evaluationFunction, outputs[m_evaluationFunction]);
|
||||
return (GetScalarValue(outputs[m_aggregatedEvaluationFunction]) / sampleCount);
|
||||
}
|
||||
|
||||
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
|
@ -107,16 +115,16 @@ namespace CNTK
|
|||
|
||||
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, std::unordered_map<Variable, ValuePtr>& outputsToFetch, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_lossFunction, nullptr } };
|
||||
if (m_evaluationFunction)
|
||||
outputs.insert({ m_evaluationFunction, nullptr });
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_aggregatedLossFunction, nullptr }, { m_lossFunction, nullptr } };
|
||||
if (m_aggregatedEvaluationFunction)
|
||||
outputs.insert({ m_aggregatedEvaluationFunction, nullptr });
|
||||
|
||||
outputs.insert(outputsToFetch.begin(), outputsToFetch.end());
|
||||
|
||||
auto backPropSate = m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice, { m_lossFunction });
|
||||
m_prevMinibatchAggregateTrainingLossValue = outputs[m_lossFunction];
|
||||
if (m_evaluationFunction)
|
||||
m_prevMinibatchAggregateEvalCriterionValue = outputs[m_evaluationFunction];
|
||||
auto backPropSate = m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice, { m_aggregatedLossFunction });
|
||||
m_prevMinibatchAggregateTrainingLossValue = outputs[m_aggregatedLossFunction];
|
||||
if (m_aggregatedEvaluationFunction)
|
||||
m_prevMinibatchAggregateEvalCriterionValue = outputs[m_aggregatedEvaluationFunction];
|
||||
|
||||
for (auto outputToFetch : outputsToFetch)
|
||||
{
|
||||
|
@ -124,8 +132,8 @@ namespace CNTK
|
|||
outputsToFetch[outputToFetch.first] = outputs[outputToFetch.first];
|
||||
}
|
||||
|
||||
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(m_lossFunction->Output().GetDataType(), m_prevMinibatchAggregateTrainingLossValue->Shape(), computeDevice), outputs.at(m_lossFunction)->Mask());
|
||||
if (m_lossFunction->Output().GetDataType() == DataType::Float)
|
||||
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(m_aggregatedLossFunction->Output().GetDataType(), m_prevMinibatchAggregateTrainingLossValue->Shape(), computeDevice), outputs.at(m_aggregatedLossFunction)->Mask());
|
||||
if (m_aggregatedLossFunction->Output().GetDataType() == DataType::Float)
|
||||
rootGradientValue->Data()->SetValue(1.0f);
|
||||
else
|
||||
rootGradientValue->Data()->SetValue(1.0);
|
||||
|
@ -135,9 +143,9 @@ namespace CNTK
|
|||
for (const auto& parameter : modelParameters)
|
||||
parameterGradients[parameter] = nullptr;
|
||||
|
||||
m_combinedTrainingFunction->Backward(backPropSate, { { m_lossFunction, rootGradientValue } }, parameterGradients);
|
||||
m_combinedTrainingFunction->Backward(backPropSate, { { m_aggregatedLossFunction, rootGradientValue } }, parameterGradients);
|
||||
|
||||
m_prevMinibatchNumSamples = GetSampleCountFromArguments(*(m_lossFunction->Arguments().begin()), arguments);
|
||||
m_prevMinibatchNumSamples = GetSampleCount(m_lossFunction, outputs[m_lossFunction]);
|
||||
|
||||
bool anyUpdatesPerformed = false;
|
||||
for (auto learner : m_parameterLearners)
|
||||
|
|
|
@ -186,7 +186,7 @@ namespace CNTK
|
|||
|
||||
void PackedValue::Unpack() const
|
||||
{
|
||||
if (Internal::IsAutomaticUnpackingOfPackedValuesDisabled())
|
||||
if (m_packedDataLayout && (m_packedDataLayout->GetNumTimeSteps() != 1) && (m_packedDataLayout->GetNumSequences() != 1) && Internal::IsAutomaticUnpackingOfPackedValuesDisabled())
|
||||
LogicError("PackedValue::Unpack: Automatic unpacking of PackedValue objects is disabled");
|
||||
|
||||
if (m_isPacked)
|
||||
|
|
|
@ -14,13 +14,16 @@ namespace CNTK
|
|||
{
|
||||
class PackedValue final : public Value
|
||||
{
|
||||
template <typename T, typename ...CtorArgTypes>
|
||||
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
|
||||
|
||||
public:
|
||||
template <typename ElementType>
|
||||
PackedValue(const NDShape& sampleShape, const std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>>& packedDataMatrix, const std::shared_ptr<Microsoft::MSR::CNTK::MBLayout>& packedDataLayout, bool isReadOnly)
|
||||
: Value(nullptr), m_isPacked(true), m_sampleShape(sampleShape), m_packedData(nullptr), m_packedDataLayout(packedDataLayout), m_isReadOnly(isReadOnly)
|
||||
{
|
||||
NDShape packedMatrixShape({ packedDataMatrix->GetNumRows(), packedDataMatrix->GetNumCols() });
|
||||
auto tensorView = new TensorView<ElementType>(packedDataMatrix, AsTensorViewShape(packedMatrixShape));
|
||||
auto tensorView = new Microsoft::MSR::CNTK::TensorView<ElementType>(packedDataMatrix, AsTensorViewShape(packedMatrixShape));
|
||||
m_packedData = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(packedDataMatrix->GetDeviceId()), AsStorageFormat(packedDataMatrix->GetFormat()), packedMatrixShape, m_isReadOnly, tensorView);
|
||||
|
||||
// Determine unpacked shape
|
||||
|
@ -37,6 +40,15 @@ namespace CNTK
|
|||
StorageFormat GetStorageFormat() const override { return m_isPacked? m_packedData->GetStorageFormat() : Value::GetStorageFormat(); }
|
||||
bool IsReadOnly() const override { return m_isPacked ? m_packedData->IsReadOnly() : Value::IsReadOnly(); }
|
||||
|
||||
size_t MaskedCount() const override
|
||||
{
|
||||
if (m_isPacked)
|
||||
// Compute the number of masked samples after the data will be unpacked
|
||||
return m_packedDataLayout ? ((m_packedDataLayout->GetNumTimeSteps() * m_packedDataLayout->GetNumSequences()) - m_packedDataLayout->GetActualNumSamples()) : 0;
|
||||
else
|
||||
return Value::MaskedCount();
|
||||
}
|
||||
|
||||
NDArrayViewPtr Data() const override
|
||||
{
|
||||
Unpack();
|
||||
|
@ -51,7 +63,18 @@ namespace CNTK
|
|||
|
||||
ValuePtr DeepClone(bool /*readOnly = false*/) const override
|
||||
{
|
||||
LogicError("DeepClone is currently unsupported for PackedValue objects");
|
||||
if (m_isPacked)
|
||||
{
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::MBLayout> packedLayoutCopy;
|
||||
if (m_packedDataLayout)
|
||||
{
|
||||
packedLayoutCopy = std::make_shared<Microsoft::MSR::CNTK::MBLayout>();
|
||||
packedLayoutCopy->CopyFrom(m_packedDataLayout);
|
||||
}
|
||||
return MakeSharedObject<PackedValue>(m_sampleShape, m_packedData->DeepClone(), packedLayoutCopy, m_isReadOnly);
|
||||
}
|
||||
else
|
||||
return Value::DeepClone();
|
||||
}
|
||||
|
||||
ValuePtr Alias(bool /*readOnly = false*/) const override
|
||||
|
@ -73,6 +96,16 @@ namespace CNTK
|
|||
return { m_packedData->GetMatrix<ElementType>(), m_packedDataLayout };
|
||||
}
|
||||
|
||||
private:
|
||||
PackedValue(const NDShape& sampleShape, const NDArrayViewPtr& packedData, const std::shared_ptr<Microsoft::MSR::CNTK::MBLayout>& packedDataLayout, bool isReadOnly)
|
||||
: Value(nullptr), m_isPacked(true), m_sampleShape(sampleShape), m_packedData(packedData), m_packedDataLayout(packedDataLayout), m_isReadOnly(isReadOnly)
|
||||
{
|
||||
// Determine unpacked shape
|
||||
m_unpackedShape = sampleShape;
|
||||
if (packedDataLayout)
|
||||
m_unpackedShape = m_unpackedShape.AppendShape({ packedDataLayout->GetNumTimeSteps(), packedDataLayout->GetNumSequences() });
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_isReadOnly;
|
||||
NDShape m_sampleShape;
|
||||
|
|
|
@ -78,10 +78,13 @@ namespace CNTK
|
|||
assert(!m_valueInitializer);
|
||||
assert(!m_valueInitializationDevice);
|
||||
|
||||
auto filterRank = (int)initializationConfig[FilterRankAttributeName].Value<size_t>();
|
||||
auto outputRank = (int)initializationConfig[OutputRankAttributeName].Value<size_t>();
|
||||
if ((filterRank + outputRank) > m_shape.Rank())
|
||||
InvalidArgument("Sum of filter rank (%d) and output rank (%d) of the parameter initializer cannot exceed the Parameter's rank", filterRank, outputRank, (int)m_shape.Rank());
|
||||
if (initializationConfig.Contains(FilterRankAttributeName))
|
||||
{
|
||||
auto filterRank = (int)initializationConfig[FilterRankAttributeName].Value<size_t>();
|
||||
auto outputRank = (int)initializationConfig[OutputRankAttributeName].Value<size_t>();
|
||||
if ((filterRank + outputRank) > m_shape.Rank())
|
||||
InvalidArgument("Sum of filter rank (%d) and output rank (%d) of the parameter initializer cannot exceed the Parameter's rank(%d)", filterRank, outputRank, (int)m_shape.Rank());
|
||||
}
|
||||
|
||||
m_valueInitializer.reset(new ParameterInitializer(initializationConfig));
|
||||
m_valueInitializationDevice.reset(new DeviceDescriptor(device));
|
||||
|
|
|
@ -23,6 +23,7 @@ cp -R $DataSourceDir/CIFAR/v0/cifar-10-batches-py $DataDir || exit $?
|
|||
cp -R $TEST_DIR/../../../../Examples/Other/Simple2d/Data/SimpleDataTrain_cntk_text.txt $DataDir || exit $?
|
||||
cp -R $TEST_DIR/../../Text/SequenceClassification/Data/Train.ctf $DataDir || exit $?
|
||||
cp -R $TEST_DIR/../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf $DataDir || exit $?
|
||||
cp -R $TEST_DIR/../../../../Examples/Speech/AN4/Data/* $DataDir || exit $?
|
||||
|
||||
pushd $DataDir
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
|
|||
}
|
||||
}
|
||||
|
||||
void TestCifarResnet()
|
||||
void TrainCifarResnet()
|
||||
{
|
||||
#ifndef CPUONLY
|
||||
TrainResNetCifarClassifer(DeviceDescriptor::GPUDevice(0), true /*testSaveAndReLoad*/);
|
||||
|
|
|
@ -137,11 +137,11 @@ std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(C
|
|||
|
||||
unsigned long seed = 1;
|
||||
auto createProjectionParam = [device, &seed](size_t outputDim, size_t inputDim) {
|
||||
return CNTK::Parameter({ outputDim, inputDim }, AsDataType<ElementType>(), UniformInitializer(1, seed++), device);
|
||||
return CNTK::Parameter({ outputDim, inputDim }, CNTK::AsDataType<ElementType>(), CNTK::UniformInitializer(1, seed++), device);
|
||||
};
|
||||
|
||||
auto createDiagWeightParam = [device, &seed](size_t dim) {
|
||||
return CNTK::Parameter({ dim }, AsDataType<ElementType>(), UniformInitializer(1, seed++), device);
|
||||
return CNTK::Parameter({ dim }, CNTK::AsDataType<ElementType>(), CNTK::UniformInitializer(1, seed++), device);
|
||||
};
|
||||
|
||||
auto stabilizedPrevOutput = Stabilize<ElementType>(prevOutput, device);
|
||||
|
@ -156,7 +156,7 @@ std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(C
|
|||
auto bit = CNTK::ElementTimes(it, CNTK::Tanh(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput)));
|
||||
|
||||
// Forget-me-not gate
|
||||
auto ft = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto ft = CNTK::Sigmoid(projectInput() + CNTK::Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + CNTK::ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto bft = CNTK::ElementTimes(ft, prevCellState);
|
||||
|
||||
auto ct = bft + bit;
|
||||
|
|
|
@ -36,8 +36,8 @@ void TestFeedForwardNetworkCreation(const DeviceDescriptor& device, bool testSav
|
|||
auto classifierOutput = FullyConnectedFeedForwardClassifierNet(inputVar, numOutputClasses, hiddenLayersDim, numHiddenLayers, device, std::bind(Sigmoid, _1, L""), L"classifierOutput");
|
||||
|
||||
auto labelsVar = InputVariable({ numOutputClasses }, DataType::Float, L"Labels");
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labelsVar, L"LossFunction");
|
||||
auto prediction = CNTK::ClassificationError(classifierOutput, labelsVar, L"ClassificationError");
|
||||
auto trainingLoss = ReduceSum(CNTK::CrossEntropyWithSoftmax(classifierOutput, labelsVar), L"LossFunction");
|
||||
auto prediction = ReduceSum(CNTK::ClassificationError(classifierOutput, labelsVar), L"ClassificationError");
|
||||
|
||||
auto ffNet = CNTK::Combine({ trainingLoss, prediction, classifierOutput }, L"ClassifierModel");
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ void TensorTests();
|
|||
void FeedForwardTests();
|
||||
void RecurrentFunctionTests();
|
||||
void TrainerTests();
|
||||
void TestCifarResnet();
|
||||
void TrainCifarResnet();
|
||||
void FunctionTests();
|
||||
void TrainLSTMSequenceClassifer();
|
||||
void SerializationTests();
|
||||
|
@ -34,7 +34,7 @@ int main()
|
|||
LearnerTests();
|
||||
|
||||
TrainerTests();
|
||||
TestCifarResnet();
|
||||
TrainCifarResnet();
|
||||
TrainLSTMSequenceClassifer();
|
||||
|
||||
TrainSequenceToSequenceTranslator();
|
||||
|
|
|
@ -41,8 +41,8 @@ void TestRecurrentNetworkCreation(const DeviceDescriptor& device, bool testSaveA
|
|||
auto classifierOutput = LSTMNet<ElementType>(features, cellDim, hiddenDim, numOutputClasses, numLSTMLayers, device, L"classifierOutput");
|
||||
|
||||
auto labelsVar = InputVariable({ numOutputClasses }, AsDataType<ElementType>(), L"labels");
|
||||
auto trainingLoss = CrossEntropyWithSoftmax(classifierOutput, labelsVar, L"lossFunction");
|
||||
auto prediction = ClassificationError(classifierOutput, labelsVar, L"classificationError");
|
||||
auto trainingLoss = ReduceSum(CrossEntropyWithSoftmax(classifierOutput, labelsVar), L"lossFunction");
|
||||
auto prediction = ReduceSum(ClassificationError(classifierOutput, labelsVar), L"classificationError");
|
||||
|
||||
auto LSTMClassifier = Combine({ trainingLoss, prediction, classifierOutput }, L"LSTMClassifier");
|
||||
|
||||
|
|
|
@ -76,8 +76,122 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool testSaveAnd
|
|||
}
|
||||
}
|
||||
|
||||
void TestLearningRateControl(const DeviceDescriptor& device)
|
||||
{
|
||||
const size_t inputDim = 2000;
|
||||
const size_t cellDim = 25;
|
||||
const size_t hiddenDim = 25;
|
||||
const size_t embeddingDim = 50;
|
||||
const size_t numOutputClasses = 5;
|
||||
|
||||
auto features = InputVariable({ inputDim }, true /*isSparse*/, DataType::Float, L"features");
|
||||
auto classifierOutput = LSTMSequenceClassiferNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, L"classifierOutput");
|
||||
|
||||
auto labels = InputVariable({ numOutputClasses }, DataType::Float, L"labels", { Axis::DefaultBatchAxis() });
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");
|
||||
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
|
||||
|
||||
auto minibatchSource = TextFormatMinibatchSource(L"Train.ctf", { { L"features", inputDim, true, L"x" }, { L"labels", numOutputClasses, false, L"y" } }, 0);
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
|
||||
|
||||
const size_t minibatchSize = 200;
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
auto actualMBSize = minibatchData[labelStreamInfo].m_numSamples;
|
||||
|
||||
LearningRatesPerSample learningRateSchedule({ { 2, 0.0005 }, { 2, 0.00025 } }, actualMBSize);
|
||||
auto learner = SGDLearner(classifierOutput->Parameters(), learningRateSchedule);
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { learner });
|
||||
|
||||
if (learner->LearningRate() != 0.0005)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
if (learner->LearningRate() != 0.0005)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
const wchar_t* modelFile = L"seq2seq.model";
|
||||
trainer.SaveCheckpoint(modelFile);
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto MB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (learner->LearningRate() != 0.00025)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto MB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (learner->LearningRate() != 0.00025)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.RestoreFromCheckpoint(modelFile);
|
||||
if (learner->LearningRate() != 0.0005)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (postRestoreMB2Loss != MB2Loss)
|
||||
throw std::runtime_error("Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
if (learner->LearningRate() != 0.00025)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
auto postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (postRestoreMB3Loss != MB3Loss)
|
||||
throw std::runtime_error("Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
trainer.RestoreFromCheckpoint(modelFile);
|
||||
if (learner->LearningRate() != 0.0005)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
learner->ResetLearningRate(0.0004);
|
||||
if (learner->LearningRate() != 0.0004)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.SaveCheckpoint(modelFile);
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (postRestoreMB2Loss != MB2Loss)
|
||||
throw std::runtime_error("Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
if (learner->LearningRate() != 0.0004)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (postRestoreMB3Loss == MB3Loss)
|
||||
throw std::runtime_error("Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
if (learner->LearningRate() != 0.0004)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.RestoreFromCheckpoint(modelFile);
|
||||
if (learner->LearningRate() != 0.0004)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB2Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (postRestoreMB2Loss != MB2Loss)
|
||||
throw std::runtime_error("Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
if (learner->LearningRate() != 0.0004)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
postRestoreMB3Loss = trainer.PreviousMinibatchLossAverage();
|
||||
if (postRestoreMB3Loss == MB3Loss)
|
||||
throw std::runtime_error("Post checkpoint restoration training loss does not match expectation");
|
||||
|
||||
if (learner->LearningRate() != 0.0004)
|
||||
throw std::runtime_error("Learner::LearningRate does not match expectation");
|
||||
}
|
||||
|
||||
void TrainLSTMSequenceClassifer()
|
||||
{
|
||||
#ifndef CPUONLY
|
||||
TestLearningRateControl(DeviceDescriptor::GPUDevice(0));
|
||||
#endif
|
||||
|
||||
#ifndef CPUONLY
|
||||
TrainLSTMSequenceClassifer(DeviceDescriptor::GPUDevice(0), true);
|
||||
#endif
|
||||
|
|
|
@ -95,21 +95,23 @@ void TrainTruncatedLSTMAcousticModelClassifer(const DeviceDescriptor& device, bo
|
|||
prediction = predictionVar;
|
||||
}
|
||||
|
||||
const size_t numTrainingSamples = 20480;
|
||||
const size_t numTrainingSamples = 81920;
|
||||
const size_t truncationLength = 20;
|
||||
Dictionary truncatedModeConfig;
|
||||
truncatedModeConfig[L"truncated"] = true;
|
||||
truncatedModeConfig[L"truncationLength"] = truncationLength;
|
||||
minibatchSource = CreateMinibatchSource(baseFeaturesDim, numOutputClasses, truncatedModeConfig, numTrainingSamples);
|
||||
|
||||
const size_t numberParallelSequencesPerMB = 1;
|
||||
const size_t numberParallelSequencesPerMB = 32;
|
||||
const size_t minibatchSize = truncationLength * numberParallelSequencesPerMB;
|
||||
|
||||
featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
|
||||
|
||||
double learningRatePerSample = 0.000781;
|
||||
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, 0.0);
|
||||
size_t momentumTimeConstant = 6074;
|
||||
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
|
||||
auto learner = MomentumSGDLearner(classifierOutput->Parameters(), learningRatePerSample, momentumPerSample);
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, {learner});
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
|
|
Загрузка…
Ссылка в новой задаче