CNTK v2 library: Added Seq2Seq implementation as a test using V2 C++ API and other related changes
This commit is contained in:
Родитель
009e53dfe7
Коммит
7a5c133edc
1
Makefile
1
Makefile
|
@ -416,6 +416,7 @@ CNTKLIBRARY_TESTS_SRC =\
|
|||
Tests/UnitTests/V2LibraryTests/LearnerTests.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/FunctionTests.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp \
|
||||
|
||||
CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests
|
||||
CNTKLIBRARY_TESTS_OBJ := $(patsubst %.cu, $(OBJDIR)/%.o, $(patsubst %.cpp, $(OBJDIR)/%.o, $(CNTKLIBRARY_TESTS_SRC)))
|
||||
|
|
|
@ -698,6 +698,9 @@ namespace CNTK
|
|||
CNTK_API static const std::wstring StaticAxisNamePrefix;
|
||||
static const size_t SentinelStaticAxisIndexValueForDynamicAxes = SIZE_MAX;
|
||||
|
||||
// TODO: Make this thread-safe
|
||||
CNTK_API static std::unordered_set<std::wstring> s_allKnownDynamicAxisNames;
|
||||
|
||||
public:
|
||||
///
|
||||
/// Construct an Axis object denoting a static axis with the specified index.
|
||||
|
@ -713,7 +716,9 @@ namespace CNTK
|
|||
///
|
||||
explicit Axis(const std::wstring& name, bool isOrderedDynamicAxis = true)
|
||||
: m_staticAxisIdx(SentinelStaticAxisIndexValueForDynamicAxes), m_name(name), m_isOrderedDynamicAxis(isOrderedDynamicAxis)
|
||||
{}
|
||||
{
|
||||
RegisterAxisName(name);
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns a boolean indicating if 'this' Axis corresponds to a static axis
|
||||
|
@ -746,6 +751,11 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API static const Axis& DefaultBatchAxis();
|
||||
|
||||
///
|
||||
/// Returns a new unique Dynamic axis
|
||||
///
|
||||
CNTK_API static Axis NewUniqueDynamicAxis(const std::wstring& axisNamePrefix, bool isOrderedDynamicAxis = true);
|
||||
|
||||
///
|
||||
/// Name of 'this' axis
|
||||
///
|
||||
|
@ -758,6 +768,9 @@ namespace CNTK
|
|||
: m_staticAxisIdx(SentinelStaticAxisIndexValueForDynamicAxes)
|
||||
{}
|
||||
|
||||
private:
|
||||
CNTK_API void RegisterAxisName(const std::wstring& axisName);
|
||||
|
||||
private:
|
||||
size_t m_staticAxisIdx;
|
||||
std::wstring m_name;
|
||||
|
@ -819,7 +832,9 @@ namespace CNTK
|
|||
template <typename T>
|
||||
friend struct std::hash;
|
||||
|
||||
public:
|
||||
CNTK_API static const std::vector<Axis> DefaultInputVariableDynamicAxes;
|
||||
|
||||
public:
|
||||
///
|
||||
/// Create an 'Input' Variable.
|
||||
|
@ -904,6 +919,11 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API Variable(const FunctionPtr& function);
|
||||
|
||||
///
|
||||
/// Implicit conversion to a FunctionPtr; creates a pass through primitive function
|
||||
///
|
||||
CNTK_API operator FunctionPtr() const;
|
||||
|
||||
///
|
||||
/// Default constructor for creating an invalid/null Variable instance.
|
||||
/// Required for use in a std::vector container.
|
||||
|
@ -1063,6 +1083,70 @@ namespace CNTK
|
|||
: Variable(shape, VariableKind::Parameter, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), true, {}, name)
|
||||
{}
|
||||
|
||||
///
|
||||
/// Create a Parameter initialized with random values drawn from a Uniform distribution in the range [-0.05, 0.05]
|
||||
///
|
||||
static Parameter Uniform(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
{
|
||||
return UniformInitParameter(shape, type, 1.0/20, seed, device, name);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create a Parameter initialized with random values from a Uniform distribution in the range [-sqrt(6 / fanIn), sqrt(6 / fanIn)]
|
||||
///
|
||||
static Parameter HeUniform(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
{
|
||||
NDShape fanInShape = shape.SubShape(fanOutRank);
|
||||
return UniformInitParameter(shape, type, std::sqrt(6.0/fanInShape.TotalSize()), seed, device, name);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create a Parameter initialized with random values from a Uniform distribution in the range [-sqrt(6 / (fanIn + fanOut)), sqrt(6 / (fanIn + fanOut))]
|
||||
///
|
||||
static Parameter GlorotUniform(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
{
|
||||
NDShape fanOutShape = shape.SubShape(0, fanOutRank);
|
||||
NDShape fanInShape = shape.SubShape(fanOutRank);
|
||||
return UniformInitParameter(shape, type, std::sqrt(6.0 / (fanInShape.TotalSize() + fanOutShape.TotalSize())), seed, device, name);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create a Parameter initialized with random values from a Uniform distribution in the range [-sqrt(3 / fanIn), sqrt(3 / fanIn)]
|
||||
///
|
||||
static Parameter Xavier(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
{
|
||||
NDShape fanInShape = shape.SubShape(fanOutRank);
|
||||
return UniformInitParameter(shape, type, std::sqrt(3.0 / fanInShape.TotalSize()), seed, device, name);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create a Parameter initialized with random values drawn from a Gaussian distribution with [mean = 0, stdDev = sqrt(0.04 / fanIn)]
|
||||
///
|
||||
static Parameter Gaussian(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
{
|
||||
NDShape fanInShape = shape.SubShape(fanOutRank);
|
||||
return NormalInitParameter(shape, type, std::sqrt(0.04 / fanInShape.TotalSize()), seed, device, name);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create a Parameter initialized with random values from a Gaussian distribution with [mean = 0, stdDev = sqrt(2 / fanIn)]
|
||||
///
|
||||
static Parameter HeNormal(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
{
|
||||
NDShape fanInShape = shape.SubShape(fanOutRank);
|
||||
return NormalInitParameter(shape, type, std::sqrt(2.0 / fanInShape.TotalSize()), seed, device, name);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create a Parameter initialized with random values from a Gaussian distribution with [mean = 0, stdDev = sqrt(2 / (fanIn + fanOut))]
|
||||
///
|
||||
static Parameter GlorotNormal(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
|
||||
{
|
||||
NDShape fanOutShape = shape.SubShape(0, fanOutRank);
|
||||
NDShape fanInShape = shape.SubShape(fanOutRank);
|
||||
return NormalInitParameter(shape, type, std::sqrt(2.0 / (fanInShape.TotalSize() + fanOutShape.TotalSize())), seed, device, name);
|
||||
}
|
||||
|
||||
///
|
||||
/// DownCast a Variable to a Parameter. Only allowed if the VariableKind is Parameter and throws an exception otherwise.
|
||||
///
|
||||
|
@ -1080,6 +1164,12 @@ namespace CNTK
|
|||
{
|
||||
return Variable::Value();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// Helper methods for Parameter construction
|
||||
CNTK_API static Parameter UniformInitParameter(const NDShape& shape, DataType type, double range, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name);
|
||||
CNTK_API static Parameter NormalInitParameter(const NDShape& shape, DataType type, double stdDev, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name);
|
||||
};
|
||||
|
||||
// Implementation note: The Variable type is a value type and not polymorphic in nature.
|
||||
|
@ -1154,8 +1244,8 @@ namespace CNTK
|
|||
///
|
||||
/// Contruct a Placeholder with the specified NDShape
|
||||
///
|
||||
explicit Placeholder(const NDShape& shape, const std::wstring& name = L"")
|
||||
: Variable(shape, VariableKind::Placeholder, DataType::Unknown, nullptr, false, { Axis::DefaultDynamicAxis(), Axis::DefaultBatchAxis() }, name)
|
||||
explicit Placeholder(const NDShape& shape, const std::vector<Axis>& dynamicAxes = DefaultInputVariableDynamicAxes)
|
||||
: Variable(shape, VariableKind::Placeholder, DataType::Unknown, nullptr, false, dynamicAxes, L"")
|
||||
{}
|
||||
|
||||
///
|
||||
|
@ -1427,16 +1517,8 @@ namespace CNTK
|
|||
if (uniqueOutputs.find(outputVar) != uniqueOutputs.end())
|
||||
RuntimeError("Same variable appears multiple times in the outputs vector passed to Function constructor");
|
||||
|
||||
switch (outputVar.Kind())
|
||||
{
|
||||
case VariableKind::Output:
|
||||
m_outputs.push_back(outputVar);
|
||||
uniqueOutputs.insert(outputVar);
|
||||
break;
|
||||
default:
|
||||
InvalidArgument("Function output has invalid VariableKind!");
|
||||
break;
|
||||
}
|
||||
m_outputs.push_back(outputVar);
|
||||
uniqueOutputs.insert(outputVar);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1454,6 +1536,14 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr Negate(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Unary negation operator corresponding to the Negate operation
|
||||
///
|
||||
inline FunctionPtr operator-(const Variable& operand)
|
||||
{
|
||||
return Negate(operand);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise sigmoid operation with the specified input operand.
|
||||
///
|
||||
|
@ -1555,11 +1645,27 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Binary addition operator corresponding to the Plus operation
|
||||
///
|
||||
inline FunctionPtr operator+(const Variable& leftOperand, const Variable& rightOperand)
|
||||
{
|
||||
return Plus(leftOperand, rightOperand);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise tensor subtraction operation with the specified input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr Minus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
|
||||
///
|
||||
/// Binary minus operator corresponding to the Minus operation
|
||||
///
|
||||
inline FunctionPtr operator-(const Variable& leftOperand, const Variable& rightOperand)
|
||||
{
|
||||
return Minus(leftOperand, rightOperand);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands.
|
||||
///
|
||||
|
@ -1733,6 +1839,11 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr Clip(const Variable& operand, const Variable& min, const Variable& max, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise choice operation using a condition tensor for specified tensor operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementSelect(const Variable& condition, const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in splice operation to splice together all the specified tensor operands into a single output tensor
|
||||
///
|
||||
|
@ -1746,6 +1857,21 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name = L"");
|
||||
|
||||
namespace Sequence
|
||||
{
|
||||
CNTK_API FunctionPtr IsFirst(const Variable& operand, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr IsLast(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr First(const Variable& operand, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Last(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr Where(const Variable& condition, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr BroadcastAs(const Variable& operand, const Variable& broadcastAs, const std::wstring& name = L"");
|
||||
}
|
||||
|
||||
///
|
||||
/// Load a legacy CNTK v1 format model
|
||||
///
|
||||
|
@ -1859,9 +1985,9 @@ namespace CNTK
|
|||
{
|
||||
static_assert(std::is_same<T, NDShape>::value ||
|
||||
std::is_same<T, Axis>::value ||
|
||||
std::is_same<T, std::wstring>::value ||
|
||||
std::is_same<T, std::vector<DictionaryValue>>::value ||
|
||||
std::is_same<T, Dictionary>::value ||
|
||||
std::is_same<T, std::wstring>::value ||
|
||||
std::is_same<T, std::vector<DictionaryValue>>::value ||
|
||||
std::is_same<T, Dictionary>::value ||
|
||||
std::is_same<T, NDArrayView>::value,
|
||||
"Unsupported ValueType");
|
||||
|
||||
|
@ -2279,35 +2405,45 @@ namespace CNTK
|
|||
/// Create an instance of the CNTK built-in SGD learner.
|
||||
///
|
||||
CNTK_API LearnerPtr SGDLearner(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates);
|
||||
const LearningRatesPerSample& learningRates,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in Momentum SGD learner.
|
||||
///
|
||||
CNTK_API LearnerPtr MomentumSGDLearner(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums);
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in Nesterov's accelerated SGD learner.
|
||||
///
|
||||
CNTK_API LearnerPtr NesterovLearner(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in AdaGrad learner.
|
||||
///
|
||||
CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool needAveMultiplier = true);
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in FSAdaGrad (improved AdaGrad) learner.
|
||||
///
|
||||
CNTK_API LearnerPtr FSAdaGradLearner(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums);
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in AdaGrad learner.
|
||||
///
|
||||
CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool needAveMultiplier = true,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in RMSProp learner.
|
||||
|
@ -2319,7 +2455,9 @@ namespace CNTK
|
|||
double dec,
|
||||
double max,
|
||||
double min,
|
||||
bool needAveMultiplier = true);
|
||||
bool needAveMultiplier = true,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
///
|
||||
/// Trainer is the top-level abstraction responsible for the orchestration of the training of a model
|
||||
|
@ -2333,7 +2471,15 @@ namespace CNTK
|
|||
/// Construct a Trainer to train the specified 'model' with the specified 'trainingLoss' Variable as the training criterion
|
||||
/// and using the specified set of 'parameterLearners' for updating the model's parameters using computed gradients.
|
||||
///
|
||||
CNTK_API Trainer(const FunctionPtr& model, const Variable& trainingLoss, const std::unordered_set<LearnerPtr>& parameterLearners);
|
||||
CNTK_API Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const std::unordered_set<LearnerPtr>& parameterLearners);
|
||||
|
||||
///
|
||||
/// Construct a Trainer to train the specified 'model' with the specified 'trainingLoss' as the training criterion,
|
||||
/// the specified 'evaluationFunction' as the criterion for evaluating the trained model's quality, and using the specified set
|
||||
/// of 'parameterLearners' for updating the model's parameters using computed gradients.
|
||||
///
|
||||
// TODO: Add overload for multiple evaluation criterion
|
||||
CNTK_API Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const FunctionPtr& evaluationFunction, const std::unordered_set<LearnerPtr>& parameterLearners);
|
||||
|
||||
///
|
||||
/// Optimize model parameters using the specified 'arguments' minibatch of training samples.
|
||||
|
@ -2341,26 +2487,41 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// Test the model on the specified batch of samples using the evaluation Function specified during construction of the Trainer
|
||||
/// Returns the average evaluation criterion value per sample for the tested minibatch of samples
|
||||
///
|
||||
CNTK_API double TestMinbatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
///
|
||||
/// Model being trained by 'this' Trainer.
|
||||
///
|
||||
FunctionPtr Model() const { return m_model; }
|
||||
|
||||
///
|
||||
/// Variable of the Trainer's model representing the training loss that is used as the optimization
|
||||
/// criterion for learning the model's parameters.
|
||||
/// Loss function that is used as the optimization criterion for learning the model's parameters.
|
||||
///
|
||||
Variable TrainingLossVariable() const { return m_trainingLossVar; }
|
||||
FunctionPtr LossFunction() const { return m_lossFunction; }
|
||||
|
||||
///
|
||||
/// Returns the Value of the training loss variable of the model corresponding to the last minibatch trained with
|
||||
/// Evaluation Function that is used as for the criterion for evaluating the trained model's quality.
|
||||
///
|
||||
ValuePtr PreviousMinibatchTrainingLossValue() const { return m_prevMinibatchTrainingLossValue; }
|
||||
FunctionPtr EvaluationFunction() const { return m_evaluationFunction; }
|
||||
|
||||
///
|
||||
/// Returns the training loss corresponding to the last minibatch trained as a double
|
||||
/// Returns the average training loss per sample for the last minibatch trained.
|
||||
///
|
||||
CNTK_API double PreviousMinibatchAverageTrainingLoss() const;
|
||||
CNTK_API double PreviousMinibatchLossAverage() const;
|
||||
|
||||
///
|
||||
/// Returns the average evaluation criterion value per sample for the last minibatch trained.
|
||||
///
|
||||
CNTK_API double PreviousMinibatchEvaluationAverage() const;
|
||||
|
||||
///
|
||||
/// Returns the number of samples in the last minibatch trained with
|
||||
///
|
||||
size_t PreviousMinibatchSampleCount() const { return m_prevMinibatchNumSamples; }
|
||||
|
||||
///
|
||||
/// Learners associated with this Trainer for updating the model's parameters using computed gradients.
|
||||
|
@ -2368,11 +2529,16 @@ namespace CNTK
|
|||
const std::unordered_set<LearnerPtr>& ParameterLearners() const { return m_parameterLearners; }
|
||||
|
||||
private:
|
||||
FunctionPtr m_combinedTrainingFunction;
|
||||
FunctionPtr m_model;
|
||||
Variable m_trainingLossVar;
|
||||
ValuePtr m_prevMinibatchTrainingLossValue;
|
||||
size_t m_prevMinibatchNumSamples;
|
||||
FunctionPtr m_lossFunction;
|
||||
FunctionPtr m_evaluationFunction;
|
||||
|
||||
std::unordered_set<LearnerPtr> m_parameterLearners;
|
||||
|
||||
size_t m_prevMinibatchNumSamples;
|
||||
ValuePtr m_prevMinibatchAggregateTrainingLossValue;
|
||||
ValuePtr m_prevMinibatchAggregateEvalCriterionValue;
|
||||
};
|
||||
|
||||
///
|
||||
|
|
|
@ -182,11 +182,17 @@ namespace CNTK
|
|||
|
||||
namespace Internal
|
||||
{
|
||||
// Create a new Function instance which just passes through specified list of 'operands'.
|
||||
CNTK_API FunctionPtr Combine(const std::vector<Variable>& operands, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr GatherPacked(const Variable& operand, const Variable& packedIndex, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ScatterPacked(const Variable& operand, const Variable& packedIndex, const Variable& condition, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ZeroesLike(const Variable& operand);
|
||||
CNTK_API FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L"");
|
||||
}
|
||||
|
|
|
@ -54,12 +54,11 @@ namespace CNTK
|
|||
{
|
||||
if (node->Is<InputValueBase<ElementType>>())
|
||||
{
|
||||
auto inputNode = node->As<InputValueBase<ElementType>>();
|
||||
bool isSparse = node->Is<SparseInputValue<ElementType>>();
|
||||
if (node->HasMBLayout())
|
||||
{
|
||||
// TODO: Currently only default dynamic axis is supported
|
||||
auto inputNodeInternalDynamicAxisName = inputNode->GetRequestedDynamicAxis();
|
||||
auto inputNodeInternalDynamicAxisName = node->GetMBLayout()->GetAxisName();
|
||||
std::vector<Axis> inputVarDynamicAxes = DynamicAxesFromInternalDynamicAxisName(inputNodeInternalDynamicAxisName);
|
||||
|
||||
var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->GetName(), inputVarDynamicAxes);
|
||||
|
@ -74,7 +73,7 @@ namespace CNTK
|
|||
{
|
||||
bool isConstant = (node->GetLearningRateMultiplier() == 0);
|
||||
auto& matrix = node->As<ComputationNode<ElementType>>()->Value();
|
||||
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), node->GetSampleLayout());
|
||||
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorViewShape(node->GetSampleLayout()));
|
||||
NDArrayViewPtr parameterValue = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView);
|
||||
if (isConstant)
|
||||
var = Constant(parameterValue, node->GetName());
|
||||
|
@ -87,7 +86,14 @@ namespace CNTK
|
|||
else
|
||||
{
|
||||
// This is a non-leaf node and maps to a primitive Function
|
||||
auto placeholderVar = Placeholder(varShape);
|
||||
std::vector<Axis> varDynamicAxes;;
|
||||
if (node->HasMBLayout())
|
||||
{
|
||||
auto nodeInternalDynamicAxisName = node->GetMBLayout()->GetAxisName();
|
||||
varDynamicAxes = DynamicAxesFromInternalDynamicAxisName(nodeInternalDynamicAxisName);
|
||||
}
|
||||
|
||||
auto placeholderVar = Placeholder(varShape, varDynamicAxes);
|
||||
nodeToVariableMap[node] = placeholderVar;
|
||||
|
||||
std::vector<Variable> inputVars(node->GetNumInputs());
|
||||
|
@ -134,14 +140,9 @@ namespace CNTK
|
|||
}
|
||||
else if (node->OperationName() == OperationNameOf(WhereNode))
|
||||
{
|
||||
auto whereNode = node->As<WhereNode<ElementType>>();
|
||||
auto internalDynamicAxisName = whereNode->DynamicAxisName();
|
||||
auto internalDynamicAxisName = node->GetMBLayout()->GetAxisName();
|
||||
std::vector<Axis> dynamicAxes = DynamicAxesFromInternalDynamicAxisName(internalDynamicAxisName);
|
||||
std::vector<std::wstring> dynamicAxesNames;
|
||||
for (auto axis : dynamicAxes)
|
||||
dynamicAxesNames.push_back(axis.Name());
|
||||
|
||||
primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(dynamicAxesNames);
|
||||
primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(dynamicAxes);
|
||||
|
||||
opType = PrimitiveOpType::Where;
|
||||
}
|
||||
|
@ -196,6 +197,13 @@ namespace CNTK
|
|||
std::swap(inputVars[0], inputVars[1]);
|
||||
opType = PrimitiveOpType::GatherPacked;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(ScatterPackedNode))
|
||||
{
|
||||
// The internal scatter node has layout as the first input and the actual source as the last operand
|
||||
// which is different from the corresponding V2 Function's ordering of the inputs
|
||||
std::swap(inputVars[0], inputVars[2]);
|
||||
opType = PrimitiveOpType::ScatterPacked;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(TimesNode))
|
||||
{
|
||||
primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameOutputRank] = (size_t)node->As<TimesNode<ElementType>>()->OutputRank();
|
||||
|
@ -296,6 +304,8 @@ namespace CNTK
|
|||
|
||||
opType = PrimitiveOpType::Clip;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(IfNode))
|
||||
opType = PrimitiveOpType::Select;
|
||||
else if (node->OperationName() == OperationNameOf(RowStackNode))
|
||||
{
|
||||
// Internal CNTK SliceNode uses 1 based axis indices instead of 0 based
|
||||
|
|
|
@ -32,6 +32,8 @@ namespace CNTK
|
|||
|
||||
/*static*/ const std::wstring Axis::StaticAxisNamePrefix = L"staticAxis_";
|
||||
|
||||
/*static*/ std::unordered_set<std::wstring> Axis::s_allKnownDynamicAxisNames;
|
||||
|
||||
/*static*/ const Axis& Axis::DefaultDynamicAxis()
|
||||
{
|
||||
static const Axis s_defaultDynamicAxis(L"defaultDynamicAxis");
|
||||
|
@ -43,4 +45,22 @@ namespace CNTK
|
|||
static const Axis s_defaultBatchAxis(L"defaultBatchAxis", false);
|
||||
return s_defaultBatchAxis;
|
||||
}
|
||||
|
||||
/*static*/ Axis Axis::NewUniqueDynamicAxis(const std::wstring& axisNamePrefix, bool isOrderedDynamicAxis /*= true*/)
|
||||
{
|
||||
if (s_allKnownDynamicAxisNames.find(axisNamePrefix) == s_allKnownDynamicAxisNames.end())
|
||||
return Axis(axisNamePrefix, isOrderedDynamicAxis);
|
||||
|
||||
for (size_t i = 1;; i++)
|
||||
{
|
||||
auto newDynamicAxisName = axisNamePrefix + std::to_wstring(i);
|
||||
if (s_allKnownDynamicAxisNames.find(newDynamicAxisName) == s_allKnownDynamicAxisNames.end())
|
||||
return Axis(newDynamicAxisName, isOrderedDynamicAxis);
|
||||
}
|
||||
}
|
||||
|
||||
void Axis::RegisterAxisName(const std::wstring& axisName)
|
||||
{
|
||||
s_allKnownDynamicAxisNames.insert(axisName);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -128,11 +128,11 @@ namespace CNTK
|
|||
// We currently require that the inputs' dynamic axes if any match
|
||||
std::vector<Axis> outputDynamicAxes;
|
||||
if (op == PrimitiveOpType::Where)
|
||||
;
|
||||
outputDynamicAxes = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value<std::vector<DictionaryValue>>());
|
||||
else if (op == PrimitiveOpType::ScatterPacked)
|
||||
outputDynamicAxes = inputs[2].DynamicAxes();
|
||||
else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked))
|
||||
{
|
||||
outputDynamicAxes = inputs[1].DynamicAxes();
|
||||
}
|
||||
else
|
||||
{
|
||||
outputDynamicAxes = inputs[0].DynamicAxes();
|
||||
|
@ -178,7 +178,7 @@ namespace CNTK
|
|||
if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis())
|
||||
LogicError("TransposeAxes operation currently does not support transposing dynamic axes");
|
||||
|
||||
auto transposedTensorShape = AsTensorShape(inputs[0].Shape(), true);
|
||||
auto transposedTensorShape = AsTensorShape(inputs[0].Shape());
|
||||
transposedTensorShape.SwapDimsInPlace(axis1.StaticAxisIndex(), axis2.StaticAxisIndex());
|
||||
outputs.push_back(Variable(AsNDShape(transposedTensorShape), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
|
@ -186,12 +186,7 @@ namespace CNTK
|
|||
case PrimitiveOpType::Where:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<Axis> newDynamicAxes;
|
||||
auto newDynamicAxesNames = AsBasicElementTypeVector<std::wstring>(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value<std::vector<DictionaryValue>>());
|
||||
for (auto axisName : newDynamicAxesNames)
|
||||
newDynamicAxes.push_back(Axis(axisName));
|
||||
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, newDynamicAxes));
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Slice:
|
||||
|
@ -216,7 +211,7 @@ namespace CNTK
|
|||
realEndIndex,
|
||||
inputs[0].Shape().AsString().c_str());
|
||||
|
||||
auto outputTensorShape = AsTensorShape(inputs[0].Shape(), true);
|
||||
auto outputTensorShape = AsTensorShape(inputs[0].Shape());
|
||||
|
||||
// propagate as much as we can
|
||||
if ((axis.StaticAxisIndex() < outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim))
|
||||
|
@ -242,7 +237,7 @@ namespace CNTK
|
|||
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
||||
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
||||
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
||||
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
outputs.push_back(Variable(ConvolutionOpOutputShape(inputs[0].Shape(), poolingWindowsShape, { 1 }, strides, { true }, autoPadding, lowerPad, upperPad, false), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
}
|
||||
|
@ -291,8 +286,8 @@ namespace CNTK
|
|||
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
||||
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
||||
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
||||
auto sharing = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
|
||||
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
|
||||
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
|
||||
if (inputs[0].Shape().NumAxes() < inputs[1].Shape().NumAxes())
|
||||
InvalidArgument("The convolution map should have at least as many axes as the shape of the input it operates on!");
|
||||
|
@ -333,10 +328,6 @@ namespace CNTK
|
|||
if (!initialStateVar.IsConstant() || (initialStateVar.Shape().NumAxes() > 0))
|
||||
LogicError("Currently PastValue/FutureValue Function only supports scalar initial state");
|
||||
|
||||
// TODO: We currently only support input operand with 1 static axis for PastValue/FutureValue
|
||||
if (inputOperandVar.Shape().NumAxes() > 1)
|
||||
LogicError("Currently PastValue/FutureValue Function only supports input operand with <= 1 static axis");
|
||||
|
||||
// TODO: We currently only support input operand with 1 dynamic axis for PastValue/FutureValue
|
||||
if (inputOperandVar.DynamicAxes().size() != 2)
|
||||
LogicError("Currently PastValue/FutureValue Function only supports input operand with with 2 dynamic axis (1 sequence-axis and 1 batch-axis)");
|
||||
|
@ -382,10 +373,25 @@ namespace CNTK
|
|||
outputs.push_back(Variable(outputShape, outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ScatterPacked:
|
||||
{
|
||||
if (inputs[0].DynamicAxes().empty() || inputs[1].DynamicAxes().empty() || inputs[2].DynamicAxes().empty())
|
||||
InvalidArgument("ScatterPacked requires all its operands to have dynamic axes");
|
||||
|
||||
if (inputs[1].Shape().NumAxes() != 1)
|
||||
InvalidArgument("ScatterPacked requires the packedIndex operand to be a scalar sequence");
|
||||
|
||||
outputs.push_back(Variable(inputs[0].Shape(), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Clip:
|
||||
assert(inputs.size() == 3);
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::Select:
|
||||
assert(inputs.size() == 3);
|
||||
outputs.push_back(Variable(NaryElementwiseOpOutputShape(op, { inputs[0].Shape(), inputs[1].Shape(), inputs[2].Shape() }), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::Splice:
|
||||
{
|
||||
assert(inputs.size() >= 2);
|
||||
|
@ -405,8 +411,8 @@ namespace CNTK
|
|||
// Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over)
|
||||
// and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed
|
||||
// when the new CNTK v2 model serialization format is ready.
|
||||
/*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"";
|
||||
/*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"noSequenceAxis";
|
||||
/*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*";
|
||||
/*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis";
|
||||
|
||||
// Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables
|
||||
// should have been replaced before performing any Forward compute of 'this' Function.
|
||||
|
@ -486,7 +492,7 @@ namespace CNTK
|
|||
// Construct the dynamic axis name to be used internally for the CNTK InputNodes
|
||||
std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
|
||||
|
||||
if (!internalDynamicAxisName.empty())
|
||||
if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
|
||||
network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});
|
||||
|
||||
if (IsSparseInput(variable))
|
||||
|
@ -524,60 +530,60 @@ namespace CNTK
|
|||
|
||||
auto functionName = primitiveFunction->Name();
|
||||
auto& functionConfig = primitiveFunction->FunctionConfig();
|
||||
auto functionInputs = primitiveFunction->Inputs();
|
||||
auto functionInputs = primitiveFunction->Inputs();
|
||||
PrimitiveOpType op = primitiveFunction->OpType();
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case PrimitiveOpType::Negate:
|
||||
switch (op)
|
||||
{
|
||||
case PrimitiveOpType::Negate:
|
||||
computationNodePtr = builder.Negate(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Sigmoid:
|
||||
break;
|
||||
case PrimitiveOpType::Sigmoid:
|
||||
computationNodePtr = builder.Sigmoid(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Tanh:
|
||||
break;
|
||||
case PrimitiveOpType::Tanh:
|
||||
computationNodePtr = builder.Tanh(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::ReLU:
|
||||
break;
|
||||
case PrimitiveOpType::ReLU:
|
||||
computationNodePtr = builder.RectifiedLinear(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Exp:
|
||||
break;
|
||||
case PrimitiveOpType::Exp:
|
||||
computationNodePtr = builder.Exp(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Log:
|
||||
break;
|
||||
case PrimitiveOpType::Log:
|
||||
computationNodePtr = builder.Log(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Sqrt:
|
||||
break;
|
||||
case PrimitiveOpType::Sqrt:
|
||||
computationNodePtr = builder.Sqrt(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Floor:
|
||||
break;
|
||||
case PrimitiveOpType::Floor:
|
||||
computationNodePtr = builder.Floor(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Abs:
|
||||
break;
|
||||
case PrimitiveOpType::Abs:
|
||||
computationNodePtr = builder.Abs(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Reciprocal:
|
||||
break;
|
||||
case PrimitiveOpType::Reciprocal:
|
||||
computationNodePtr = builder.Reciprocal(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Softmax:
|
||||
break;
|
||||
case PrimitiveOpType::Softmax:
|
||||
computationNodePtr = builder.Softmax(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Hardmax:
|
||||
break;
|
||||
case PrimitiveOpType::Hardmax:
|
||||
computationNodePtr = builder.Hardmax(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::TransposeAxes:
|
||||
{
|
||||
break;
|
||||
case PrimitiveOpType::TransposeAxes:
|
||||
{
|
||||
auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
|
||||
auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();
|
||||
|
||||
// The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
|
||||
// The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
|
||||
computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), functionName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0] });
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Where:
|
||||
{
|
||||
auto dynamicAxes = variable.DynamicAxes();
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Where:
|
||||
{
|
||||
auto dynamicAxes = variable.DynamicAxes();
|
||||
auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
|
||||
computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), functionName, internalCNTKWhereNodeDynamicAxisName);
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0] });
|
||||
|
@ -604,141 +610,144 @@ namespace CNTK
|
|||
case PrimitiveOpType::Reshape:
|
||||
{
|
||||
auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
|
||||
computationNodePtr = builder.Reshape(inputNodes[0], AsTensorShape(newShape, true /*preserveRank*/), functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Pooling:
|
||||
{
|
||||
computationNodePtr = builder.Reshape(inputNodes[0], AsTensorShape(newShape), functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Pooling:
|
||||
{
|
||||
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
|
||||
auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
|
||||
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
||||
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
||||
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
||||
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
computationNodePtr = builder.Pooling(inputNodes[0], AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape, true), AsTensorShape(strides, true), autoPadding, AsTensorShape(lowerPad, true), AsTensorShape(upperPad, true), ImageLayoutKind::CHW, functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::SumAll:
|
||||
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
computationNodePtr = builder.Pooling(inputNodes[0], AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW, functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::SumAll:
|
||||
computationNodePtr = builder.Sum(inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Plus:
|
||||
break;
|
||||
case PrimitiveOpType::Plus:
|
||||
computationNodePtr = builder.Plus(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Minus:
|
||||
break;
|
||||
case PrimitiveOpType::Minus:
|
||||
computationNodePtr = builder.Minus(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
break;
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
computationNodePtr = builder.ElementTimes(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Equal:
|
||||
break;
|
||||
case PrimitiveOpType::Equal:
|
||||
computationNodePtr = builder.Equal(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::NotEqual:
|
||||
break;
|
||||
case PrimitiveOpType::NotEqual:
|
||||
computationNodePtr = builder.NotEqual(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Less:
|
||||
break;
|
||||
case PrimitiveOpType::Less:
|
||||
computationNodePtr = builder.Less(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::LessEqual:
|
||||
break;
|
||||
case PrimitiveOpType::LessEqual:
|
||||
computationNodePtr = builder.LessEqual(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Greater:
|
||||
break;
|
||||
case PrimitiveOpType::Greater:
|
||||
computationNodePtr = builder.Greater(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::GreaterEqual:
|
||||
computationNodePtr = builder.GreaterEqual(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Times:
|
||||
{
|
||||
break;
|
||||
case PrimitiveOpType::GreaterEqual:
|
||||
computationNodePtr = builder.GreaterEqual(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Times:
|
||||
{
|
||||
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
|
||||
computationNodePtr = builder.Times(inputNodes[0], inputNodes[1], outputRank, functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::TransposeTimes:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::TransposeTimes:
|
||||
{
|
||||
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
|
||||
computationNodePtr = network->AddNodeToNetAndAttachInputs(New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), functionName, outputRank), { inputNodes[0], inputNodes[1] });
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Convolution:
|
||||
{
|
||||
NDShape outputMapCount, kernelShape;
|
||||
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape());
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Convolution:
|
||||
{
|
||||
NDShape outputMapCount, kernelShape;
|
||||
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape());
|
||||
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
||||
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
||||
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
||||
auto sharing = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
|
||||
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
|
||||
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||
auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
|
||||
auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
|
||||
computationNodePtr = builder.Convolution(inputNodes[0], inputNodes[1], AsTensorShape(kernelShape, true), AsTensorShape(outputMapCount, true), AsTensorShape(strides, true), sharing, autoPadding, AsTensorShape(lowerPad, true), AsTensorShape(upperPad, true), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples, functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::SquaredError:
|
||||
computationNodePtr = builder.Convolution(inputNodes[0], inputNodes[1], AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples, functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::SquaredError:
|
||||
computationNodePtr = builder.SquareError(inputNodes[0], inputNodes[1], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
break;
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
computationNodePtr = builder.CrossEntropyWithSoftmax(inputNodes[1], inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::ClassificationError:
|
||||
break;
|
||||
case PrimitiveOpType::ClassificationError:
|
||||
computationNodePtr = builder.ClassificationError(inputNodes[1], inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
{
|
||||
break;
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
{
|
||||
Variable inputOperandVar = functionInputs[0];
|
||||
Variable initialStateVar = functionInputs[1];
|
||||
|
||||
// Get the intial state of the PastValue/FutureValue operation
|
||||
ElementType initStateValue;
|
||||
NDArrayView tempView({}, &initStateValue, 1, DeviceDescriptor::CPUDevice());
|
||||
tempView.CopyFrom(*Constant(initialStateVar).Value());
|
||||
// Get the intial state of the PastValue/FutureValue operation
|
||||
ElementType initStateValue;
|
||||
NDArrayView tempView({}, &initStateValue, 1, DeviceDescriptor::CPUDevice());
|
||||
tempView.CopyFrom(*Constant(initialStateVar).Value());
|
||||
|
||||
size_t offset = primitiveFunction->FunctionConfig()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
|
||||
if (op == PrimitiveOpType::PastValue)
|
||||
if (op == PrimitiveOpType::PastValue)
|
||||
computationNodePtr = builder.PastValue(inputNodes[0], (float)initStateValue, inputOperandVar.Shape().TotalSize(), offset, functionName);
|
||||
else
|
||||
else
|
||||
computationNodePtr = builder.FutureValue(inputNodes[0], (float)initStateValue, inputOperandVar.Shape().TotalSize(), offset, functionName);
|
||||
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ReduceElements:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ReduceElements:
|
||||
{
|
||||
auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
|
||||
auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
|
||||
computationNodePtr = network->AddNodeToNetAndAttachInputs(New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), functionName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)), { inputNodes[0] });
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::BatchNormalization:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::BatchNormalization:
|
||||
{
|
||||
auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
|
||||
auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
|
||||
auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
|
||||
auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
|
||||
auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();
|
||||
computationNodePtr = builder.BatchNormalization(inputNodes[0], inputNodes[1], inputNodes[2], inputNodes[3], inputNodes[4], spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW, functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Combine:
|
||||
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
|
||||
// whose outputs are a union of the outputs of the Functions being combined.
|
||||
|
||||
computationNodePtr = variableToNodeMap[variable];
|
||||
|
||||
break;
|
||||
case PrimitiveOpType::PackedIndex:
|
||||
computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), functionName);
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0], inputNodes[1] });
|
||||
break;
|
||||
case PrimitiveOpType::GatherPacked:
|
||||
computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), functionName);
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[1], inputNodes[0] });
|
||||
break;
|
||||
case PrimitiveOpType::Clip:
|
||||
{
|
||||
computationNodePtr = builder.Clip(inputNodes[1], inputNodes[2], inputNodes[0], functionName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::Combine:
|
||||
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
|
||||
// whose outputs are a union of the outputs of the Functions being combined.
|
||||
computationNodePtr = variableToNodeMap[variable];
|
||||
break;
|
||||
case PrimitiveOpType::PackedIndex:
|
||||
computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), functionName);
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0], inputNodes[1] });
|
||||
break;
|
||||
case PrimitiveOpType::GatherPacked:
|
||||
computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), functionName);
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[1], inputNodes[0] });
|
||||
break;
|
||||
case PrimitiveOpType::ScatterPacked:
|
||||
computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), functionName);
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[2], inputNodes[1], inputNodes[0] });
|
||||
break;
|
||||
case PrimitiveOpType::Clip:
|
||||
computationNodePtr = builder.Clip(inputNodes[1], inputNodes[2], inputNodes[0], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Select:
|
||||
computationNodePtr = builder.If(inputNodes[0], inputNodes[1], inputNodes[2], functionName);
|
||||
break;
|
||||
case PrimitiveOpType::Splice:
|
||||
{
|
||||
Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
|
||||
|
@ -750,12 +759,12 @@ namespace CNTK
|
|||
inputNodesBasePtrs.push_back(inputNode);
|
||||
|
||||
network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Specified op %s not yet supported", PrimitiveOpTypeName(op));
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LogicError("Specified op %s not yet supported", PrimitiveOpTypeName(op));
|
||||
break;
|
||||
}
|
||||
|
||||
return computationNodePtr;
|
||||
}
|
||||
|
@ -881,7 +890,7 @@ namespace CNTK
|
|||
auto outputShape = outputVar.Shape();
|
||||
auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
|
||||
if (((outputShape.NumAxes() == 0) && (computationNodeSampleLayout[0] != 1)) ||
|
||||
((outputShape.NumAxes() != 0) && (computationNodeSampleLayout != AsTensorShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape, true))))
|
||||
((outputShape.NumAxes() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape))))
|
||||
{
|
||||
LogicError("The output Variable shape %s does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", AsString(outputShape).c_str(), ((std::string)computationNodeSampleLayout).c_str());
|
||||
}
|
||||
|
@ -1026,7 +1035,7 @@ namespace CNTK
|
|||
if ((layout == nullptr) || (layout->GetNumTimeSteps() == 1) || (layout->GetNumSequences() == 1))
|
||||
{
|
||||
// Just create a view over the existing matrix itself
|
||||
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorShape(valueDataShape));
|
||||
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorViewShape(valueDataShape));
|
||||
auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), valueDataShape, readOnly, tensorView);
|
||||
return MakeSharedObject<Value>(data);
|
||||
}
|
||||
|
@ -1085,7 +1094,7 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
auto tensorView = new TensorView<ElementType>(shuffledMatrixData, AsTensorShape(valueDataShape));
|
||||
auto tensorView = new TensorView<ElementType>(shuffledMatrixData, AsTensorViewShape(valueDataShape));
|
||||
auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(shuffledMatrixData->GetFormat()), valueDataShape, readOnly, tensorView);
|
||||
return MakeSharedObject<Value>(data, mask);
|
||||
}
|
||||
|
@ -1114,21 +1123,16 @@ namespace CNTK
|
|||
auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
|
||||
|
||||
// Switch the node matrix to the right matrix type
|
||||
nodeData.SwitchToMatrixType(CNTKMatrixAndMBLayout.first->GetMatrixType(), CNTKMatrixAndMBLayout.first->GetFormat(), false);
|
||||
nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
|
||||
computationNode->GetMBLayout()->CopyFrom(layout);
|
||||
}
|
||||
|
||||
void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
|
||||
{
|
||||
auto functionArguments = this->Arguments();
|
||||
std::vector<ComputationNodeBasePtr> inputNodes;
|
||||
for (auto argument : functionArguments)
|
||||
for (auto argumentValuePair : arguments)
|
||||
{
|
||||
// Ensure we have values for all arguments of the function
|
||||
if (arguments.find(argument) == arguments.end())
|
||||
InvalidArgument("Value not specified for required Function Argument");
|
||||
|
||||
auto argument = argumentValuePair.first;
|
||||
auto argumentComputationNode = m_variableToNodeMap[argument];
|
||||
inputNodes.push_back(argumentComputationNode);
|
||||
|
||||
|
@ -1300,6 +1304,7 @@ namespace CNTK
|
|||
list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode));
|
||||
for (auto& nodeIter : dropoutNodes)
|
||||
nodeIter->SetEvalTimeStampOutdatedWrtAll();
|
||||
|
||||
std::unordered_set<Variable> functionOutputs(this->Outputs().begin(), this->Outputs().end());
|
||||
std::vector<ComputationNodeBasePtr> outputsToEvaluate;
|
||||
|
||||
|
@ -1316,10 +1321,15 @@ namespace CNTK
|
|||
// The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
|
||||
for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
|
||||
{
|
||||
if (functionOutputs.find(rootVarForBackprop) == functionOutputs.end())
|
||||
InvalidArgument("Requested outputs to retain backward state for is not an Ouptut of the Function");
|
||||
|
||||
if (outputs.find(rootVarForBackprop) == outputs.end())
|
||||
outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]);
|
||||
}
|
||||
|
||||
// TODO: Verify that values were supplied for all inputs that requested outputs depend on
|
||||
|
||||
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);
|
||||
|
||||
m_computationNetwork->ForwardProp(outputsToEvaluate);
|
||||
|
@ -1467,14 +1477,19 @@ namespace CNTK
|
|||
}
|
||||
FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
if ((endIndex - beginIndex) <= 0)
|
||||
InvalidArgument("CNTK::Slice: endIndex (%d) - beginIndex (%d) must be a positive number", endIndex, beginIndex);
|
||||
|
||||
if (axis == Axis::DefaultBatchAxis())
|
||||
LogicError("Slice is currently unsupported along the batch axis");
|
||||
|
||||
if (axis.IsStaticAxis())
|
||||
{
|
||||
if ((endIndex - beginIndex) <= 0)
|
||||
InvalidArgument("CNTK::Slice: endIndex (%d) - beginIndex (%d) must be a positive number", endIndex, beginIndex);
|
||||
|
||||
return Internal::Slice(operand, axis, beginIndex, endIndex, name);
|
||||
}
|
||||
|
||||
if ((beginIndex == 0) && (endIndex == 0))
|
||||
return operand;
|
||||
|
||||
auto operandAxes = operand.DynamicAxes();
|
||||
auto findAxis = std::find(operandAxes.begin(), operandAxes.end(), axis);
|
||||
|
@ -1504,7 +1519,7 @@ namespace CNTK
|
|||
if (operandAxis == axis)
|
||||
{
|
||||
// If we are selecting just one frame from the dynamic axis, we can remove that axis
|
||||
if ((endIndex - beginIndex) > 1)
|
||||
if ((endIndex - beginIndex) != 1)
|
||||
newDynamicAxes.push_back(CompositeFunction::NextAutoGeneratedDynamicAxis());
|
||||
}
|
||||
else
|
||||
|
@ -1746,6 +1761,12 @@ namespace CNTK
|
|||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Clip, std::vector<Variable>({ operand, min, max }), Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementSelect(const Variable& condition, const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
// TODO: If the condition is a scalar constant, we can just pass-through the appropriate operand
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Select, std::vector<Variable>({ condition, leftOperand, rightOperand }), Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr Splice(const std::vector<Variable>& operands, size_t axis, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
|
@ -1753,25 +1774,116 @@ namespace CNTK
|
|||
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Splice, operands, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
std::unordered_set<FunctionPtr> uniqueOperands;
|
||||
std::vector<Variable> inputs;
|
||||
for (auto operand : operands)
|
||||
{
|
||||
if (uniqueOperands.find(operand) != uniqueOperands.end())
|
||||
LogicError("All function operands specified to Combine must be unique");
|
||||
|
||||
uniqueOperands.insert(operand);
|
||||
auto currentFunctionOutputs = operand->Outputs();
|
||||
std::copy(currentFunctionOutputs.begin(), currentFunctionOutputs.end(), std::back_inserter(inputs));
|
||||
}
|
||||
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Combine, inputs, Dictionary(), name), name);
|
||||
return Internal::Combine(inputs);
|
||||
}
|
||||
|
||||
namespace Sequence
|
||||
{
|
||||
void VerifyIsSequence(const Variable& operand)
|
||||
{
|
||||
// The operand must have at least one dynamic axis and it's first dynamic axis must be ordered
|
||||
if (operand.DynamicAxes().empty() || !operand.DynamicAxes()[0].IsOrdered())
|
||||
InvalidArgument("A sequence function can only be applied on operands with at least one dynamic axis and whose first dynamic axis is ordered");
|
||||
}
|
||||
|
||||
FunctionPtr IsFirst(const Variable& operand, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Internal::IsWithin(operand, 1);
|
||||
}
|
||||
|
||||
FunctionPtr IsLast(const Variable& operand, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Internal::IsWithin(operand, -1);
|
||||
}
|
||||
|
||||
FunctionPtr First(const Variable& operand, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Slice(operand, operand.DynamicAxes()[0], 0, 1);
|
||||
}
|
||||
|
||||
FunctionPtr Last(const Variable& operand, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
return Slice(operand, operand.DynamicAxes()[0], -1, 0);
|
||||
}
|
||||
|
||||
std::vector<Axis> WhereOpDynamicAxes(const Variable& operand)
|
||||
{
|
||||
VerifyIsSequence(operand);
|
||||
|
||||
std::vector<Axis> newDynamicAxes = { Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis") };
|
||||
for (size_t i = 1; i < operand.DynamicAxes().size(); ++i)
|
||||
newDynamicAxes.push_back(operand.DynamicAxes()[i]);
|
||||
|
||||
return newDynamicAxes;
|
||||
}
|
||||
|
||||
FunctionPtr Where(const Variable& condition, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
return Internal::Where(condition, WhereOpDynamicAxes(condition), name);
|
||||
}
|
||||
|
||||
FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
return Internal::Gather(operand, condition, WhereOpDynamicAxes(condition), name);
|
||||
}
|
||||
|
||||
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
return Internal::Scatter(operand, condition, WhereOpDynamicAxes(condition), name);
|
||||
}
|
||||
|
||||
FunctionPtr BroadcastAs(const Variable& operand, const Variable& broadcastAs, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
auto dataPadded = Internal::Scatter(operand, Sequence::IsFirst(broadcastAs), broadcastAs.DynamicAxes());
|
||||
auto placeHolderOutput = Placeholder(operand.Shape(), broadcastAs.DynamicAxes());
|
||||
auto output = ElementSelect(Sequence::IsFirst(dataPadded), dataPadded, PastValue(placeHolderOutput, ScalarConstant(operand.GetDataType(), 0.0f), 1), name);
|
||||
return output->ReplacePlaceholders({ { placeHolderOutput, output } });
|
||||
}
|
||||
}
|
||||
|
||||
namespace Internal
|
||||
{
|
||||
FunctionPtr Combine(const std::vector<Variable>& operands, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
std::unordered_set<Variable> uniqueOperands;
|
||||
for (auto operand : operands)
|
||||
{
|
||||
if (uniqueOperands.find(operand) != uniqueOperands.end())
|
||||
LogicError("All operands specified to Combine must be unique");
|
||||
|
||||
uniqueOperands.insert(operand);
|
||||
}
|
||||
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Combine, operands, Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
Sequence::VerifyIsSequence(operand);
|
||||
|
||||
if (offset == 0)
|
||||
InvalidArgument("CNTK::Sequence::IsWithin: The offset must be positive");
|
||||
|
||||
if (offset > 0)
|
||||
return PastValue(Internal::ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), offset, name);
|
||||
else
|
||||
return FutureValue(Internal::ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), -offset, name);
|
||||
}
|
||||
|
||||
FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::PackedIndex, operand, index, Dictionary(), name);
|
||||
|
@ -1782,34 +1894,36 @@ namespace CNTK
|
|||
return BinaryOp(PrimitiveOpType::GatherPacked, operand, packedIndex, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr ScatterPacked(const Variable& operand, const Variable& packedIndex, const Variable& condition, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ScatterPacked, std::vector<Variable>({ operand, packedIndex, condition }), Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ZeroesLike(const Variable& operand)
|
||||
{
|
||||
if (operand.Shape().NumAxes() > 1)
|
||||
LogicError("ZerosLike currently does not support operands with more than 1 static axes");
|
||||
LogicError("Internal::ZeroesLike: Currently only 1D inputs are supported!");
|
||||
|
||||
auto rowSliceFunc = Internal::Slice(operand, Axis(0), 0, 1);
|
||||
return Minus(rowSliceFunc, rowSliceFunc);
|
||||
}
|
||||
|
||||
FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
if (offset == 0)
|
||||
InvalidArgument("Internal::CNTK::IsWithin: The offset must be positive");
|
||||
|
||||
if (offset > 0)
|
||||
return PastValue(ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), offset, name);
|
||||
if (operand.IsSparse())
|
||||
{
|
||||
if (operand.GetDataType() == DataType::Float)
|
||||
return Times(Constant({1, operand.Shape()[0]}, 0.0f), operand);
|
||||
else if (operand.GetDataType() == DataType::Double)
|
||||
return Times(Constant({ 1, operand.Shape()[0] }, 0.0), operand);
|
||||
else
|
||||
LogicError("Unsupported DataType %s", DataTypeName(operand.GetDataType()));
|
||||
}
|
||||
else
|
||||
return FutureValue(ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), -offset, name);
|
||||
{
|
||||
auto rowSliceFunc = Internal::Slice(operand, Axis(0), 0, 1);
|
||||
return Minus(rowSliceFunc, rowSliceFunc);
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
std::vector<std::wstring> newDynamicAxesNames;
|
||||
for (auto axis : newDynamicAxes)
|
||||
newDynamicAxesNames.push_back(axis.Name());
|
||||
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(newDynamicAxesNames);
|
||||
additionalProperties[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(newDynamicAxes);
|
||||
return UnaryOp(PrimitiveOpType::Where, condition, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
|
@ -1818,6 +1932,11 @@ namespace CNTK
|
|||
return Internal::GatherPacked(operand, Internal::PackedIndex(operand, Where(condition, newDynamicAxes)));
|
||||
}
|
||||
|
||||
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
return Internal::ScatterPacked(operand, Internal::PackedIndex(operand, Where(condition, newDynamicAxes)), condition);
|
||||
}
|
||||
|
||||
FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name /*= L""*/)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
|
|
|
@ -46,6 +46,7 @@ namespace CNTK
|
|||
GreaterEqual,
|
||||
PackedIndex,
|
||||
GatherPacked,
|
||||
ScatterPacked,
|
||||
Times,
|
||||
TransposeTimes,
|
||||
Convolution,
|
||||
|
@ -57,6 +58,7 @@ namespace CNTK
|
|||
ReduceElements,
|
||||
BatchNormalization,
|
||||
Clip,
|
||||
Select,
|
||||
Splice,
|
||||
Combine,
|
||||
};
|
||||
|
@ -77,7 +79,7 @@ namespace CNTK
|
|||
{
|
||||
inline const char* PrimitiveOpTypeName(PrimitiveOpType opType)
|
||||
{
|
||||
static std::unordered_map<PrimitiveOpType, const char*> primitiveOpNames = {
|
||||
static const std::unordered_map<PrimitiveOpType, const char*> primitiveOpNames = {
|
||||
{ PrimitiveOpType::Negate, "Negate" },
|
||||
{ PrimitiveOpType::Sigmoid, "Sigmoid" },
|
||||
{ PrimitiveOpType::Tanh, "Tanh" },
|
||||
|
@ -108,6 +110,7 @@ namespace CNTK
|
|||
{ PrimitiveOpType::GreaterEqual, "GreaterEqual" },
|
||||
{ PrimitiveOpType::PackedIndex, "PackedIndex" },
|
||||
{ PrimitiveOpType::GatherPacked, "GatherPacked" },
|
||||
{ PrimitiveOpType::ScatterPacked, "ScatterPacked" },
|
||||
{ PrimitiveOpType::Times, "Times" },
|
||||
{ PrimitiveOpType::TransposeTimes, "TransposeTimes" },
|
||||
{ PrimitiveOpType::Convolution, "Convolution" },
|
||||
|
@ -119,6 +122,7 @@ namespace CNTK
|
|||
{ PrimitiveOpType::ReduceElements, "ReduceElements" },
|
||||
{ PrimitiveOpType::BatchNormalization, "BatchNormalization" },
|
||||
{ PrimitiveOpType::Clip, "Clip" },
|
||||
{ PrimitiveOpType::Select, "Select" },
|
||||
{ PrimitiveOpType::Splice, "Splice" },
|
||||
{ PrimitiveOpType::Combine, "Combine" }
|
||||
};
|
||||
|
@ -288,9 +292,9 @@ namespace CNTK
|
|||
{
|
||||
if ((leftOperandShape[i] == NDShape::InferredDimension) && (rightOperandShape[i] == NDShape::InferredDimension))
|
||||
outputDims[i] = NDShape::InferredDimension;
|
||||
else if (leftOperandShape[i] == NDShape::InferredDimension)
|
||||
else if ((leftOperandShape[i] == NDShape::InferredDimension) || (leftOperandShape[i] == 1))
|
||||
outputDims[i] = rightOperandShape[i];
|
||||
else if (rightOperandShape[i] == NDShape::InferredDimension)
|
||||
else if ((rightOperandShape[i] == NDShape::InferredDimension) || (rightOperandShape[i] == 1))
|
||||
outputDims[i] = leftOperandShape[i];
|
||||
else
|
||||
{
|
||||
|
@ -308,6 +312,18 @@ namespace CNTK
|
|||
return NDShape(std::move(outputDims));
|
||||
}
|
||||
|
||||
static NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, const std::vector<NDShape>& operandShapes, bool broadcastAllowed = true)
|
||||
{
|
||||
assert(!operandShapes.empty());
|
||||
|
||||
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
|
||||
NDShape outputShape = {};
|
||||
for (auto& operandShape : operandShapes)
|
||||
outputShape = BinaryElementwiseOpOutputShape(op, outputShape, operandShape, broadcastAllowed);
|
||||
|
||||
return outputShape;
|
||||
}
|
||||
|
||||
static NDShape TimesOpOutputShape(const NDShape& leftOperandShape, const NDShape& rightOperandShape, size_t outputRank)
|
||||
{
|
||||
if (outputRank == 0)
|
||||
|
@ -362,7 +378,7 @@ namespace CNTK
|
|||
else
|
||||
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
|
||||
|
||||
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape, true), AsTensorShape(kernelShape, true), AsTensorShape(outputMapCount, true), AsTensorShape(strides, true), sharing, autoPad, AsTensorShape(lowerPad, true), AsTensorShape(upperPad, true)));
|
||||
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad)));
|
||||
}
|
||||
|
||||
// TODO: Reconcile this with the ComputationNode::Validate functionality in core CNTK to avoid duplication of inference logic
|
||||
|
@ -533,9 +549,9 @@ namespace CNTK
|
|||
inline std::vector<CNTK::Axis> DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName)
|
||||
{
|
||||
std::vector<CNTK::Axis> inputVarDynamicAxes;
|
||||
if (internalDynamicAxisName == CNTK::CompositeFunction::InternalDefaultDynamicAxisName)
|
||||
if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalDefaultDynamicAxisName.length()) == CNTK::CompositeFunction::InternalDefaultDynamicAxisName)
|
||||
inputVarDynamicAxes = { CNTK::Axis::DefaultDynamicAxis(), CNTK::Axis::DefaultBatchAxis() };
|
||||
else if (internalDynamicAxisName == CNTK::CompositeFunction::InternalNoSequenceAxisName)
|
||||
else if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalNoSequenceAxisName.length()) == CNTK::CompositeFunction::InternalNoSequenceAxisName)
|
||||
inputVarDynamicAxes = { CNTK::Axis::DefaultBatchAxis() };
|
||||
else
|
||||
inputVarDynamicAxes = { CNTK::Axis(internalDynamicAxisName), CNTK::Axis::DefaultBatchAxis() };
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "Learner.h"
|
||||
#include "TensorView.h"
|
||||
#include "Utils.h"
|
||||
|
@ -155,12 +156,17 @@ namespace CNTK
|
|||
|
||||
LearnerBase::LearnerBase(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool allocateSmoothGradients /* = true */)
|
||||
bool allocateSmoothGradients /* = true */,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
: Learner(parameters),
|
||||
m_learningRates(learningRates),
|
||||
m_sampleCount(0),
|
||||
m_minibatchCount(0)
|
||||
{
|
||||
m_additionalOptions.gradientClippingThresholdPerSample = clippingThresholdPerSample;
|
||||
m_additionalOptions.gradientClippingWithTruncation = gradientClippingWithTruncation;
|
||||
|
||||
for (const auto& parameter : parameters)
|
||||
{
|
||||
if (!allocateSmoothGradients)
|
||||
|
@ -356,8 +362,10 @@ namespace CNTK
|
|||
|
||||
LearnerAdaGrad::LearnerAdaGrad(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool needAveMultiplier)
|
||||
: LearnerBase(parameters, learningRates),
|
||||
bool needAveMultiplier,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
: LearnerBase(parameters, learningRates, true, clippingThresholdPerSample, gradientClippingWithTruncation),
|
||||
m_needAveMultiplier(needAveMultiplier)
|
||||
{
|
||||
}
|
||||
|
@ -385,8 +393,10 @@ namespace CNTK
|
|||
|
||||
LearnerFSAdaGrad::LearnerFSAdaGrad(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums)
|
||||
: LearnerMomentumSGD(parameters, learningRates, momentums, /*allocateSmoothGradients*/ false)
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
: LearnerMomentumSGD(parameters, learningRates, momentums, /*allocateSmoothGradients*/ false, clippingThresholdPerSample, gradientClippingWithTruncation)
|
||||
{
|
||||
for (const auto& parameter : parameters)
|
||||
{
|
||||
|
@ -417,10 +427,11 @@ namespace CNTK
|
|||
}
|
||||
|
||||
LearnerRMSProp::LearnerRMSProp(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates,
|
||||
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier)
|
||||
: LearnerBase(parameters, learningRates, /*allocateSmoothGradients*/ false),
|
||||
m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min),
|
||||
m_needAveMultiplier(needAveMultiplier)
|
||||
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
: LearnerBase(parameters, learningRates, /*allocateSmoothGradients*/ false, clippingThresholdPerSample, gradientClippingWithTruncation),
|
||||
m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min), m_needAveMultiplier(needAveMultiplier)
|
||||
{
|
||||
for (const auto& parameter : parameters)
|
||||
{
|
||||
|
@ -467,35 +478,56 @@ namespace CNTK
|
|||
template shared_ptr<Matrix<float>> LearnerBase::GetWritableMatrix<float>(const NDArrayViewPtr& arrayView);
|
||||
template shared_ptr<Matrix<double>> LearnerBase::GetWritableMatrix<double>(const NDArrayViewPtr& arrayView);
|
||||
|
||||
LearnerPtr SGDLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates)
|
||||
LearnerPtr SGDLearner(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerSGD>(parameters, learningRates);
|
||||
return MakeSharedObject<LearnerSGD>(parameters, learningRates, true, clippingThresholdPerSample, gradientClippingWithTruncation);
|
||||
}
|
||||
|
||||
LearnerPtr MomentumSGDLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, const MomentumsPerSample& momentums)
|
||||
LearnerPtr MomentumSGDLearner(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRates, momentums);
|
||||
return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRates, momentums, true, clippingThresholdPerSample, gradientClippingWithTruncation);
|
||||
}
|
||||
|
||||
LearnerPtr NesterovLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, const MomentumsPerSample& momentums)
|
||||
LearnerPtr NesterovLearner(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerNesterov>(parameters, learningRates, momentums);
|
||||
return MakeSharedObject<LearnerNesterov>(parameters, learningRates, momentums, clippingThresholdPerSample, gradientClippingWithTruncation);
|
||||
}
|
||||
|
||||
LearnerPtr AdaGradLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, bool needAveMultiplier)
|
||||
LearnerPtr FSAdaGradLearner(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerAdaGrad>(parameters, learningRates, needAveMultiplier);
|
||||
return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRates, momentums, clippingThresholdPerSample, gradientClippingWithTruncation);
|
||||
}
|
||||
|
||||
LearnerPtr FSAdaGradLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, const MomentumsPerSample& momentums)
|
||||
LearnerPtr AdaGradLearner(const unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool needAveMultiplier /*= true*/,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRates, momentums);
|
||||
return MakeSharedObject<LearnerAdaGrad>(parameters, learningRates, needAveMultiplier, clippingThresholdPerSample, gradientClippingWithTruncation);
|
||||
}
|
||||
|
||||
LearnerPtr RMSPropLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates,
|
||||
double gamma, double inc, double dec, double max, double min,
|
||||
bool needAveMultiplier)
|
||||
bool needAveMultiplier /*= true*/,
|
||||
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
|
||||
bool gradientClippingWithTruncation /*= true*/)
|
||||
{
|
||||
return MakeSharedObject<LearnerRMSProp>(parameters, learningRates, gamma, inc, dec, max, min, needAveMultiplier);
|
||||
return MakeSharedObject<LearnerRMSProp>(parameters, learningRates, gamma, inc, dec, max, min, needAveMultiplier, clippingThresholdPerSample, gradientClippingWithTruncation);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,9 @@ namespace CNTK
|
|||
protected:
|
||||
LearnerBase(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool allocateSmoothGradients = true);
|
||||
bool allocateSmoothGradients = true,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
virtual void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const = 0;
|
||||
|
||||
|
@ -104,11 +106,13 @@ namespace CNTK
|
|||
public:
|
||||
LearnerSGD(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool allocateSmoothGradients = true)
|
||||
: LearnerBase(parameters, learningRates, allocateSmoothGradients),
|
||||
bool allocateSmoothGradients = true,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true)
|
||||
: LearnerBase(parameters, learningRates, allocateSmoothGradients, clippingThresholdPerSample, gradientClippingWithTruncation),
|
||||
m_momentums(0.0),
|
||||
m_useNesterovAcceleration(false)
|
||||
{ }
|
||||
{}
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -129,8 +133,10 @@ namespace CNTK
|
|||
LearnerMomentumSGD(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums,
|
||||
bool allocateSmoothGradients = true)
|
||||
: LearnerSGD(parameters, learningRates, allocateSmoothGradients)
|
||||
bool allocateSmoothGradients = true,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true)
|
||||
: LearnerSGD(parameters, learningRates, allocateSmoothGradients, clippingThresholdPerSample, gradientClippingWithTruncation)
|
||||
{
|
||||
m_momentums = momentums;
|
||||
}
|
||||
|
@ -143,8 +149,10 @@ namespace CNTK
|
|||
|
||||
LearnerNesterov(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums)
|
||||
: LearnerMomentumSGD(parameters, learningRates, momentums)
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true)
|
||||
: LearnerMomentumSGD(parameters, learningRates, momentums, true, clippingThresholdPerSample, gradientClippingWithTruncation)
|
||||
{
|
||||
m_useNesterovAcceleration = true;
|
||||
}
|
||||
|
@ -156,7 +164,9 @@ namespace CNTK
|
|||
|
||||
LearnerAdaGrad(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
bool needAveMultiplier);
|
||||
bool needAveMultiplier,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
protected:
|
||||
bool m_needAveMultiplier;
|
||||
|
@ -173,7 +183,9 @@ namespace CNTK
|
|||
|
||||
LearnerFSAdaGrad(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
const MomentumsPerSample& momentums);
|
||||
const MomentumsPerSample& momentums,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -190,7 +202,9 @@ namespace CNTK
|
|||
LearnerRMSProp(const std::unordered_set<Parameter>& parameters,
|
||||
const LearningRatesPerSample& learningRates,
|
||||
double gamma, double inc, double dec, double max, double min,
|
||||
bool needAveMultiplier);
|
||||
bool needAveMultiplier,
|
||||
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
|
||||
bool gradientClippingWithTruncation = true);
|
||||
|
||||
protected:
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace CNTK
|
|||
|
||||
auto matrixDims = GetMatrixDimensions(viewShape);
|
||||
std::shared_ptr<Matrix<ElementType>> matrix = std::make_shared<Matrix<ElementType>>(matrixDims.first, matrixDims.second, (ElementType*)dataBuffer, AsCNTKImplDeviceId(device), matrixFlagDontOwnBuffer);
|
||||
return new TensorView<ElementType>(matrix, AsTensorShape(viewShape));
|
||||
return new TensorView<ElementType>(matrix, AsTensorViewShape(viewShape));
|
||||
}
|
||||
|
||||
static void* AllocateTensorView(CNTK::DataType dataType,
|
||||
|
@ -61,7 +61,7 @@ namespace CNTK
|
|||
AsCNTKImplDeviceId(device),
|
||||
IsSparseStorageFormat(storageType) ? MatrixType::SPARSE : MatrixType::DENSE,
|
||||
AsCNTKImplMatrixFormat(storageType));
|
||||
return new TensorView<ElementType>(matrix, AsTensorShape(viewShape));
|
||||
return new TensorView<ElementType>(matrix, AsTensorViewShape(viewShape));
|
||||
}
|
||||
|
||||
static void* AllocateTensorView(CNTK::DataType dataType,
|
||||
|
@ -320,7 +320,7 @@ namespace CNTK
|
|||
{
|
||||
auto matrixDims = GetMatrixDimensions(shape);
|
||||
auto randomNormalMatrix = std::make_shared<Matrix<ElementType>>(Matrix<ElementType>::RandomGaussian(matrixDims.first, matrixDims.second, AsCNTKImplDeviceId(device), (ElementType)mean, (ElementType)stdDev, seed));
|
||||
auto tensorView = new TensorView<ElementType>(randomNormalMatrix, AsTensorShape(shape));
|
||||
auto tensorView = new TensorView<ElementType>(randomNormalMatrix, AsTensorViewShape(shape));
|
||||
|
||||
return MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), device, StorageFormat::Dense, shape, false, tensorView);
|
||||
}
|
||||
|
@ -330,7 +330,7 @@ namespace CNTK
|
|||
{
|
||||
auto matrixDims = GetMatrixDimensions(shape);
|
||||
auto randomUniformMatrix = std::make_shared<Matrix<ElementType>>(Matrix<ElementType>::RandomUniform(matrixDims.first, matrixDims.second, AsCNTKImplDeviceId(device), (ElementType)rangeBegin, (ElementType)rangeEnd, seed));
|
||||
auto tensorView = new TensorView<ElementType>(randomUniformMatrix, AsTensorShape(shape));
|
||||
auto tensorView = new TensorView<ElementType>(randomUniformMatrix, AsTensorViewShape(shape));
|
||||
|
||||
return MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), device, StorageFormat::Dense, shape, false, tensorView);
|
||||
}
|
||||
|
|
|
@ -9,10 +9,12 @@
|
|||
|
||||
namespace CNTK
|
||||
{
|
||||
Trainer::Trainer(const FunctionPtr& model, const Variable& trainingLoss, const std::unordered_set<LearnerPtr>& parameterLearners)
|
||||
: m_model(model), m_trainingLossVar(trainingLoss), m_parameterLearners(parameterLearners), m_prevMinibatchNumSamples(1)
|
||||
Trainer::Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const FunctionPtr& evaluationFunction, const std::unordered_set<LearnerPtr>& parameterLearners)
|
||||
: m_model(model), m_lossFunction(lossFunction), m_evaluationFunction(evaluationFunction), m_parameterLearners(parameterLearners), m_prevMinibatchNumSamples(1)
|
||||
{
|
||||
auto modelParameters = model->Parameters();
|
||||
m_combinedTrainingFunction = Combine({ model, lossFunction, evaluationFunction });
|
||||
|
||||
auto modelParameters = m_combinedTrainingFunction->Parameters();
|
||||
std::unordered_set<Parameter> learnerParameters;
|
||||
for (const auto& learner : parameterLearners)
|
||||
{
|
||||
|
@ -29,36 +31,97 @@ namespace CNTK
|
|||
InvalidArgument("Trainer ctor: Union of the parameters covered by the specified parameterLearners should match the specified model's parameters");
|
||||
}
|
||||
|
||||
Trainer::Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const std::unordered_set<LearnerPtr>& parameterLearners)
|
||||
: Trainer(model, lossFunction, nullptr, parameterLearners)
|
||||
{}
|
||||
|
||||
static double GetScalarValue(const ValuePtr& value)
|
||||
{
|
||||
if (value->Mask())
|
||||
LogicError("Scalar Value object cannot have an associated mask");
|
||||
|
||||
auto scalarData = value->Data();
|
||||
if (scalarData->Shape().TotalSize() != 1)
|
||||
LogicError("Scalar Value object's has a size > 1");
|
||||
|
||||
double scalar = std::numeric_limits<double>::quiet_NaN();
|
||||
NDArrayViewPtr cpuData;
|
||||
if (scalarData->Device() == DeviceDescriptor::CPUDevice())
|
||||
cpuData = scalarData;
|
||||
else
|
||||
{
|
||||
cpuData = std::make_shared<NDArrayView>(scalarData->GetDataType(), scalarData->Shape(), CNTK::DeviceDescriptor::CPUDevice());
|
||||
cpuData->CopyFrom(*scalarData);
|
||||
}
|
||||
|
||||
if (scalarData->GetDataType() == DataType::Float)
|
||||
scalar = *(cpuData->DataBuffer<float>());
|
||||
else if (scalarData->GetDataType() == DataType::Double)
|
||||
scalar = *(cpuData->DataBuffer<double>());
|
||||
else
|
||||
LogicError("Unsupported DataType of training loss value");
|
||||
|
||||
return scalar;
|
||||
}
|
||||
|
||||
static size_t GetSampleCountFromArguments(const Variable& evalOrLossArgument, const std::unordered_map<Variable, ValuePtr>& arguments)
|
||||
{
|
||||
// Find the argument whose dynamic axes match the criterion operation's dynamic axes (i.e. label dynamic axes)
|
||||
// Then we determine the actual number of samples contributing to the training loss from the argument's Value object
|
||||
auto argumentIter = std::find_if(arguments.begin(), arguments.end(), [evalOrLossArgument](const std::pair<Variable, ValuePtr>& currentPair) {
|
||||
return (currentPair.first.DynamicAxes() == evalOrLossArgument.DynamicAxes());
|
||||
});
|
||||
|
||||
auto argumentValue = argumentIter->second;
|
||||
auto argumentVar = argumentIter->first;
|
||||
auto argumentDataShape = argumentValue->Data()->Shape();
|
||||
auto mask = argumentValue->Mask();
|
||||
size_t numMaskedSamples = (mask != nullptr) ? mask->MaskedCount() : 0;
|
||||
size_t numSamplesInDataArrayView = argumentDataShape.SubShape(argumentVar.Shape().NumAxes()).TotalSize();
|
||||
if (numMaskedSamples > numSamplesInDataArrayView)
|
||||
LogicError("Number of masked values cannot exceed the number of samples that the Value object's Data NDArrayView can hold");
|
||||
|
||||
return (numSamplesInDataArrayView - numMaskedSamples);
|
||||
}
|
||||
|
||||
double Trainer::TestMinbatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
if (!m_evaluationFunction)
|
||||
InvalidArgument("Trainer::TestMinbatch: Cannot test when no evaluation function was specified during 'this' trainer's construction");
|
||||
|
||||
// TODO: Should we refactor this code that is somewhat similar to the prologue of the TrainMinibatch function
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_evaluationFunction, nullptr } };
|
||||
m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice);
|
||||
|
||||
auto sampleCount = GetSampleCountFromArguments(*(m_evaluationFunction->Arguments().begin()), arguments);
|
||||
return (GetScalarValue(outputs[m_evaluationFunction]) / sampleCount);
|
||||
}
|
||||
|
||||
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
|
||||
{
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_trainingLossVar, nullptr } };
|
||||
auto backPropSate = m_model->Forward(arguments, outputs, computeDevice, { m_trainingLossVar });
|
||||
m_prevMinibatchTrainingLossValue = outputs.begin()->second;
|
||||
std::unordered_map<Variable, ValuePtr> outputs = { { m_lossFunction, nullptr } };
|
||||
if (m_evaluationFunction)
|
||||
outputs.insert({ m_evaluationFunction, nullptr });
|
||||
|
||||
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(m_trainingLossVar.GetDataType(), outputs.at(m_trainingLossVar)->Data()->Shape(), computeDevice), outputs.at(m_trainingLossVar)->Mask());
|
||||
if (m_trainingLossVar.GetDataType() == DataType::Float)
|
||||
auto backPropSate = m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice, { m_lossFunction });
|
||||
m_prevMinibatchAggregateTrainingLossValue = outputs[m_lossFunction];
|
||||
if (m_evaluationFunction)
|
||||
m_prevMinibatchAggregateEvalCriterionValue = outputs[m_evaluationFunction];
|
||||
|
||||
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(m_lossFunction->Output().GetDataType(), m_prevMinibatchAggregateTrainingLossValue->Data()->Shape(), computeDevice), outputs.at(m_lossFunction)->Mask());
|
||||
if (m_lossFunction->Output().GetDataType() == DataType::Float)
|
||||
rootGradientValue->Data()->SetValue(1.0f);
|
||||
else
|
||||
rootGradientValue->Data()->SetValue(1.0);
|
||||
|
||||
auto modelParameters = m_model->Parameters();
|
||||
auto modelParameters = m_combinedTrainingFunction->Parameters();
|
||||
std::unordered_map<Variable, ValuePtr> parameterGradients;
|
||||
for (const auto& parameter : modelParameters)
|
||||
parameterGradients[parameter] = nullptr;
|
||||
|
||||
m_model->Backward(backPropSate, { { m_trainingLossVar, rootGradientValue } }, parameterGradients);
|
||||
m_combinedTrainingFunction->Backward(backPropSate, { { m_lossFunction, rootGradientValue } }, parameterGradients);
|
||||
|
||||
auto trainingLossArgument = *(m_trainingLossVar.Owner()->Arguments().begin());
|
||||
|
||||
// Find the argument whose dynamic axes match the criterion operation's dynamic axes (i.e. label dynamic axes)
|
||||
// Then we determine the actual number of samples contributing to the training loss from the argument's Value object
|
||||
auto argumentValue = std::find_if(arguments.begin(), arguments.end(), [trainingLossArgument](const std::pair<Variable, ValuePtr>& currentPair) {
|
||||
return (currentPair.first.DynamicAxes() == trainingLossArgument.DynamicAxes());
|
||||
})->second;
|
||||
auto argumentData = argumentValue->Data();
|
||||
auto argumentDataShape = argumentData->Shape();
|
||||
auto mask = argumentValue->Mask();
|
||||
m_prevMinibatchNumSamples = argumentDataShape[argumentDataShape.NumAxes() - 1] - ((mask != nullptr) ? mask->MaskedCount() : 0);
|
||||
m_prevMinibatchNumSamples = GetSampleCountFromArguments(*(m_lossFunction->Arguments().begin()), arguments);
|
||||
|
||||
bool anyUpdatesPerformed = false;
|
||||
for (auto learner : m_parameterLearners)
|
||||
|
@ -79,27 +142,16 @@ namespace CNTK
|
|||
return anyUpdatesPerformed;
|
||||
}
|
||||
|
||||
double Trainer::PreviousMinibatchAverageTrainingLoss() const
|
||||
double Trainer::PreviousMinibatchLossAverage() const
|
||||
{
|
||||
double trainLossValue = std::numeric_limits<double>::quiet_NaN();
|
||||
auto prevMBTrainingLossValue = PreviousMinibatchTrainingLossValue()->Data();
|
||||
return (GetScalarValue(m_prevMinibatchAggregateTrainingLossValue) / m_prevMinibatchNumSamples);
|
||||
}
|
||||
|
||||
NDArrayViewPtr cpuTrainLossValue;
|
||||
if (prevMBTrainingLossValue->Device() == DeviceDescriptor::CPUDevice())
|
||||
cpuTrainLossValue = prevMBTrainingLossValue;
|
||||
else
|
||||
{
|
||||
cpuTrainLossValue = std::make_shared<NDArrayView>(prevMBTrainingLossValue->GetDataType(), prevMBTrainingLossValue->Shape(), CNTK::DeviceDescriptor::CPUDevice());
|
||||
cpuTrainLossValue->CopyFrom(*prevMBTrainingLossValue);
|
||||
}
|
||||
double Trainer::PreviousMinibatchEvaluationAverage() const
|
||||
{
|
||||
if (!m_evaluationFunction)
|
||||
InvalidArgument("Trainer::PreviousMinibatchEvaluationAverage: Cannot get evaluation criterion value when no evaluation function was specified during 'this' trainer's construction");
|
||||
|
||||
if (prevMBTrainingLossValue->GetDataType() == DataType::Float)
|
||||
trainLossValue = *(cpuTrainLossValue->DataBuffer<float>());
|
||||
else if (prevMBTrainingLossValue->GetDataType() == DataType::Double)
|
||||
trainLossValue = *(cpuTrainLossValue->DataBuffer<double>());
|
||||
else
|
||||
LogicError("Unsupported DataType of training loss value");
|
||||
|
||||
return (trainLossValue / m_prevMinibatchNumSamples);
|
||||
return (GetScalarValue(m_prevMinibatchAggregateEvalCriterionValue) / m_prevMinibatchNumSamples);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -601,12 +601,14 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
|
||||
template void DictionaryValue::AllocateDataPtr<Axis>(const Axis& value);
|
||||
template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
|
||||
template void DictionaryValue::AllocateDataPtr<wstring>(const wstring& value);
|
||||
template void DictionaryValue::AllocateDataPtr<Dictionary>(const Dictionary& value);
|
||||
template void DictionaryValue::AllocateDataPtr<NDArrayView>(const NDArrayView& value);
|
||||
|
||||
template void DictionaryValue::FreePtrAsType<NDShape>();
|
||||
template void DictionaryValue::FreePtrAsType<Axis>();
|
||||
template void DictionaryValue::FreePtrAsType<vector<DictionaryValue>>();
|
||||
template void DictionaryValue::FreePtrAsType<wstring>();
|
||||
template void DictionaryValue::FreePtrAsType<Dictionary>();
|
||||
|
|
|
@ -119,14 +119,14 @@ namespace CNTK
|
|||
}
|
||||
}
|
||||
|
||||
inline Microsoft::MSR::CNTK::TensorShape AsTensorShape(const NDShape& viewShape, bool preserveRank = false)
|
||||
inline Microsoft::MSR::CNTK::TensorShape AsTensorShape(const NDShape& viewShape)
|
||||
{
|
||||
const size_t maxNumAxesSupportedByTensorView = 12;
|
||||
if (viewShape.NumAxes() > maxNumAxesSupportedByTensorView)
|
||||
LogicError("The number of requested axes exceeds the currently supported limit");
|
||||
|
||||
// TensorShape is required to be at least 2D
|
||||
size_t minRankSize = preserveRank ? viewShape.NumAxes() : 2;
|
||||
// TensorShape is required to be at least 1D
|
||||
size_t minRankSize = 1;
|
||||
Microsoft::MSR::CNTK::SmallVector<size_t> tensorViewShape(std::max<size_t>(minRankSize, viewShape.NumAxes()));
|
||||
for (size_t i = 0; i < tensorViewShape.size(); ++i)
|
||||
tensorViewShape[i] = (i < viewShape.NumAxes()) ? viewShape[i] : 1;
|
||||
|
@ -134,6 +134,17 @@ namespace CNTK
|
|||
return tensorViewShape;
|
||||
}
|
||||
|
||||
inline Microsoft::MSR::CNTK::TensorShape AsTensorViewShape(const Microsoft::MSR::CNTK::TensorShape& viewShape)
|
||||
{
|
||||
// For TensorView shapes we pad the TensorShape to be at least rank 2
|
||||
return viewShape.PadRank(std::max<size_t>(2, viewShape.GetRank()));
|
||||
}
|
||||
|
||||
inline Microsoft::MSR::CNTK::TensorShape AsTensorViewShape(const NDShape& viewShape)
|
||||
{
|
||||
return AsTensorViewShape(AsTensorShape(viewShape));
|
||||
}
|
||||
|
||||
inline std::string AsString(const NDShape& shape)
|
||||
{
|
||||
std::string shapeString = "[";
|
||||
|
@ -242,35 +253,39 @@ namespace CNTK
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<DictionaryValue> AsDictionaryValueVector(const std::vector<T>& basicElementTypeVector)
|
||||
inline std::vector<DictionaryValue> AsDictionaryValueVector(const std::vector<T>& elementVector)
|
||||
{
|
||||
static_assert(std::is_same<T, bool>::value ||
|
||||
std::is_same<T, size_t>::value ||
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value ||
|
||||
std::is_same<T, std::wstring>::value, "Unsupported ValueType");
|
||||
std::is_same<T, Axis>::value ||
|
||||
std::is_same<T, std::wstring>::value,
|
||||
"Unsupported ValueType");
|
||||
|
||||
std::vector<DictionaryValue> dictionaryValueVector;
|
||||
for (auto value : basicElementTypeVector)
|
||||
for (auto value : elementVector)
|
||||
dictionaryValueVector.push_back(value);
|
||||
|
||||
return dictionaryValueVector;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> AsBasicElementTypeVector(const std::vector<DictionaryValue>& dictionaryValueVector)
|
||||
inline std::vector<T> AsVector(const std::vector<DictionaryValue>& dictionaryValueVector)
|
||||
{
|
||||
static_assert(std::is_same<T, bool>::value ||
|
||||
std::is_same<T, size_t>::value ||
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value ||
|
||||
std::is_same<T, std::wstring>::value, "Unsupported ValueType");
|
||||
std::is_same<T, size_t>::value ||
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value ||
|
||||
std::is_same<T, Axis>::value ||
|
||||
std::is_same<T, std::wstring>::value,
|
||||
"Unsupported ValueType");
|
||||
|
||||
std::vector<T> basicElementTypeVector;
|
||||
std::vector<T> elementVector;
|
||||
for (auto value : dictionaryValueVector)
|
||||
basicElementTypeVector.push_back(value.Value<T>());
|
||||
elementVector.push_back(value.Value<T>());
|
||||
|
||||
return basicElementTypeVector;
|
||||
return elementVector;
|
||||
}
|
||||
|
||||
inline PoolingType AsPoolingType(Microsoft::MSR::CNTK::PoolKind cntkPoolingKind)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
namespace CNTK
|
||||
|
|
|
@ -3,7 +3,10 @@
|
|||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "Function.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
@ -16,6 +19,44 @@ namespace CNTK
|
|||
|
||||
FunctionPtr Variable::Owner() const
|
||||
{
|
||||
return m_dataFields->m_ownerFunction->shared_from_this();
|
||||
if (m_dataFields->m_ownerFunction != nullptr)
|
||||
return m_dataFields->m_ownerFunction->shared_from_this();
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Variable::operator FunctionPtr() const
|
||||
{
|
||||
auto varOwner = Owner();
|
||||
if (varOwner)
|
||||
return CompositeFunction::Create(varOwner, varOwner->Name());
|
||||
else
|
||||
return Internal::Combine({ *this });
|
||||
}
|
||||
|
||||
/*static*/ Parameter Parameter::UniformInitParameter(const NDShape& shape, DataType type, double range, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case DataType::Float:
|
||||
return Parameter(NDArrayView::RandomUniform<float>(shape, -range, range, seed, device), name);
|
||||
case DataType::Double:
|
||||
return Parameter(NDArrayView::RandomUniform<double>(shape, -range, range, seed, device), name);
|
||||
default:
|
||||
InvalidArgument("Parameter construction: Unsupported DataType %s", DataTypeName(type));
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ Parameter Parameter::NormalInitParameter(const NDShape& shape, DataType type, double stdDev, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case DataType::Float:
|
||||
return Parameter(NDArrayView::RandomNormal<float>(shape, 0, stdDev, seed, device), name);
|
||||
case DataType::Double:
|
||||
return Parameter(NDArrayView::RandomNormal<double>(shape, 0, stdDev, seed, device), name);
|
||||
default:
|
||||
InvalidArgument("Parameter construction: Unsupported DataType %s", DataTypeName(type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1239,7 +1239,11 @@ void Matrix<ElemType>::AssignValuesOf(const Matrix<ElemType>& deepCopyFrom)
|
|||
DISPATCH_MATRIX_ON_FLAG(&deepCopyFrom, nullptr,
|
||||
{ m_GPUMatrix->SetValue(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols(), this->GetDeviceId(), deepCopyFrom.m_CPUMatrix->Data()); },
|
||||
{ m_GPUMatrix->SetValue(*deepCopyFrom.m_GPUMatrix); },
|
||||
{ LogicError("AssignValuesOf: Assigning a CPUSparseMatrix to a GPUMatrix is not yet implemented."); },//{ m_GPUMatrix->SetValue(*deepCopyFrom.m_CPUSparseMatrix); },
|
||||
{
|
||||
CPUMatrix<ElemType> tempCPUDenseMatrix(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols());
|
||||
deepCopyFrom.m_CPUSparseMatrix->AssignColumnSliceToDense(tempCPUDenseMatrix, 0, deepCopyFrom.GetNumCols());
|
||||
m_GPUMatrix->SetValue(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols(), this->GetDeviceId(), tempCPUDenseMatrix.Data());
|
||||
},//{ m_GPUMatrix->SetValue(*deepCopyFrom.m_CPUSparseMatrix); },
|
||||
{ LogicError("AssignValuesOf: Assigning a GPUSparseMatrix to a GPUMatrix is not yet implemented."); });//{ m_GPUMatrix->SetValue(*deepCopyFrom.m_GPUSparseMatrix); });
|
||||
},
|
||||
{
|
||||
|
|
|
@ -132,24 +132,28 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
|
|||
const size_t numOutputClasses = labelStreamInfo.m_sampleLayout[0];
|
||||
|
||||
Variable imageInput(inputImageShape, imageStreamInfo.m_elementType, L"Images");
|
||||
auto classifierOutputFunction = ResNetClassifier(imageInput, numOutputClasses, device, L"classifierOutput");
|
||||
Variable classifierOutput = classifierOutputFunction;
|
||||
auto classifierOutput = ResNetClassifier(imageInput, numOutputClasses, device, L"classifierOutput");
|
||||
|
||||
auto labelsVar = Variable({ numOutputClasses }, labelStreamInfo.m_elementType, L"Labels");
|
||||
|
||||
auto trainingLossFunction = CrossEntropyWithSoftmax(classifierOutputFunction, labelsVar, L"lossFunction");
|
||||
Variable trainingLoss = trainingLossFunction;
|
||||
auto predictionFunction = ClassificationError(classifierOutputFunction, labelsVar, L"predictionError");
|
||||
Variable prediction = predictionFunction;
|
||||
|
||||
auto imageClassifier = Combine({ trainingLossFunction, predictionFunction, classifierOutputFunction }, L"ImageClassifier");
|
||||
auto trainingLoss = CrossEntropyWithSoftmax(classifierOutput, labelsVar, L"lossFunction");
|
||||
auto prediction = ClassificationError(classifierOutput, labelsVar, L"predictionError");
|
||||
|
||||
if (testSaveAndReLoad)
|
||||
SaveAndReloadModel<float>(imageClassifier, { &imageInput, &labelsVar, &trainingLoss, &prediction, &classifierOutput }, device);
|
||||
{
|
||||
Variable classifierOutputVar = classifierOutput;
|
||||
Variable trainingLossVar = trainingLoss;
|
||||
Variable predictionVar = prediction;
|
||||
auto imageClassifier = Combine({ trainingLoss, prediction, classifierOutput }, L"ImageClassifier");
|
||||
SaveAndReloadModel<float>(imageClassifier, { &imageInput, &labelsVar, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
|
||||
|
||||
trainingLoss = trainingLossVar;
|
||||
prediction = predictionVar;
|
||||
classifierOutput = classifierOutputVar;
|
||||
}
|
||||
|
||||
double learningRatePerSample = 0.0078125;
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
|
||||
|
||||
Trainer trainer(imageClassifier, trainingLoss, { SGDLearner(imageClassifier->Parameters(), learningRatePerSample) });
|
||||
const size_t minibatchSize = 32;
|
||||
size_t numMinibatchesToTrain = 100;
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
|
@ -157,12 +161,7 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
|
|||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
trainer.TrainMinibatch({ { imageInput, minibatchData[imageStreamInfo].m_data }, { labelsVar, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
|
||||
if ((i % outputFrequencyInMinibatches) == 0)
|
||||
{
|
||||
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
|
||||
}
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -116,114 +116,82 @@ inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t out
|
|||
template <typename ElementType>
|
||||
std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(CNTK::Variable input, CNTK::Variable prevOutput, CNTK::Variable prevCellState, const CNTK::DeviceDescriptor& device)
|
||||
{
|
||||
assert(input.Shape().NumAxes() == 1);
|
||||
size_t inputDim = input.Shape()[0];
|
||||
if ((input.Shape().NumAxes() != 1) || (prevOutput.Shape().NumAxes() != 1) || (prevCellState.Shape().NumAxes() != 1))
|
||||
LogicError("The LSTM implementation in the test library currently only supports 1D inputs and outputs");
|
||||
|
||||
auto stabilize = [](const Variable& x) {
|
||||
float scalarConstant = 4.0f;
|
||||
auto f = Constant({}, scalarConstant);
|
||||
auto fInv = Constant({}, 1.0f / scalarConstant);
|
||||
|
||||
auto beta = ElementTimes(fInv, Log(Constant({}, 1.0f) + Exp(ElementTimes(f, Parameter({}, 0.99537863f /* 1/f*ln (e^f-1) */)))));
|
||||
return ElementTimes(beta, x);
|
||||
};
|
||||
|
||||
size_t inputDim = input.Shape()[0];
|
||||
size_t outputDim = prevOutput.Shape()[0];
|
||||
size_t cellDim = prevCellState.Shape()[0];
|
||||
|
||||
auto createBiasParam = [device](size_t dim) {
|
||||
return CNTK::Parameter({ dim }, (ElementType)0.0, device);
|
||||
};
|
||||
|
||||
unsigned long seed = 1;
|
||||
auto createProjectionParam = [device, &seed](size_t outputDim, size_t inputDim) {
|
||||
return CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ outputDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
};
|
||||
|
||||
auto Wxo = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wxi = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wxf = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wxc = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
|
||||
auto createDiagWeightParam = [device, &seed](size_t dim) {
|
||||
return CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ dim }, -0.5, 0.5, seed++, device));
|
||||
};
|
||||
|
||||
auto Bo = CNTK::Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
auto Bc = CNTK::Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
auto Bi = CNTK::Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
auto Bf = CNTK::Parameter({ cellDim }, (ElementType)0.0, device);
|
||||
auto stabilizedPrevOutput = stabilize(prevOutput);
|
||||
auto stabilizedPrevCellState = stabilize(prevCellState);
|
||||
|
||||
auto Whi = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wci = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device));
|
||||
auto projectInput = [input, cellDim, inputDim, createBiasParam, createProjectionParam]() {
|
||||
return createBiasParam(cellDim) + Times(createProjectionParam(cellDim, inputDim), input);
|
||||
};
|
||||
|
||||
auto Whf = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wcf = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device));
|
||||
// Input gate
|
||||
auto it = Sigmoid(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto bit = ElementTimes(it, Tanh(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput)));
|
||||
|
||||
auto Who = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
auto Wco = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device));
|
||||
// Forget-me-not gate
|
||||
auto ft = Sigmoid(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
|
||||
auto bft = ElementTimes(ft, prevCellState);
|
||||
|
||||
auto Whc = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device));
|
||||
auto ct = bft + bit;
|
||||
|
||||
auto Wmr = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ outputDim, cellDim }, -0.5, 0.5, seed++, device));
|
||||
// Output gate
|
||||
auto ot = Sigmoid(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + ElementTimes(createDiagWeightParam(cellDim), stabilize(ct)));
|
||||
auto ht = ElementTimes(ot, Tanh(ct));
|
||||
|
||||
// Stabilization by routing input through an extra scalar parameter
|
||||
auto sWxo = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWxi = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWxf = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWxc = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto c = ct;
|
||||
auto h = (outputDim != cellDim) ? Times(createProjectionParam(outputDim, cellDim), stabilize(ht)) : ht;
|
||||
|
||||
auto sWhi = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWci = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
|
||||
auto sWhf = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWcf = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWho = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWco = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
auto sWhc = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
|
||||
auto sWmr = CNTK::Parameter({}, (ElementType)0.0, device);
|
||||
|
||||
auto expsWxo = CNTK::Exp(sWxo);
|
||||
auto expsWxi = CNTK::Exp(sWxi);
|
||||
auto expsWxf = CNTK::Exp(sWxf);
|
||||
auto expsWxc = CNTK::Exp(sWxc);
|
||||
|
||||
auto expsWhi = CNTK::Exp(sWhi);
|
||||
auto expsWci = CNTK::Exp(sWci);
|
||||
|
||||
auto expsWhf = CNTK::Exp(sWhf);
|
||||
auto expsWcf = CNTK::Exp(sWcf);
|
||||
auto expsWho = CNTK::Exp(sWho);
|
||||
auto expsWco = CNTK::Exp(sWco);
|
||||
auto expsWhc = CNTK::Exp(sWhc);
|
||||
|
||||
auto expsWmr = CNTK::Exp(sWmr);
|
||||
|
||||
auto Wxix = CNTK::Times(Wxi, CNTK::ElementTimes(expsWxi, input));
|
||||
auto Whidh = CNTK::Times(Whi, CNTK::ElementTimes(expsWhi, prevOutput));
|
||||
auto Wcidc = CNTK::ElementTimes(Wci, CNTK::ElementTimes(expsWci, prevCellState));
|
||||
|
||||
auto it = CNTK::Sigmoid(CNTK::Plus(CNTK::Plus(CNTK::Plus(Wxix, Bi), Whidh), Wcidc));
|
||||
|
||||
auto Wxcx = CNTK::Times(Wxc, CNTK::ElementTimes(expsWxc, input));
|
||||
auto Whcdh = CNTK::Times(Whc, CNTK::ElementTimes(expsWhc, prevOutput));
|
||||
auto bit = CNTK::ElementTimes(it, CNTK::Tanh(CNTK::Plus(Wxcx, CNTK::Plus(Whcdh, Bc))));
|
||||
|
||||
auto Wxfx = CNTK::Times(Wxf, CNTK::ElementTimes(expsWxf, input));
|
||||
auto Whfdh = CNTK::Times(Whf, CNTK::ElementTimes(expsWhf, prevOutput));
|
||||
auto Wcfdc = CNTK::ElementTimes(Wcf, CNTK::ElementTimes(expsWcf, prevCellState));
|
||||
|
||||
auto ft = CNTK::Sigmoid(CNTK::Plus(CNTK::Plus(CNTK::Plus(Wxfx, Bf), Whfdh), Wcfdc));
|
||||
|
||||
auto bft = CNTK::ElementTimes(ft, prevCellState);
|
||||
|
||||
auto ct = CNTK::Plus(bft, bit);
|
||||
|
||||
auto Wxox = CNTK::Times(Wxo, CNTK::ElementTimes(expsWxo, input));
|
||||
auto Whodh = CNTK::Times(Who, CNTK::ElementTimes(expsWho, prevOutput));
|
||||
auto Wcoct = CNTK::ElementTimes(Wco, CNTK::ElementTimes(expsWco, ct));
|
||||
|
||||
auto ot = CNTK::Sigmoid(CNTK::Plus(CNTK::Plus(CNTK::Plus(Wxox, Bo), Whodh), Wcoct));
|
||||
|
||||
auto mt = CNTK::ElementTimes(ot, Tanh(ct));
|
||||
|
||||
return{ CNTK::Times(Wmr, CNTK::ElementTimes(expsWmr, mt)), ct };
|
||||
return{ h, c };
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
CNTK::FunctionPtr LSTMPComponentWithSelfStabilization(CNTK::Variable input, size_t outputDim, size_t cellDim, const CNTK::DeviceDescriptor& device)
|
||||
std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPComponentWithSelfStabilization(CNTK::Variable input,
|
||||
size_t outputDim,
|
||||
size_t cellDim,
|
||||
const std::function<CNTK::FunctionPtr(const CNTK::Variable&)>& recurrenceHookH,
|
||||
const std::function<CNTK::FunctionPtr(const CNTK::Variable&)>& recurrenceHookC,
|
||||
const CNTK::DeviceDescriptor& device)
|
||||
{
|
||||
auto dh = CNTK::Placeholder({ outputDim });
|
||||
auto dc = CNTK::Placeholder({ cellDim });
|
||||
auto dh = CNTK::Placeholder({ outputDim }, input.DynamicAxes());
|
||||
auto dc = CNTK::Placeholder({ cellDim }, input.DynamicAxes());
|
||||
|
||||
auto LSTMCell = LSTMPCellWithSelfStabilization<ElementType>(input, dh, dc, device);
|
||||
|
||||
auto actualDh = CNTK::PastValue(LSTMCell.first, CNTK::Constant({}, (ElementType)0.0, device), 1);
|
||||
auto actualDc = CNTK::PastValue(LSTMCell.second, CNTK::Constant({}, (ElementType)0.0, device), 1);
|
||||
auto actualDh = recurrenceHookH(LSTMCell.first);
|
||||
auto actualDc = recurrenceHookC(LSTMCell.second);
|
||||
|
||||
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
|
||||
return LSTMCell.first->ReplacePlaceholders({ { dh, actualDh }, { dc, actualDc } });
|
||||
LSTMCell.first->ReplacePlaceholders({ { dh, actualDh }, { dc, actualDc } });
|
||||
|
||||
return { LSTMCell.first, LSTMCell.second };
|
||||
}
|
||||
|
||||
inline CNTK::MinibatchSourcePtr CreateTextMinibatchSource(const std::wstring& filePath,
|
||||
|
@ -355,4 +323,14 @@ inline void OpenStream(std::fstream& stream, const std::wstring& filename, bool
|
|||
stream.open(wtocharpath(filename.c_str()).c_str(), mode);
|
||||
#endif
|
||||
stream.exceptions(std::ios_base::failbit | std::ios_base::badbit);
|
||||
}
|
||||
}
|
||||
|
||||
inline void PrintTrainingProgress(const CNTK::Trainer& trainer, size_t minibatchIdx, size_t outputFrequencyInMinibatches)
|
||||
{
|
||||
if ((minibatchIdx % outputFrequencyInMinibatches) == 0)
|
||||
{
|
||||
double trainLossValue = trainer.PreviousMinibatchLossAverage();
|
||||
double evaluationValue = trainer.PreviousMinibatchEvaluationAverage();
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g, Evaluation criterion = %.8g\n", (int)minibatchIdx, trainLossValue, evaluationValue);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ void FunctionTests();
|
|||
void TrainLSTMSequenceClassifer();
|
||||
void SerializationTests();
|
||||
void LearnerTests();
|
||||
void TrainSequenceToSequenceTranslator();
|
||||
|
||||
int main()
|
||||
{
|
||||
|
@ -29,6 +30,8 @@ int main()
|
|||
|
||||
TestCifarResnet();
|
||||
TrainLSTMSequenceClassifer();
|
||||
|
||||
TrainSequenceToSequenceTranslator();
|
||||
|
||||
fprintf(stderr, "\nCNTKv2Library tests: Passed\n");
|
||||
fflush(stderr);
|
||||
|
|
|
@ -10,10 +10,13 @@ static unsigned long seed = 1;
|
|||
template <typename ElementType>
|
||||
FunctionPtr LSTMNet(Variable features, size_t cellDim, size_t hiddenDim, size_t numOutputClasses, size_t numLSTMLayers, const DeviceDescriptor& device, const std::wstring& outputName)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
|
||||
assert(numLSTMLayers >= 1);
|
||||
auto classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(features, hiddenDim, cellDim, device);
|
||||
for (size_t i = 1; i < numLSTMLayers; ++i) {
|
||||
classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(classifierRoot, hiddenDim, cellDim, device);
|
||||
FunctionPtr classifierRoot = features;
|
||||
auto pastValueRecurrenceHook = std::bind(PastValue, _1, CNTK::Constant({}, (ElementType)0.0), 1, L"");
|
||||
for (size_t i = 0; i < numLSTMLayers; ++i) {
|
||||
classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(classifierRoot, hiddenDim, cellDim, pastValueRecurrenceHook, pastValueRecurrenceHook, device).first;
|
||||
}
|
||||
|
||||
auto W = Parameter(NDArrayView::RandomUniform<ElementType>({ numOutputClasses, hiddenDim }, -0.5, 0.5, seed++, device));
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
#include "CNTKLibrary.h"
|
||||
#include <functional>
|
||||
#include "Common.h"
|
||||
|
||||
using namespace CNTK;
|
||||
|
||||
using namespace std::placeholders;
|
||||
|
||||
inline CNTK::MinibatchSourcePtr CreateSeq2SeqMinibatchSource(const std::wstring& filePath, size_t inputVocabSize, size_t labelsVocabSize)
|
||||
{
|
||||
CNTK::Dictionary inputStreamConfig;
|
||||
inputStreamConfig[L"dim"] = inputVocabSize;
|
||||
inputStreamConfig[L"format"] = L"sparse";
|
||||
inputStreamConfig[L"alias"] = L"S0";
|
||||
|
||||
CNTK::Dictionary labelsStreamConfig;
|
||||
labelsStreamConfig[L"dim"] = labelsVocabSize;
|
||||
labelsStreamConfig[L"format"] = L"sparse";
|
||||
labelsStreamConfig[L"alias"] = L"S1";
|
||||
|
||||
CNTK::Dictionary inputStreamsConfig;
|
||||
inputStreamsConfig[L"rawInput"] = inputStreamConfig;
|
||||
inputStreamsConfig[L"rawLabels"] = labelsStreamConfig;
|
||||
|
||||
CNTK::Dictionary deserializerConfiguration;
|
||||
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
|
||||
deserializerConfiguration[L"file"] = filePath;
|
||||
deserializerConfiguration[L"input"] = inputStreamsConfig;
|
||||
deserializerConfiguration[L"skipSequenceIds"] = L"false";
|
||||
deserializerConfiguration[L"maxErrors"] = (size_t)100;
|
||||
deserializerConfiguration[L"traceLevel"] = (size_t)1;
|
||||
deserializerConfiguration[L"chunkSizeInBytes"] = (size_t)30000000;
|
||||
|
||||
CNTK::Dictionary minibatchSourceConfiguration;
|
||||
minibatchSourceConfiguration[L"epochSize"] = (size_t)2000;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<CNTK::DictionaryValue>({ deserializerConfiguration });
|
||||
|
||||
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
}
|
||||
|
||||
void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
|
||||
const size_t inputVocabDim = 69;
|
||||
const size_t labelVocabDim = 69;
|
||||
|
||||
const size_t hiddenDim = 512;
|
||||
const size_t numLayers = 2;
|
||||
|
||||
const size_t embeddingDim = 300;
|
||||
const size_t inputEmbeddingDim = std::min(inputVocabDim, embeddingDim);
|
||||
const size_t labelEmbeddingDim = std::min(labelVocabDim, embeddingDim);
|
||||
|
||||
/* Inputs */
|
||||
std::vector<Axis> inputDynamicAxes = { Axis(L"inputAxis"), Axis::DefaultBatchAxis() };
|
||||
auto rawInput = Variable({ inputVocabDim }, useSparseInputs /*isSparse*/, DataType::Float, L"rawInput", inputDynamicAxes);
|
||||
|
||||
std::vector<Axis> labelDynamicAxes = { Axis(L"labelAxis"), Axis::DefaultBatchAxis() };
|
||||
auto rawLabels = Variable({ labelVocabDim }, useSparseInputs /*isSparse*/, DataType::Float, L"rawLabels", labelDynamicAxes);
|
||||
|
||||
FunctionPtr inputSequence = rawInput;
|
||||
|
||||
// Drop the sentence start token from the label, for decoder training
|
||||
auto labelSequence = Slice(rawLabels, labelDynamicAxes[0], 1, 0);
|
||||
auto labelSentenceStart = Sequence::First(rawLabels);
|
||||
|
||||
auto isFirstLabel = Sequence::IsFirst(labelSequence);
|
||||
|
||||
bool forceEmbedding = useSparseInputs;
|
||||
|
||||
/* Embeddings */
|
||||
auto inputEmbeddingWeights = Parameter(NDArrayView::RandomUniform<float>({ inputEmbeddingDim, inputVocabDim }, -0.05, 0.05, 1, device));
|
||||
auto labelEmbeddingWeights = Parameter(NDArrayView::RandomUniform<float>({ labelEmbeddingDim, labelVocabDim }, -0.05, 0.05, 1, device));
|
||||
|
||||
auto inputEmbedding = (!forceEmbedding && (inputVocabDim <= inputEmbeddingDim)) ? inputSequence : Times(inputEmbeddingWeights, inputSequence);
|
||||
auto labelEmbedding = (!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSequence : Times(labelEmbeddingWeights, labelSequence);
|
||||
auto labelSentenceStartEmbedding = (!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSentenceStart : Times(labelEmbeddingWeights, labelSentenceStart);
|
||||
auto labelSentenceStartEmbeddedScattered = Sequence::Scatter(labelSentenceStartEmbedding, isFirstLabel);
|
||||
|
||||
auto stabilize = [](const Variable& x) {
|
||||
float scalarConstant = 4.0f;
|
||||
auto f = Constant({}, scalarConstant);
|
||||
auto fInv = Constant({}, 1.0f/scalarConstant);
|
||||
|
||||
auto beta = ElementTimes(fInv, Log(Constant({}, 1.0f) + Exp(ElementTimes(f, Parameter({}, 0.99537863f /* 1/f*ln (e^f-1) */)))));
|
||||
return ElementTimes(beta, x);
|
||||
};
|
||||
|
||||
/* Encoder */
|
||||
auto encoderOutputH = stabilize(inputEmbedding);
|
||||
FunctionPtr encoderOutputC;
|
||||
auto futureValueRecurrenceHook = std::bind(FutureValue, _1, CNTK::Constant({}, 0.0f), 1, L"");
|
||||
for (size_t i = 0; i < numLayers; ++i)
|
||||
std::tie(encoderOutputH, encoderOutputC) = LSTMPComponentWithSelfStabilization<float>(encoderOutputH, hiddenDim, hiddenDim, futureValueRecurrenceHook, futureValueRecurrenceHook, device);
|
||||
|
||||
auto thoughtVectorH = Sequence::First(encoderOutputH);
|
||||
auto thoughtVectorC = Sequence::First(encoderOutputC);
|
||||
|
||||
auto thoughtVectorBroadcastH = Sequence::BroadcastAs(thoughtVectorH, labelEmbedding);
|
||||
auto thoughtVectorBroadcastC = Sequence::BroadcastAs(thoughtVectorC, labelEmbedding);
|
||||
|
||||
/* Decoder */
|
||||
auto decoderHistoryFromGroundTruth = labelEmbedding;
|
||||
auto decoderInput = ElementSelect(isFirstLabel, labelSentenceStartEmbeddedScattered, PastValue(decoderHistoryFromGroundTruth, Constant({}, 0.0f), 1));
|
||||
|
||||
auto decoderOutputH = stabilize(decoderInput);
|
||||
FunctionPtr decoderOutputC;
|
||||
auto pastValueRecurrenceHook = std::bind(PastValue, _1, CNTK::Constant({}, 0.0f), 1, L"");
|
||||
for (size_t i = 0; i < numLayers; ++i)
|
||||
{
|
||||
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC;
|
||||
if (i == 0)
|
||||
{
|
||||
recurrenceHookH = pastValueRecurrenceHook;
|
||||
recurrenceHookC = pastValueRecurrenceHook;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto isFirst = Sequence::IsFirst(labelEmbedding);
|
||||
recurrenceHookH = [labelEmbedding, thoughtVectorBroadcastH, isFirst](const Variable& operand) {
|
||||
return ElementSelect(isFirst, thoughtVectorBroadcastH, PastValue(operand, CNTK::Constant({}, 0.0f), 1, L""));
|
||||
};
|
||||
|
||||
recurrenceHookC = [labelEmbedding, thoughtVectorBroadcastC, isFirst](const Variable& operand) {
|
||||
return ElementSelect(isFirst, thoughtVectorBroadcastC, PastValue(operand, CNTK::Constant({}, 0.0f), 1, L""));
|
||||
};
|
||||
}
|
||||
|
||||
std::tie(decoderOutputH, encoderOutputC) = LSTMPComponentWithSelfStabilization<float>(decoderOutputH, hiddenDim, hiddenDim, recurrenceHookH, recurrenceHookC, device);
|
||||
}
|
||||
|
||||
auto decoderOutput = decoderOutputH;
|
||||
auto decoderDim = hiddenDim;
|
||||
|
||||
/* Softmax output layer */
|
||||
auto outputLayerProjWeights = Parameter(NDArrayView::RandomUniform<float>({ labelVocabDim, decoderDim }, -0.05, 0.05, 1, device));
|
||||
auto biasWeights = Parameter({ labelVocabDim }, 0.0f, device);
|
||||
|
||||
auto z = Plus(Times(outputLayerProjWeights, stabilize(decoderOutput)), biasWeights, L"classifierOutput");
|
||||
auto ce = CrossEntropyWithSoftmax(z, labelSequence, L"lossFunction");
|
||||
auto errs = ClassificationError(z, labelSequence, L"classificationError");
|
||||
|
||||
if (testSaveAndReLoad)
|
||||
{
|
||||
Variable zVar = z;
|
||||
Variable ceVar = ce;
|
||||
Variable errsVar = errs;
|
||||
auto seq2seqModel = Combine({ ce, errs, z }, L"seq2seqModel");
|
||||
SaveAndReloadModel<float>(seq2seqModel, { &rawInput, &rawLabels, &ceVar, &errsVar, &zVar }, device);
|
||||
|
||||
z = zVar;
|
||||
ce = ceVar;
|
||||
errs = errsVar;
|
||||
}
|
||||
|
||||
auto minibatchSource = CreateSeq2SeqMinibatchSource(L"cmudict-0.7b.train-dev-20-21.bsf.ctf.2", inputVocabDim, labelVocabDim);
|
||||
auto rawInputStreamInfo = minibatchSource->StreamInfo(L"rawInput");
|
||||
auto rawLabelsStreamInfo = minibatchSource->StreamInfo(L"rawLabels");
|
||||
|
||||
double learningRatePerSample = 0.007;
|
||||
size_t momentumTimeConstant = 1100;
|
||||
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
|
||||
double clippingThresholdPerSample = 2.3;
|
||||
bool gradientClippingWithTruncation = true;
|
||||
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumPerSample, clippingThresholdPerSample, gradientClippingWithTruncation) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
size_t minibatchSize = 72;
|
||||
for (size_t i = 0; true; i++)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
if (minibatchData.empty())
|
||||
break;
|
||||
|
||||
trainer.TrainMinibatch({ { rawInput, minibatchData[rawInputStreamInfo].m_data }, { rawLabels, minibatchData[rawLabelsStreamInfo].m_data } }, device);
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
||||
void TrainSequenceToSequenceTranslator()
|
||||
{
|
||||
// TODO: Also test with sparse input variables in the graph
|
||||
|
||||
TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, true);
|
||||
TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, false);
|
||||
}
|
|
@ -15,16 +15,12 @@ FunctionPtr Embedding(const Variable& input, size_t embeddingDim, const DeviceDe
|
|||
return Times(embeddingParameters, input);
|
||||
}
|
||||
|
||||
FunctionPtr SelectLast(const Variable& operand)
|
||||
{
|
||||
return Slice(operand, Axis::DefaultDynamicAxis(), -1, 0);
|
||||
}
|
||||
|
||||
FunctionPtr LSTMSequenceClassiferNet(const Variable& input, size_t numOutputClasses, size_t embeddingDim, size_t LSTMDim, size_t cellDim, const DeviceDescriptor& device, const std::wstring& outputName)
|
||||
{
|
||||
auto embeddingFunction = Embedding(input, embeddingDim, device);
|
||||
auto LSTMFunction = LSTMPComponentWithSelfStabilization<float>(embeddingFunction, LSTMDim, cellDim, device);
|
||||
auto thoughtVectorFunction = SelectLast(LSTMFunction);
|
||||
auto pastValueRecurrenceHook = std::bind(PastValue, _1, CNTK::Constant({}, 0.0f), 1, L"");
|
||||
auto LSTMFunction = LSTMPComponentWithSelfStabilization<float>(embeddingFunction, LSTMDim, cellDim, pastValueRecurrenceHook, pastValueRecurrenceHook, device).first;
|
||||
auto thoughtVectorFunction = Sequence::Last(LSTMFunction);
|
||||
|
||||
return FullyConnectedLinearLayer(thoughtVectorFunction, numOutputClasses, device, outputName);
|
||||
}
|
||||
|
@ -38,27 +34,34 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool testSaveAnd
|
|||
const size_t numOutputClasses = 5;
|
||||
|
||||
Variable features({ inputDim }, true /*isSparse*/, DataType::Float, L"features");
|
||||
auto classifierOutputFunction = LSTMSequenceClassiferNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, L"classifierOutput");
|
||||
Variable classifierOutput = classifierOutputFunction;
|
||||
auto classifierOutput = LSTMSequenceClassiferNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, L"classifierOutput");
|
||||
|
||||
Variable labels({ numOutputClasses }, DataType::Float, L"labels", { Axis::DefaultBatchAxis() });
|
||||
auto trainingLossFunction = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");
|
||||
Variable trainingLoss = trainingLossFunction;
|
||||
auto predictionFunction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
|
||||
Variable prediction = predictionFunction;
|
||||
|
||||
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss.Owner(), prediction.Owner(), classifierOutput.Owner() }, L"classifierModel");
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");
|
||||
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
|
||||
|
||||
if (testSaveAndReLoad)
|
||||
SaveAndReloadModel<float>(oneHiddenLayerClassifier, { &features, &labels, &trainingLoss, &prediction, &classifierOutput }, device);
|
||||
{
|
||||
Variable classifierOutputVar = classifierOutput;
|
||||
Variable trainingLossVar = trainingLoss;
|
||||
Variable predictionVar = prediction;
|
||||
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss, prediction, classifierOutput }, L"classifierModel");
|
||||
SaveAndReloadModel<float>(oneHiddenLayerClassifier, { &features, &labels, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
|
||||
|
||||
classifierOutput = classifierOutputVar;
|
||||
trainingLoss = trainingLossVar;
|
||||
prediction = predictionVar;
|
||||
}
|
||||
|
||||
auto minibatchSource = CreateTextMinibatchSource(L"Train.ctf", inputDim, numOutputClasses, 0, true, false, L"x", L"y");
|
||||
const size_t minibatchSize = 200;
|
||||
|
||||
auto featureStreamInfo = minibatchSource->StreamInfo(features);
|
||||
auto labelStreamInfo = minibatchSource->StreamInfo(labels);
|
||||
|
||||
double learningRatePerSample = 0.0005;
|
||||
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) });
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 1;
|
||||
for (size_t i = 0; true; i++)
|
||||
{
|
||||
|
@ -67,12 +70,7 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool testSaveAnd
|
|||
break;
|
||||
|
||||
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
|
||||
|
||||
if ((i % outputFrequencyInMinibatches) == 0)
|
||||
{
|
||||
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
|
||||
}
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -35,28 +35,34 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
|
|||
|
||||
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device));
|
||||
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device));
|
||||
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput));
|
||||
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
|
||||
|
||||
Variable labels({ numOutputClasses }, DataType::Float, L"labels");
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");;
|
||||
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
|
||||
|
||||
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss, prediction, classifierOutput }, L"classifierModel");
|
||||
// Test save and reload of model
|
||||
{
|
||||
Variable classifierOutputVar = classifierOutput;
|
||||
Variable trainingLossVar = trainingLoss;
|
||||
Variable predictionVar = prediction;
|
||||
auto combinedNet = Combine({ trainingLoss, prediction, classifierOutput }, L"feedForwardClassifier");
|
||||
SaveAndReloadModel<float>(combinedNet, { &input, &labels, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
|
||||
|
||||
classifierOutput = classifierOutputVar;
|
||||
trainingLoss = trainingLossVar;
|
||||
prediction = predictionVar;
|
||||
}
|
||||
|
||||
double learningRatePerSample = 0.02;
|
||||
minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, SIZE_MAX);
|
||||
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) });
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device);
|
||||
|
||||
if ((i % outputFrequencyInMinibatches) == 0)
|
||||
{
|
||||
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
|
||||
}
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,13 +77,24 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
|
|||
auto classifierOutput = FullyConnectedDNNLayer(scaledInput, hiddenLayerDim, device, std::bind(Sigmoid, _1, L""));
|
||||
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device));
|
||||
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device));
|
||||
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput));
|
||||
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
|
||||
|
||||
Variable labels({ numOutputClasses }, DataType::Float, L"labels");
|
||||
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");;
|
||||
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
|
||||
|
||||
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss, prediction, classifierOutput }, L"classifierModel");
|
||||
// Test save and reload of model
|
||||
{
|
||||
Variable classifierOutputVar = classifierOutput;
|
||||
Variable trainingLossVar = trainingLoss;
|
||||
Variable predictionVar = prediction;
|
||||
auto combinedNet = Combine({ trainingLoss, prediction, classifierOutput }, L"MNISTClassifier");
|
||||
SaveAndReloadModel<float>(combinedNet, { &input, &labels, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
|
||||
|
||||
classifierOutput = classifierOutputVar;
|
||||
trainingLoss = trainingLossVar;
|
||||
prediction = predictionVar;
|
||||
}
|
||||
|
||||
const size_t minibatchSize = 32;
|
||||
const size_t numSamplesPerSweep = 60000;
|
||||
|
@ -91,18 +108,14 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
|
|||
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"labels"); });
|
||||
|
||||
double learningRatePerSample = 0.003125;
|
||||
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) });
|
||||
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
|
||||
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
|
||||
{
|
||||
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
|
||||
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device);
|
||||
|
||||
if ((i % outputFrequencyInMinibatches) == 0)
|
||||
{
|
||||
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
|
||||
}
|
||||
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -111,6 +111,7 @@
|
|||
<ItemGroup>
|
||||
<ClCompile Include="CifarResNet.cpp" />
|
||||
<ClCompile Include="LearnerTests.cpp" />
|
||||
<ClCompile Include="Seq2Seq.cpp" />
|
||||
<ClCompile Include="SerializationTests.cpp" />
|
||||
<ClCompile Include="FeedForwardTests.cpp" />
|
||||
<ClCompile Include="FunctionTests.cpp" />
|
||||
|
|
|
@ -48,6 +48,9 @@
|
|||
<ClCompile Include="LearnerTests.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Seq2Seq.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="Common.h">
|
||||
|
|
Загрузка…
Ссылка в новой задаче