CNTK v2 library: Added Seq2Seq implementation as a test using V2 C++ API and other related changes

This commit is contained in:
Amit Agarwal 2016-08-30 22:50:25 -07:00
Родитель 009e53dfe7
Коммит 7a5c133edc
25 изменённых файлов: 1171 добавлений и 487 удалений

Просмотреть файл

@ -416,6 +416,7 @@ CNTKLIBRARY_TESTS_SRC =\
Tests/UnitTests/V2LibraryTests/LearnerTests.cpp \ Tests/UnitTests/V2LibraryTests/LearnerTests.cpp \
Tests/UnitTests/V2LibraryTests/FunctionTests.cpp \ Tests/UnitTests/V2LibraryTests/FunctionTests.cpp \
Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp \ Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp \
Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp \
CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests
CNTKLIBRARY_TESTS_OBJ := $(patsubst %.cu, $(OBJDIR)/%.o, $(patsubst %.cpp, $(OBJDIR)/%.o, $(CNTKLIBRARY_TESTS_SRC))) CNTKLIBRARY_TESTS_OBJ := $(patsubst %.cu, $(OBJDIR)/%.o, $(patsubst %.cpp, $(OBJDIR)/%.o, $(CNTKLIBRARY_TESTS_SRC)))

Просмотреть файл

@ -698,6 +698,9 @@ namespace CNTK
CNTK_API static const std::wstring StaticAxisNamePrefix; CNTK_API static const std::wstring StaticAxisNamePrefix;
static const size_t SentinelStaticAxisIndexValueForDynamicAxes = SIZE_MAX; static const size_t SentinelStaticAxisIndexValueForDynamicAxes = SIZE_MAX;
// TODO: Make this thread-safe
CNTK_API static std::unordered_set<std::wstring> s_allKnownDynamicAxisNames;
public: public:
/// ///
/// Construct an Axis object denoting a static axis with the specified index. /// Construct an Axis object denoting a static axis with the specified index.
@ -713,7 +716,9 @@ namespace CNTK
/// ///
explicit Axis(const std::wstring& name, bool isOrderedDynamicAxis = true) explicit Axis(const std::wstring& name, bool isOrderedDynamicAxis = true)
: m_staticAxisIdx(SentinelStaticAxisIndexValueForDynamicAxes), m_name(name), m_isOrderedDynamicAxis(isOrderedDynamicAxis) : m_staticAxisIdx(SentinelStaticAxisIndexValueForDynamicAxes), m_name(name), m_isOrderedDynamicAxis(isOrderedDynamicAxis)
{} {
RegisterAxisName(name);
}
/// ///
/// Returns a boolean indicating if 'this' Axis corresponds to a static axis /// Returns a boolean indicating if 'this' Axis corresponds to a static axis
@ -746,6 +751,11 @@ namespace CNTK
/// ///
CNTK_API static const Axis& DefaultBatchAxis(); CNTK_API static const Axis& DefaultBatchAxis();
///
/// Returns a new unique Dynamic axis
///
CNTK_API static Axis NewUniqueDynamicAxis(const std::wstring& axisNamePrefix, bool isOrderedDynamicAxis = true);
/// ///
/// Name of 'this' axis /// Name of 'this' axis
/// ///
@ -758,6 +768,9 @@ namespace CNTK
: m_staticAxisIdx(SentinelStaticAxisIndexValueForDynamicAxes) : m_staticAxisIdx(SentinelStaticAxisIndexValueForDynamicAxes)
{} {}
private:
CNTK_API void RegisterAxisName(const std::wstring& axisName);
private: private:
size_t m_staticAxisIdx; size_t m_staticAxisIdx;
std::wstring m_name; std::wstring m_name;
@ -819,7 +832,9 @@ namespace CNTK
template <typename T> template <typename T>
friend struct std::hash; friend struct std::hash;
public:
CNTK_API static const std::vector<Axis> DefaultInputVariableDynamicAxes; CNTK_API static const std::vector<Axis> DefaultInputVariableDynamicAxes;
public: public:
/// ///
/// Create an 'Input' Variable. /// Create an 'Input' Variable.
@ -904,6 +919,11 @@ namespace CNTK
/// ///
CNTK_API Variable(const FunctionPtr& function); CNTK_API Variable(const FunctionPtr& function);
///
/// Implicit conversion to a FunctionPtr; creates a pass through primitive function
///
CNTK_API operator FunctionPtr() const;
/// ///
/// Default constructor for creating an invalid/null Variable instance. /// Default constructor for creating an invalid/null Variable instance.
/// Required for use in a std::vector container. /// Required for use in a std::vector container.
@ -1063,6 +1083,70 @@ namespace CNTK
: Variable(shape, VariableKind::Parameter, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), true, {}, name) : Variable(shape, VariableKind::Parameter, AsDataType<ElemType>(), MakeSharedObject<NDArrayView>(initValue, shape, device), true, {}, name)
{} {}
///
/// Create a Parameter initialized with random values drawn from a Uniform distribution in the range [-0.05, 0.05]
///
static Parameter Uniform(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
{
return UniformInitParameter(shape, type, 1.0/20, seed, device, name);
}
///
/// Create a Parameter initialized with random values from a Uniform distribution in the range [-sqrt(6 / fanIn), sqrt(6 / fanIn)]
///
static Parameter HeUniform(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
{
NDShape fanInShape = shape.SubShape(fanOutRank);
return UniformInitParameter(shape, type, std::sqrt(6.0/fanInShape.TotalSize()), seed, device, name);
}
///
/// Create a Parameter initialized with random values from a Uniform distribution in the range [-sqrt(6 / (fanIn + fanOut)), sqrt(6 / (fanIn + fanOut))]
///
static Parameter GlorotUniform(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
{
NDShape fanOutShape = shape.SubShape(0, fanOutRank);
NDShape fanInShape = shape.SubShape(fanOutRank);
return UniformInitParameter(shape, type, std::sqrt(6.0 / (fanInShape.TotalSize() + fanOutShape.TotalSize())), seed, device, name);
}
///
/// Create a Parameter initialized with random values from a Uniform distribution in the range [-sqrt(3 / fanIn), sqrt(3 / fanIn)]
///
static Parameter Xavier(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
{
NDShape fanInShape = shape.SubShape(fanOutRank);
return UniformInitParameter(shape, type, std::sqrt(3.0 / fanInShape.TotalSize()), seed, device, name);
}
///
/// Create a Parameter initialized with random values drawn from a Gaussian distribution with [mean = 0, stdDev = sqrt(0.04 / fanIn)]
///
static Parameter Gaussian(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
{
NDShape fanInShape = shape.SubShape(fanOutRank);
return NormalInitParameter(shape, type, std::sqrt(0.04 / fanInShape.TotalSize()), seed, device, name);
}
///
/// Create a Parameter initialized with random values from a Gaussian distribution with [mean = 0, stdDev = sqrt(2 / fanIn)]
///
static Parameter HeNormal(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
{
NDShape fanInShape = shape.SubShape(fanOutRank);
return NormalInitParameter(shape, type, std::sqrt(2.0 / fanInShape.TotalSize()), seed, device, name);
}
///
/// Create a Parameter initialized with random values from a Gaussian distribution with [mean = 0, stdDev = sqrt(2 / (fanIn + fanOut))]
///
static Parameter GlorotNormal(const NDShape& shape, DataType type = DataType::Float, unsigned long seed = 1, size_t fanOutRank = 1, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), const std::wstring& name = L"")
{
NDShape fanOutShape = shape.SubShape(0, fanOutRank);
NDShape fanInShape = shape.SubShape(fanOutRank);
return NormalInitParameter(shape, type, std::sqrt(2.0 / (fanInShape.TotalSize() + fanOutShape.TotalSize())), seed, device, name);
}
/// ///
/// DownCast a Variable to a Parameter. Only allowed if the VariableKind is Parameter and throws an exception otherwise. /// DownCast a Variable to a Parameter. Only allowed if the VariableKind is Parameter and throws an exception otherwise.
/// ///
@ -1080,6 +1164,12 @@ namespace CNTK
{ {
return Variable::Value(); return Variable::Value();
} }
private:
// Helper methods for Parameter construction
CNTK_API static Parameter UniformInitParameter(const NDShape& shape, DataType type, double range, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name);
CNTK_API static Parameter NormalInitParameter(const NDShape& shape, DataType type, double stdDev, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name);
}; };
// Implementation note: The Variable type is a value type and not polymorphic in nature. // Implementation note: The Variable type is a value type and not polymorphic in nature.
@ -1154,8 +1244,8 @@ namespace CNTK
/// ///
/// Contruct a Placeholder with the specified NDShape /// Contruct a Placeholder with the specified NDShape
/// ///
explicit Placeholder(const NDShape& shape, const std::wstring& name = L"") explicit Placeholder(const NDShape& shape, const std::vector<Axis>& dynamicAxes = DefaultInputVariableDynamicAxes)
: Variable(shape, VariableKind::Placeholder, DataType::Unknown, nullptr, false, { Axis::DefaultDynamicAxis(), Axis::DefaultBatchAxis() }, name) : Variable(shape, VariableKind::Placeholder, DataType::Unknown, nullptr, false, dynamicAxes, L"")
{} {}
/// ///
@ -1427,16 +1517,8 @@ namespace CNTK
if (uniqueOutputs.find(outputVar) != uniqueOutputs.end()) if (uniqueOutputs.find(outputVar) != uniqueOutputs.end())
RuntimeError("Same variable appears multiple times in the outputs vector passed to Function constructor"); RuntimeError("Same variable appears multiple times in the outputs vector passed to Function constructor");
switch (outputVar.Kind()) m_outputs.push_back(outputVar);
{ uniqueOutputs.insert(outputVar);
case VariableKind::Output:
m_outputs.push_back(outputVar);
uniqueOutputs.insert(outputVar);
break;
default:
InvalidArgument("Function output has invalid VariableKind!");
break;
}
} }
} }
@ -1454,6 +1536,14 @@ namespace CNTK
/// ///
CNTK_API FunctionPtr Negate(const Variable& operand, const std::wstring& name = L""); CNTK_API FunctionPtr Negate(const Variable& operand, const std::wstring& name = L"");
///
/// Unary negation operator corresponding to the Negate operation
///
inline FunctionPtr operator-(const Variable& operand)
{
return Negate(operand);
}
/// ///
/// Create an instance of the CNTK built-in elementwise sigmoid operation with the specified input operand. /// Create an instance of the CNTK built-in elementwise sigmoid operation with the specified input operand.
/// ///
@ -1555,11 +1645,27 @@ namespace CNTK
/// ///
CNTK_API FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L""); CNTK_API FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Binary addition operator corresponding to the Plus operation
///
inline FunctionPtr operator+(const Variable& leftOperand, const Variable& rightOperand)
{
return Plus(leftOperand, rightOperand);
}
/// ///
/// Create an instance of the CNTK built-in elementwise tensor subtraction operation with the specified input operands. /// Create an instance of the CNTK built-in elementwise tensor subtraction operation with the specified input operands.
/// ///
CNTK_API FunctionPtr Minus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L""); CNTK_API FunctionPtr Minus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Binary minus operator corresponding to the Minus operation
///
inline FunctionPtr operator-(const Variable& leftOperand, const Variable& rightOperand)
{
return Minus(leftOperand, rightOperand);
}
/// ///
/// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands. /// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands.
/// ///
@ -1733,6 +1839,11 @@ namespace CNTK
/// ///
CNTK_API FunctionPtr Clip(const Variable& operand, const Variable& min, const Variable& max, const std::wstring& name = L""); CNTK_API FunctionPtr Clip(const Variable& operand, const Variable& min, const Variable& max, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise choice operation using a condition tensor for specified tensor operands.
///
CNTK_API FunctionPtr ElementSelect(const Variable& condition, const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
/// ///
/// Create an instance of the CNTK built-in splice operation to splice together all the specified tensor operands into a single output tensor /// Create an instance of the CNTK built-in splice operation to splice together all the specified tensor operands into a single output tensor
/// ///
@ -1746,6 +1857,21 @@ namespace CNTK
/// ///
CNTK_API FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name = L""); CNTK_API FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name = L"");
namespace Sequence
{
CNTK_API FunctionPtr IsFirst(const Variable& operand, const std::wstring& name = L"");
CNTK_API FunctionPtr IsLast(const Variable& operand, const std::wstring& name = L"");
CNTK_API FunctionPtr First(const Variable& operand, const std::wstring& name = L"");
CNTK_API FunctionPtr Last(const Variable& operand, const std::wstring& name = L"");
CNTK_API FunctionPtr Where(const Variable& condition, const std::wstring& name = L"");
CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name = L"");
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name = L"");
CNTK_API FunctionPtr BroadcastAs(const Variable& operand, const Variable& broadcastAs, const std::wstring& name = L"");
}
/// ///
/// Load a legacy CNTK v1 format model /// Load a legacy CNTK v1 format model
/// ///
@ -1859,9 +1985,9 @@ namespace CNTK
{ {
static_assert(std::is_same<T, NDShape>::value || static_assert(std::is_same<T, NDShape>::value ||
std::is_same<T, Axis>::value || std::is_same<T, Axis>::value ||
std::is_same<T, std::wstring>::value || std::is_same<T, std::wstring>::value ||
std::is_same<T, std::vector<DictionaryValue>>::value || std::is_same<T, std::vector<DictionaryValue>>::value ||
std::is_same<T, Dictionary>::value || std::is_same<T, Dictionary>::value ||
std::is_same<T, NDArrayView>::value, std::is_same<T, NDArrayView>::value,
"Unsupported ValueType"); "Unsupported ValueType");
@ -2279,35 +2405,45 @@ namespace CNTK
/// Create an instance of the CNTK built-in SGD learner. /// Create an instance of the CNTK built-in SGD learner.
/// ///
CNTK_API LearnerPtr SGDLearner(const std::unordered_set<Parameter>& parameters, CNTK_API LearnerPtr SGDLearner(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates); const LearningRatesPerSample& learningRates,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
/// ///
/// Create an instance of the CNTK built-in Momentum SGD learner. /// Create an instance of the CNTK built-in Momentum SGD learner.
/// ///
CNTK_API LearnerPtr MomentumSGDLearner(const std::unordered_set<Parameter>& parameters, CNTK_API LearnerPtr MomentumSGDLearner(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums); const MomentumsPerSample& momentums,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
/// ///
/// Create an instance of the CNTK built-in Nesterov's accelerated SGD learner. /// Create an instance of the CNTK built-in Nesterov's accelerated SGD learner.
/// ///
CNTK_API LearnerPtr NesterovLearner(const std::unordered_set<Parameter>& parameters, CNTK_API LearnerPtr NesterovLearner(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums); const MomentumsPerSample& momentums,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
/// bool gradientClippingWithTruncation = true);
/// Create an instance of the CNTK built-in AdaGrad learner.
///
CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates,
bool needAveMultiplier = true);
/// ///
/// Create an instance of the CNTK built-in FSAdaGrad (improved AdaGrad) learner. /// Create an instance of the CNTK built-in FSAdaGrad (improved AdaGrad) learner.
/// ///
CNTK_API LearnerPtr FSAdaGradLearner(const std::unordered_set<Parameter>& parameters, CNTK_API LearnerPtr FSAdaGradLearner(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums); const MomentumsPerSample& momentums,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
///
/// Create an instance of the CNTK built-in AdaGrad learner.
///
CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates,
bool needAveMultiplier = true,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
/// ///
/// Create an instance of the CNTK built-in RMSProp learner. /// Create an instance of the CNTK built-in RMSProp learner.
@ -2319,7 +2455,9 @@ namespace CNTK
double dec, double dec,
double max, double max,
double min, double min,
bool needAveMultiplier = true); bool needAveMultiplier = true,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
/// ///
/// Trainer is the top-level abstraction responsible for the orchestration of the training of a model /// Trainer is the top-level abstraction responsible for the orchestration of the training of a model
@ -2333,7 +2471,15 @@ namespace CNTK
/// Construct a Trainer to train the specified 'model' with the specified 'trainingLoss' Variable as the training criterion /// Construct a Trainer to train the specified 'model' with the specified 'trainingLoss' Variable as the training criterion
/// and using the specified set of 'parameterLearners' for updating the model's parameters using computed gradients. /// and using the specified set of 'parameterLearners' for updating the model's parameters using computed gradients.
/// ///
CNTK_API Trainer(const FunctionPtr& model, const Variable& trainingLoss, const std::unordered_set<LearnerPtr>& parameterLearners); CNTK_API Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const std::unordered_set<LearnerPtr>& parameterLearners);
///
/// Construct a Trainer to train the specified 'model' with the specified 'trainingLoss' as the training criterion,
/// the specified 'evaluationFunction' as the criterion for evaluating the trained model's quality, and using the specified set
/// of 'parameterLearners' for updating the model's parameters using computed gradients.
///
// TODO: Add overload for multiple evaluation criterion
CNTK_API Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const FunctionPtr& evaluationFunction, const std::unordered_set<LearnerPtr>& parameterLearners);
/// ///
/// Optimize model parameters using the specified 'arguments' minibatch of training samples. /// Optimize model parameters using the specified 'arguments' minibatch of training samples.
@ -2341,26 +2487,41 @@ namespace CNTK
/// ///
CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); CNTK_API bool TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
///
/// Test the model on the specified batch of samples using the evaluation Function specified during construction of the Trainer
/// Returns the average evaluation criterion value per sample for the tested minibatch of samples
///
CNTK_API double TestMinbatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
/// ///
/// Model being trained by 'this' Trainer. /// Model being trained by 'this' Trainer.
/// ///
FunctionPtr Model() const { return m_model; } FunctionPtr Model() const { return m_model; }
/// ///
/// Variable of the Trainer's model representing the training loss that is used as the optimization /// Loss function that is used as the optimization criterion for learning the model's parameters.
/// criterion for learning the model's parameters.
/// ///
Variable TrainingLossVariable() const { return m_trainingLossVar; } FunctionPtr LossFunction() const { return m_lossFunction; }
/// ///
/// Returns the Value of the training loss variable of the model corresponding to the last minibatch trained with /// Evaluation Function that is used as for the criterion for evaluating the trained model's quality.
/// ///
ValuePtr PreviousMinibatchTrainingLossValue() const { return m_prevMinibatchTrainingLossValue; } FunctionPtr EvaluationFunction() const { return m_evaluationFunction; }
/// ///
/// Returns the training loss corresponding to the last minibatch trained as a double /// Returns the average training loss per sample for the last minibatch trained.
/// ///
CNTK_API double PreviousMinibatchAverageTrainingLoss() const; CNTK_API double PreviousMinibatchLossAverage() const;
///
/// Returns the average evaluation criterion value per sample for the last minibatch trained.
///
CNTK_API double PreviousMinibatchEvaluationAverage() const;
///
/// Returns the number of samples in the last minibatch trained with
///
size_t PreviousMinibatchSampleCount() const { return m_prevMinibatchNumSamples; }
/// ///
/// Learners associated with this Trainer for updating the model's parameters using computed gradients. /// Learners associated with this Trainer for updating the model's parameters using computed gradients.
@ -2368,11 +2529,16 @@ namespace CNTK
const std::unordered_set<LearnerPtr>& ParameterLearners() const { return m_parameterLearners; } const std::unordered_set<LearnerPtr>& ParameterLearners() const { return m_parameterLearners; }
private: private:
FunctionPtr m_combinedTrainingFunction;
FunctionPtr m_model; FunctionPtr m_model;
Variable m_trainingLossVar; FunctionPtr m_lossFunction;
ValuePtr m_prevMinibatchTrainingLossValue; FunctionPtr m_evaluationFunction;
size_t m_prevMinibatchNumSamples;
std::unordered_set<LearnerPtr> m_parameterLearners; std::unordered_set<LearnerPtr> m_parameterLearners;
size_t m_prevMinibatchNumSamples;
ValuePtr m_prevMinibatchAggregateTrainingLossValue;
ValuePtr m_prevMinibatchAggregateEvalCriterionValue;
}; };
/// ///

Просмотреть файл

@ -182,11 +182,17 @@ namespace CNTK
namespace Internal namespace Internal
{ {
// Create a new Function instance which just passes through specified list of 'operands'.
CNTK_API FunctionPtr Combine(const std::vector<Variable>& operands, const std::wstring& name = L"");
CNTK_API FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name = L"");
CNTK_API FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name = L""); CNTK_API FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name = L"");
CNTK_API FunctionPtr GatherPacked(const Variable& operand, const Variable& packedIndex, const std::wstring& name = L""); CNTK_API FunctionPtr GatherPacked(const Variable& operand, const Variable& packedIndex, const std::wstring& name = L"");
CNTK_API FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name = L""); CNTK_API FunctionPtr ScatterPacked(const Variable& operand, const Variable& packedIndex, const Variable& condition, const std::wstring& name = L"");
CNTK_API FunctionPtr ZeroesLike(const Variable& operand);
CNTK_API FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L""); CNTK_API FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L""); CNTK_API FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
CNTK_API FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name = L"");
CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L""); CNTK_API FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name = L"");
CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L""); CNTK_API FunctionPtr ReduceElements(const Variable& operand, const std::wstring& reductionOpName, const Axis& axis, const std::wstring& name = L"");
} }

Просмотреть файл

@ -54,12 +54,11 @@ namespace CNTK
{ {
if (node->Is<InputValueBase<ElementType>>()) if (node->Is<InputValueBase<ElementType>>())
{ {
auto inputNode = node->As<InputValueBase<ElementType>>();
bool isSparse = node->Is<SparseInputValue<ElementType>>(); bool isSparse = node->Is<SparseInputValue<ElementType>>();
if (node->HasMBLayout()) if (node->HasMBLayout())
{ {
// TODO: Currently only default dynamic axis is supported // TODO: Currently only default dynamic axis is supported
auto inputNodeInternalDynamicAxisName = inputNode->GetRequestedDynamicAxis(); auto inputNodeInternalDynamicAxisName = node->GetMBLayout()->GetAxisName();
std::vector<Axis> inputVarDynamicAxes = DynamicAxesFromInternalDynamicAxisName(inputNodeInternalDynamicAxisName); std::vector<Axis> inputVarDynamicAxes = DynamicAxesFromInternalDynamicAxisName(inputNodeInternalDynamicAxisName);
var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->GetName(), inputVarDynamicAxes); var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->GetName(), inputVarDynamicAxes);
@ -74,7 +73,7 @@ namespace CNTK
{ {
bool isConstant = (node->GetLearningRateMultiplier() == 0); bool isConstant = (node->GetLearningRateMultiplier() == 0);
auto& matrix = node->As<ComputationNode<ElementType>>()->Value(); auto& matrix = node->As<ComputationNode<ElementType>>()->Value();
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), node->GetSampleLayout()); auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorViewShape(node->GetSampleLayout()));
NDArrayViewPtr parameterValue = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView); NDArrayViewPtr parameterValue = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), varShape, false, tensorView);
if (isConstant) if (isConstant)
var = Constant(parameterValue, node->GetName()); var = Constant(parameterValue, node->GetName());
@ -87,7 +86,14 @@ namespace CNTK
else else
{ {
// This is a non-leaf node and maps to a primitive Function // This is a non-leaf node and maps to a primitive Function
auto placeholderVar = Placeholder(varShape); std::vector<Axis> varDynamicAxes;;
if (node->HasMBLayout())
{
auto nodeInternalDynamicAxisName = node->GetMBLayout()->GetAxisName();
varDynamicAxes = DynamicAxesFromInternalDynamicAxisName(nodeInternalDynamicAxisName);
}
auto placeholderVar = Placeholder(varShape, varDynamicAxes);
nodeToVariableMap[node] = placeholderVar; nodeToVariableMap[node] = placeholderVar;
std::vector<Variable> inputVars(node->GetNumInputs()); std::vector<Variable> inputVars(node->GetNumInputs());
@ -134,14 +140,9 @@ namespace CNTK
} }
else if (node->OperationName() == OperationNameOf(WhereNode)) else if (node->OperationName() == OperationNameOf(WhereNode))
{ {
auto whereNode = node->As<WhereNode<ElementType>>(); auto internalDynamicAxisName = node->GetMBLayout()->GetAxisName();
auto internalDynamicAxisName = whereNode->DynamicAxisName();
std::vector<Axis> dynamicAxes = DynamicAxesFromInternalDynamicAxisName(internalDynamicAxisName); std::vector<Axis> dynamicAxes = DynamicAxesFromInternalDynamicAxisName(internalDynamicAxisName);
std::vector<std::wstring> dynamicAxesNames; primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(dynamicAxes);
for (auto axis : dynamicAxes)
dynamicAxesNames.push_back(axis.Name());
primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(dynamicAxesNames);
opType = PrimitiveOpType::Where; opType = PrimitiveOpType::Where;
} }
@ -196,6 +197,13 @@ namespace CNTK
std::swap(inputVars[0], inputVars[1]); std::swap(inputVars[0], inputVars[1]);
opType = PrimitiveOpType::GatherPacked; opType = PrimitiveOpType::GatherPacked;
} }
else if (node->OperationName() == OperationNameOf(ScatterPackedNode))
{
// The internal scatter node has layout as the first input and the actual source as the last operand
// which is different from the corresponding V2 Function's ordering of the inputs
std::swap(inputVars[0], inputVars[2]);
opType = PrimitiveOpType::ScatterPacked;
}
else if (node->OperationName() == OperationNameOf(TimesNode)) else if (node->OperationName() == OperationNameOf(TimesNode))
{ {
primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameOutputRank] = (size_t)node->As<TimesNode<ElementType>>()->OutputRank(); primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameOutputRank] = (size_t)node->As<TimesNode<ElementType>>()->OutputRank();
@ -296,6 +304,8 @@ namespace CNTK
opType = PrimitiveOpType::Clip; opType = PrimitiveOpType::Clip;
} }
else if (node->OperationName() == OperationNameOf(IfNode))
opType = PrimitiveOpType::Select;
else if (node->OperationName() == OperationNameOf(RowStackNode)) else if (node->OperationName() == OperationNameOf(RowStackNode))
{ {
// Internal CNTK SliceNode uses 1 based axis indices instead of 0 based // Internal CNTK SliceNode uses 1 based axis indices instead of 0 based

Просмотреть файл

@ -32,6 +32,8 @@ namespace CNTK
/*static*/ const std::wstring Axis::StaticAxisNamePrefix = L"staticAxis_"; /*static*/ const std::wstring Axis::StaticAxisNamePrefix = L"staticAxis_";
/*static*/ std::unordered_set<std::wstring> Axis::s_allKnownDynamicAxisNames;
/*static*/ const Axis& Axis::DefaultDynamicAxis() /*static*/ const Axis& Axis::DefaultDynamicAxis()
{ {
static const Axis s_defaultDynamicAxis(L"defaultDynamicAxis"); static const Axis s_defaultDynamicAxis(L"defaultDynamicAxis");
@ -43,4 +45,22 @@ namespace CNTK
static const Axis s_defaultBatchAxis(L"defaultBatchAxis", false); static const Axis s_defaultBatchAxis(L"defaultBatchAxis", false);
return s_defaultBatchAxis; return s_defaultBatchAxis;
} }
/*static*/ Axis Axis::NewUniqueDynamicAxis(const std::wstring& axisNamePrefix, bool isOrderedDynamicAxis /*= true*/)
{
if (s_allKnownDynamicAxisNames.find(axisNamePrefix) == s_allKnownDynamicAxisNames.end())
return Axis(axisNamePrefix, isOrderedDynamicAxis);
for (size_t i = 1;; i++)
{
auto newDynamicAxisName = axisNamePrefix + std::to_wstring(i);
if (s_allKnownDynamicAxisNames.find(newDynamicAxisName) == s_allKnownDynamicAxisNames.end())
return Axis(newDynamicAxisName, isOrderedDynamicAxis);
}
}
void Axis::RegisterAxisName(const std::wstring& axisName)
{
s_allKnownDynamicAxisNames.insert(axisName);
}
} }

Просмотреть файл

@ -128,11 +128,11 @@ namespace CNTK
// We currently require that the inputs' dynamic axes if any match // We currently require that the inputs' dynamic axes if any match
std::vector<Axis> outputDynamicAxes; std::vector<Axis> outputDynamicAxes;
if (op == PrimitiveOpType::Where) if (op == PrimitiveOpType::Where)
; outputDynamicAxes = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value<std::vector<DictionaryValue>>());
else if (op == PrimitiveOpType::ScatterPacked)
outputDynamicAxes = inputs[2].DynamicAxes();
else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked)) else if ((op == PrimitiveOpType::PackedIndex) || (op == PrimitiveOpType::GatherPacked))
{
outputDynamicAxes = inputs[1].DynamicAxes(); outputDynamicAxes = inputs[1].DynamicAxes();
}
else else
{ {
outputDynamicAxes = inputs[0].DynamicAxes(); outputDynamicAxes = inputs[0].DynamicAxes();
@ -178,7 +178,7 @@ namespace CNTK
if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis()) if (!axis1.IsStaticAxis() || !axis2.IsStaticAxis())
LogicError("TransposeAxes operation currently does not support transposing dynamic axes"); LogicError("TransposeAxes operation currently does not support transposing dynamic axes");
auto transposedTensorShape = AsTensorShape(inputs[0].Shape(), true); auto transposedTensorShape = AsTensorShape(inputs[0].Shape());
transposedTensorShape.SwapDimsInPlace(axis1.StaticAxisIndex(), axis2.StaticAxisIndex()); transposedTensorShape.SwapDimsInPlace(axis1.StaticAxisIndex(), axis2.StaticAxisIndex());
outputs.push_back(Variable(AsNDShape(transposedTensorShape), outputDataType, owner, outputDynamicAxes)); outputs.push_back(Variable(AsNDShape(transposedTensorShape), outputDataType, owner, outputDynamicAxes));
break; break;
@ -186,12 +186,7 @@ namespace CNTK
case PrimitiveOpType::Where: case PrimitiveOpType::Where:
{ {
assert(inputs.size() == 1); assert(inputs.size() == 1);
std::vector<Axis> newDynamicAxes; outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes));
auto newDynamicAxesNames = AsBasicElementTypeVector<std::wstring>(functionConfig[PrimitiveFunction::AttributeNameNewDynamicAxes].Value<std::vector<DictionaryValue>>());
for (auto axisName : newDynamicAxesNames)
newDynamicAxes.push_back(Axis(axisName));
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, newDynamicAxes));
break; break;
} }
case PrimitiveOpType::Slice: case PrimitiveOpType::Slice:
@ -216,7 +211,7 @@ namespace CNTK
realEndIndex, realEndIndex,
inputs[0].Shape().AsString().c_str()); inputs[0].Shape().AsString().c_str());
auto outputTensorShape = AsTensorShape(inputs[0].Shape(), true); auto outputTensorShape = AsTensorShape(inputs[0].Shape());
// propagate as much as we can // propagate as much as we can
if ((axis.StaticAxisIndex() < outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)) if ((axis.StaticAxisIndex() < outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim))
@ -242,7 +237,7 @@ namespace CNTK
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>()); auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
outputs.push_back(Variable(ConvolutionOpOutputShape(inputs[0].Shape(), poolingWindowsShape, { 1 }, strides, { true }, autoPadding, lowerPad, upperPad, false), outputDataType, owner, outputDynamicAxes)); outputs.push_back(Variable(ConvolutionOpOutputShape(inputs[0].Shape(), poolingWindowsShape, { 1 }, strides, { true }, autoPadding, lowerPad, upperPad, false), outputDataType, owner, outputDynamicAxes));
break; break;
} }
@ -291,8 +286,8 @@ namespace CNTK
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto sharing = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>()); auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>()); auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>(); bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
if (inputs[0].Shape().NumAxes() < inputs[1].Shape().NumAxes()) if (inputs[0].Shape().NumAxes() < inputs[1].Shape().NumAxes())
InvalidArgument("The convolution map should have at least as many axes as the shape of the input it operates on!"); InvalidArgument("The convolution map should have at least as many axes as the shape of the input it operates on!");
@ -333,10 +328,6 @@ namespace CNTK
if (!initialStateVar.IsConstant() || (initialStateVar.Shape().NumAxes() > 0)) if (!initialStateVar.IsConstant() || (initialStateVar.Shape().NumAxes() > 0))
LogicError("Currently PastValue/FutureValue Function only supports scalar initial state"); LogicError("Currently PastValue/FutureValue Function only supports scalar initial state");
// TODO: We currently only support input operand with 1 static axis for PastValue/FutureValue
if (inputOperandVar.Shape().NumAxes() > 1)
LogicError("Currently PastValue/FutureValue Function only supports input operand with <= 1 static axis");
// TODO: We currently only support input operand with 1 dynamic axis for PastValue/FutureValue // TODO: We currently only support input operand with 1 dynamic axis for PastValue/FutureValue
if (inputOperandVar.DynamicAxes().size() != 2) if (inputOperandVar.DynamicAxes().size() != 2)
LogicError("Currently PastValue/FutureValue Function only supports input operand with with 2 dynamic axis (1 sequence-axis and 1 batch-axis)"); LogicError("Currently PastValue/FutureValue Function only supports input operand with with 2 dynamic axis (1 sequence-axis and 1 batch-axis)");
@ -382,10 +373,25 @@ namespace CNTK
outputs.push_back(Variable(outputShape, outputDataType, owner, outputDynamicAxes)); outputs.push_back(Variable(outputShape, outputDataType, owner, outputDynamicAxes));
break; break;
} }
case PrimitiveOpType::ScatterPacked:
{
if (inputs[0].DynamicAxes().empty() || inputs[1].DynamicAxes().empty() || inputs[2].DynamicAxes().empty())
InvalidArgument("ScatterPacked requires all its operands to have dynamic axes");
if (inputs[1].Shape().NumAxes() != 1)
InvalidArgument("ScatterPacked requires the packedIndex operand to be a scalar sequence");
outputs.push_back(Variable(inputs[0].Shape(), outputDataType, owner, outputDynamicAxes));
break;
}
case PrimitiveOpType::Clip: case PrimitiveOpType::Clip:
assert(inputs.size() == 3); assert(inputs.size() == 3);
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes)); outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes));
break; break;
case PrimitiveOpType::Select:
assert(inputs.size() == 3);
outputs.push_back(Variable(NaryElementwiseOpOutputShape(op, { inputs[0].Shape(), inputs[1].Shape(), inputs[2].Shape() }), outputDataType, owner, outputDynamicAxes));
break;
case PrimitiveOpType::Splice: case PrimitiveOpType::Splice:
{ {
assert(inputs.size() >= 2); assert(inputs.size() >= 2);
@ -405,8 +411,8 @@ namespace CNTK
// Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over) // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over)
// and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed
// when the new CNTK v2 model serialization format is ready. // when the new CNTK v2 model serialization format is ready.
/*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L""; /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*";
/*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"noSequenceAxis"; /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis";
// Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables // Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables
// should have been replaced before performing any Forward compute of 'this' Function. // should have been replaced before performing any Forward compute of 'this' Function.
@ -486,7 +492,7 @@ namespace CNTK
// Construct the dynamic axis name to be used internally for the CNTK InputNodes // Construct the dynamic axis name to be used internally for the CNTK InputNodes
std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
if (!internalDynamicAxisName.empty()) if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {}); network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});
if (IsSparseInput(variable)) if (IsSparseInput(variable))
@ -524,60 +530,60 @@ namespace CNTK
auto functionName = primitiveFunction->Name(); auto functionName = primitiveFunction->Name();
auto& functionConfig = primitiveFunction->FunctionConfig(); auto& functionConfig = primitiveFunction->FunctionConfig();
auto functionInputs = primitiveFunction->Inputs(); auto functionInputs = primitiveFunction->Inputs();
PrimitiveOpType op = primitiveFunction->OpType(); PrimitiveOpType op = primitiveFunction->OpType();
switch (op) switch (op)
{ {
case PrimitiveOpType::Negate: case PrimitiveOpType::Negate:
computationNodePtr = builder.Negate(inputNodes[0], functionName); computationNodePtr = builder.Negate(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Sigmoid: case PrimitiveOpType::Sigmoid:
computationNodePtr = builder.Sigmoid(inputNodes[0], functionName); computationNodePtr = builder.Sigmoid(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Tanh: case PrimitiveOpType::Tanh:
computationNodePtr = builder.Tanh(inputNodes[0], functionName); computationNodePtr = builder.Tanh(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::ReLU: case PrimitiveOpType::ReLU:
computationNodePtr = builder.RectifiedLinear(inputNodes[0], functionName); computationNodePtr = builder.RectifiedLinear(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Exp: case PrimitiveOpType::Exp:
computationNodePtr = builder.Exp(inputNodes[0], functionName); computationNodePtr = builder.Exp(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Log: case PrimitiveOpType::Log:
computationNodePtr = builder.Log(inputNodes[0], functionName); computationNodePtr = builder.Log(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Sqrt: case PrimitiveOpType::Sqrt:
computationNodePtr = builder.Sqrt(inputNodes[0], functionName); computationNodePtr = builder.Sqrt(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Floor: case PrimitiveOpType::Floor:
computationNodePtr = builder.Floor(inputNodes[0], functionName); computationNodePtr = builder.Floor(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Abs: case PrimitiveOpType::Abs:
computationNodePtr = builder.Abs(inputNodes[0], functionName); computationNodePtr = builder.Abs(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Reciprocal: case PrimitiveOpType::Reciprocal:
computationNodePtr = builder.Reciprocal(inputNodes[0], functionName); computationNodePtr = builder.Reciprocal(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Softmax: case PrimitiveOpType::Softmax:
computationNodePtr = builder.Softmax(inputNodes[0], functionName); computationNodePtr = builder.Softmax(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Hardmax: case PrimitiveOpType::Hardmax:
computationNodePtr = builder.Hardmax(inputNodes[0], functionName); computationNodePtr = builder.Hardmax(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::TransposeAxes: case PrimitiveOpType::TransposeAxes:
{ {
auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>(); auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>(); auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();
// The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), functionName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2)); computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), functionName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0] }); network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0] });
break; break;
} }
case PrimitiveOpType::Where: case PrimitiveOpType::Where:
{ {
auto dynamicAxes = variable.DynamicAxes(); auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), functionName, internalCNTKWhereNodeDynamicAxisName); computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), functionName, internalCNTKWhereNodeDynamicAxisName);
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0] }); network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0] });
@ -604,141 +610,144 @@ namespace CNTK
case PrimitiveOpType::Reshape: case PrimitiveOpType::Reshape:
{ {
auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>(); auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
computationNodePtr = builder.Reshape(inputNodes[0], AsTensorShape(newShape, true /*preserveRank*/), functionName); computationNodePtr = builder.Reshape(inputNodes[0], AsTensorShape(newShape), functionName);
break; break;
} }
case PrimitiveOpType::Pooling: case PrimitiveOpType::Pooling:
{ {
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>()); PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>(); auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>()); auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
computationNodePtr = builder.Pooling(inputNodes[0], AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape, true), AsTensorShape(strides, true), autoPadding, AsTensorShape(lowerPad, true), AsTensorShape(upperPad, true), ImageLayoutKind::CHW, functionName); computationNodePtr = builder.Pooling(inputNodes[0], AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW, functionName);
break; break;
} }
case PrimitiveOpType::SumAll: case PrimitiveOpType::SumAll:
computationNodePtr = builder.Sum(inputNodes[0], functionName); computationNodePtr = builder.Sum(inputNodes[0], functionName);
break; break;
case PrimitiveOpType::Plus: case PrimitiveOpType::Plus:
computationNodePtr = builder.Plus(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.Plus(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::Minus: case PrimitiveOpType::Minus:
computationNodePtr = builder.Minus(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.Minus(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::ElementTimes: case PrimitiveOpType::ElementTimes:
computationNodePtr = builder.ElementTimes(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.ElementTimes(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::Equal: case PrimitiveOpType::Equal:
computationNodePtr = builder.Equal(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.Equal(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::NotEqual: case PrimitiveOpType::NotEqual:
computationNodePtr = builder.NotEqual(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.NotEqual(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::Less: case PrimitiveOpType::Less:
computationNodePtr = builder.Less(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.Less(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::LessEqual: case PrimitiveOpType::LessEqual:
computationNodePtr = builder.LessEqual(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.LessEqual(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::Greater: case PrimitiveOpType::Greater:
computationNodePtr = builder.Greater(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.Greater(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::GreaterEqual: case PrimitiveOpType::GreaterEqual:
computationNodePtr = builder.GreaterEqual(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.GreaterEqual(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::Times: case PrimitiveOpType::Times:
{ {
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>(); size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
computationNodePtr = builder.Times(inputNodes[0], inputNodes[1], outputRank, functionName); computationNodePtr = builder.Times(inputNodes[0], inputNodes[1], outputRank, functionName);
break; break;
} }
case PrimitiveOpType::TransposeTimes: case PrimitiveOpType::TransposeTimes:
{ {
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>(); size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
computationNodePtr = network->AddNodeToNetAndAttachInputs(New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), functionName, outputRank), { inputNodes[0], inputNodes[1] }); computationNodePtr = network->AddNodeToNetAndAttachInputs(New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), functionName, outputRank), { inputNodes[0], inputNodes[1] });
break; break;
} }
case PrimitiveOpType::Convolution: case PrimitiveOpType::Convolution:
{ {
NDShape outputMapCount, kernelShape; NDShape outputMapCount, kernelShape;
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape()); std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape());
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto sharing = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>()); auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
auto autoPadding = AsBasicElementTypeVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>()); auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>(); auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>(); auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
computationNodePtr = builder.Convolution(inputNodes[0], inputNodes[1], AsTensorShape(kernelShape, true), AsTensorShape(outputMapCount, true), AsTensorShape(strides, true), sharing, autoPadding, AsTensorShape(lowerPad, true), AsTensorShape(upperPad, true), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples, functionName); computationNodePtr = builder.Convolution(inputNodes[0], inputNodes[1], AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples, functionName);
break; break;
} }
case PrimitiveOpType::SquaredError: case PrimitiveOpType::SquaredError:
computationNodePtr = builder.SquareError(inputNodes[0], inputNodes[1], functionName); computationNodePtr = builder.SquareError(inputNodes[0], inputNodes[1], functionName);
break; break;
case PrimitiveOpType::CrossEntropyWithSoftmax: case PrimitiveOpType::CrossEntropyWithSoftmax:
computationNodePtr = builder.CrossEntropyWithSoftmax(inputNodes[1], inputNodes[0], functionName); computationNodePtr = builder.CrossEntropyWithSoftmax(inputNodes[1], inputNodes[0], functionName);
break; break;
case PrimitiveOpType::ClassificationError: case PrimitiveOpType::ClassificationError:
computationNodePtr = builder.ClassificationError(inputNodes[1], inputNodes[0], functionName); computationNodePtr = builder.ClassificationError(inputNodes[1], inputNodes[0], functionName);
break; break;
case PrimitiveOpType::PastValue: case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue: case PrimitiveOpType::FutureValue:
{ {
Variable inputOperandVar = functionInputs[0]; Variable inputOperandVar = functionInputs[0];
Variable initialStateVar = functionInputs[1]; Variable initialStateVar = functionInputs[1];
// Get the intial state of the PastValue/FutureValue operation // Get the intial state of the PastValue/FutureValue operation
ElementType initStateValue; ElementType initStateValue;
NDArrayView tempView({}, &initStateValue, 1, DeviceDescriptor::CPUDevice()); NDArrayView tempView({}, &initStateValue, 1, DeviceDescriptor::CPUDevice());
tempView.CopyFrom(*Constant(initialStateVar).Value()); tempView.CopyFrom(*Constant(initialStateVar).Value());
size_t offset = primitiveFunction->FunctionConfig()[PrimitiveFunction::AttributeNameOffset].Value<size_t>(); size_t offset = primitiveFunction->FunctionConfig()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
if (op == PrimitiveOpType::PastValue) if (op == PrimitiveOpType::PastValue)
computationNodePtr = builder.PastValue(inputNodes[0], (float)initStateValue, inputOperandVar.Shape().TotalSize(), offset, functionName); computationNodePtr = builder.PastValue(inputNodes[0], (float)initStateValue, inputOperandVar.Shape().TotalSize(), offset, functionName);
else else
computationNodePtr = builder.FutureValue(inputNodes[0], (float)initStateValue, inputOperandVar.Shape().TotalSize(), offset, functionName); computationNodePtr = builder.FutureValue(inputNodes[0], (float)initStateValue, inputOperandVar.Shape().TotalSize(), offset, functionName);
break; break;
} }
case PrimitiveOpType::ReduceElements: case PrimitiveOpType::ReduceElements:
{ {
auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>(); auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>(); auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
computationNodePtr = network->AddNodeToNetAndAttachInputs(New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), functionName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)), { inputNodes[0] }); computationNodePtr = network->AddNodeToNetAndAttachInputs(New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), functionName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)), { inputNodes[0] });
break; break;
} }
case PrimitiveOpType::BatchNormalization: case PrimitiveOpType::BatchNormalization:
{ {
auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>(); auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>(); auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>(); auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>(); auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>(); auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();
computationNodePtr = builder.BatchNormalization(inputNodes[0], inputNodes[1], inputNodes[2], inputNodes[3], inputNodes[4], spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW, functionName); computationNodePtr = builder.BatchNormalization(inputNodes[0], inputNodes[1], inputNodes[2], inputNodes[3], inputNodes[4], spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW, functionName);
break;
}
case PrimitiveOpType::Combine:
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
// whose outputs are a union of the outputs of the Functions being combined.
computationNodePtr = variableToNodeMap[variable];
break;
case PrimitiveOpType::PackedIndex:
computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), functionName);
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0], inputNodes[1] });
break;
case PrimitiveOpType::GatherPacked:
computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), functionName);
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[1], inputNodes[0] });
break;
case PrimitiveOpType::Clip:
{
computationNodePtr = builder.Clip(inputNodes[1], inputNodes[2], inputNodes[0], functionName);
break; break;
} }
case PrimitiveOpType::Combine:
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
// whose outputs are a union of the outputs of the Functions being combined.
computationNodePtr = variableToNodeMap[variable];
break;
case PrimitiveOpType::PackedIndex:
computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), functionName);
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[0], inputNodes[1] });
break;
case PrimitiveOpType::GatherPacked:
computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), functionName);
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[1], inputNodes[0] });
break;
case PrimitiveOpType::ScatterPacked:
computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), functionName);
network->AddNodeToNetAndAttachInputs(computationNodePtr, { inputNodes[2], inputNodes[1], inputNodes[0] });
break;
case PrimitiveOpType::Clip:
computationNodePtr = builder.Clip(inputNodes[1], inputNodes[2], inputNodes[0], functionName);
break;
case PrimitiveOpType::Select:
computationNodePtr = builder.If(inputNodes[0], inputNodes[1], inputNodes[2], functionName);
break;
case PrimitiveOpType::Splice: case PrimitiveOpType::Splice:
{ {
Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>(); Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
@ -750,12 +759,12 @@ namespace CNTK
inputNodesBasePtrs.push_back(inputNode); inputNodesBasePtrs.push_back(inputNode);
network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs); network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
break; break;
} }
default: default:
LogicError("Specified op %s not yet supported", PrimitiveOpTypeName(op)); LogicError("Specified op %s not yet supported", PrimitiveOpTypeName(op));
break; break;
} }
return computationNodePtr; return computationNodePtr;
} }
@ -881,7 +890,7 @@ namespace CNTK
auto outputShape = outputVar.Shape(); auto outputShape = outputVar.Shape();
auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout(); auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
if (((outputShape.NumAxes() == 0) && (computationNodeSampleLayout[0] != 1)) || if (((outputShape.NumAxes() == 0) && (computationNodeSampleLayout[0] != 1)) ||
((outputShape.NumAxes() != 0) && (computationNodeSampleLayout != AsTensorShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape, true)))) ((outputShape.NumAxes() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape))))
{ {
LogicError("The output Variable shape %s does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", AsString(outputShape).c_str(), ((std::string)computationNodeSampleLayout).c_str()); LogicError("The output Variable shape %s does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", AsString(outputShape).c_str(), ((std::string)computationNodeSampleLayout).c_str());
} }
@ -1026,7 +1035,7 @@ namespace CNTK
if ((layout == nullptr) || (layout->GetNumTimeSteps() == 1) || (layout->GetNumSequences() == 1)) if ((layout == nullptr) || (layout->GetNumTimeSteps() == 1) || (layout->GetNumSequences() == 1))
{ {
// Just create a view over the existing matrix itself // Just create a view over the existing matrix itself
auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorShape(valueDataShape)); auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorViewShape(valueDataShape));
auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), valueDataShape, readOnly, tensorView); auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), valueDataShape, readOnly, tensorView);
return MakeSharedObject<Value>(data); return MakeSharedObject<Value>(data);
} }
@ -1085,7 +1094,7 @@ namespace CNTK
} }
} }
auto tensorView = new TensorView<ElementType>(shuffledMatrixData, AsTensorShape(valueDataShape)); auto tensorView = new TensorView<ElementType>(shuffledMatrixData, AsTensorViewShape(valueDataShape));
auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(shuffledMatrixData->GetFormat()), valueDataShape, readOnly, tensorView); auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(shuffledMatrixData->GetFormat()), valueDataShape, readOnly, tensorView);
return MakeSharedObject<Value>(data, mask); return MakeSharedObject<Value>(data, mask);
} }
@ -1114,21 +1123,16 @@ namespace CNTK
auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value(); auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
// Switch the node matrix to the right matrix type // Switch the node matrix to the right matrix type
nodeData.SwitchToMatrixType(CNTKMatrixAndMBLayout.first->GetMatrixType(), CNTKMatrixAndMBLayout.first->GetFormat(), false);
nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first); nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
computationNode->GetMBLayout()->CopyFrom(layout); computationNode->GetMBLayout()->CopyFrom(layout);
} }
void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments) void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
{ {
auto functionArguments = this->Arguments();
std::vector<ComputationNodeBasePtr> inputNodes; std::vector<ComputationNodeBasePtr> inputNodes;
for (auto argument : functionArguments) for (auto argumentValuePair : arguments)
{ {
// Ensure we have values for all arguments of the function auto argument = argumentValuePair.first;
if (arguments.find(argument) == arguments.end())
InvalidArgument("Value not specified for required Function Argument");
auto argumentComputationNode = m_variableToNodeMap[argument]; auto argumentComputationNode = m_variableToNodeMap[argument];
inputNodes.push_back(argumentComputationNode); inputNodes.push_back(argumentComputationNode);
@ -1300,6 +1304,7 @@ namespace CNTK
list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode)); list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode));
for (auto& nodeIter : dropoutNodes) for (auto& nodeIter : dropoutNodes)
nodeIter->SetEvalTimeStampOutdatedWrtAll(); nodeIter->SetEvalTimeStampOutdatedWrtAll();
std::unordered_set<Variable> functionOutputs(this->Outputs().begin(), this->Outputs().end()); std::unordered_set<Variable> functionOutputs(this->Outputs().begin(), this->Outputs().end());
std::vector<ComputationNodeBasePtr> outputsToEvaluate; std::vector<ComputationNodeBasePtr> outputsToEvaluate;
@ -1316,10 +1321,15 @@ namespace CNTK
// The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs' // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
for (auto rootVarForBackprop : outputsToRetainBackwardStateFor) for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
{ {
if (functionOutputs.find(rootVarForBackprop) == functionOutputs.end())
InvalidArgument("Requested outputs to retain backward state for is not an Ouptut of the Function");
if (outputs.find(rootVarForBackprop) == outputs.end()) if (outputs.find(rootVarForBackprop) == outputs.end())
outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]); outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]);
} }
// TODO: Verify that values were supplied for all inputs that requested outputs depend on
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);
m_computationNetwork->ForwardProp(outputsToEvaluate); m_computationNetwork->ForwardProp(outputsToEvaluate);
@ -1467,14 +1477,19 @@ namespace CNTK
} }
FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name /*= L""*/) FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name /*= L""*/)
{ {
if ((endIndex - beginIndex) <= 0)
InvalidArgument("CNTK::Slice: endIndex (%d) - beginIndex (%d) must be a positive number", endIndex, beginIndex);
if (axis == Axis::DefaultBatchAxis()) if (axis == Axis::DefaultBatchAxis())
LogicError("Slice is currently unsupported along the batch axis"); LogicError("Slice is currently unsupported along the batch axis");
if (axis.IsStaticAxis()) if (axis.IsStaticAxis())
{
if ((endIndex - beginIndex) <= 0)
InvalidArgument("CNTK::Slice: endIndex (%d) - beginIndex (%d) must be a positive number", endIndex, beginIndex);
return Internal::Slice(operand, axis, beginIndex, endIndex, name); return Internal::Slice(operand, axis, beginIndex, endIndex, name);
}
if ((beginIndex == 0) && (endIndex == 0))
return operand;
auto operandAxes = operand.DynamicAxes(); auto operandAxes = operand.DynamicAxes();
auto findAxis = std::find(operandAxes.begin(), operandAxes.end(), axis); auto findAxis = std::find(operandAxes.begin(), operandAxes.end(), axis);
@ -1504,7 +1519,7 @@ namespace CNTK
if (operandAxis == axis) if (operandAxis == axis)
{ {
// If we are selecting just one frame from the dynamic axis, we can remove that axis // If we are selecting just one frame from the dynamic axis, we can remove that axis
if ((endIndex - beginIndex) > 1) if ((endIndex - beginIndex) != 1)
newDynamicAxes.push_back(CompositeFunction::NextAutoGeneratedDynamicAxis()); newDynamicAxes.push_back(CompositeFunction::NextAutoGeneratedDynamicAxis());
} }
else else
@ -1746,6 +1761,12 @@ namespace CNTK
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Clip, std::vector<Variable>({ operand, min, max }), Dictionary(), name), name); return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Clip, std::vector<Variable>({ operand, min, max }), Dictionary(), name), name);
} }
FunctionPtr ElementSelect(const Variable& condition, const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name /*= L""*/)
{
// TODO: If the condition is a scalar constant, we can just pass-through the appropriate operand
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Select, std::vector<Variable>({ condition, leftOperand, rightOperand }), Dictionary(), name), name);
}
FunctionPtr Splice(const std::vector<Variable>& operands, size_t axis, const std::wstring& name /*= L""*/) FunctionPtr Splice(const std::vector<Variable>& operands, size_t axis, const std::wstring& name /*= L""*/)
{ {
auto additionalProperties = Dictionary(); auto additionalProperties = Dictionary();
@ -1753,25 +1774,116 @@ namespace CNTK
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Splice, operands, std::move(additionalProperties), name), name); return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Splice, operands, std::move(additionalProperties), name), name);
} }
FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name/* = L""*/) FunctionPtr Combine(const std::vector<FunctionPtr>& operands, const std::wstring& name/* = L""*/)
{ {
std::unordered_set<FunctionPtr> uniqueOperands;
std::vector<Variable> inputs; std::vector<Variable> inputs;
for (auto operand : operands) for (auto operand : operands)
{ {
if (uniqueOperands.find(operand) != uniqueOperands.end())
LogicError("All function operands specified to Combine must be unique");
uniqueOperands.insert(operand);
auto currentFunctionOutputs = operand->Outputs(); auto currentFunctionOutputs = operand->Outputs();
std::copy(currentFunctionOutputs.begin(), currentFunctionOutputs.end(), std::back_inserter(inputs)); std::copy(currentFunctionOutputs.begin(), currentFunctionOutputs.end(), std::back_inserter(inputs));
} }
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Combine, inputs, Dictionary(), name), name); return Internal::Combine(inputs);
}
namespace Sequence
{
void VerifyIsSequence(const Variable& operand)
{
// The operand must have at least one dynamic axis and it's first dynamic axis must be ordered
if (operand.DynamicAxes().empty() || !operand.DynamicAxes()[0].IsOrdered())
InvalidArgument("A sequence function can only be applied on operands with at least one dynamic axis and whose first dynamic axis is ordered");
}
FunctionPtr IsFirst(const Variable& operand, const std::wstring& name /*= L""*/)
{
VerifyIsSequence(operand);
return Internal::IsWithin(operand, 1);
}
FunctionPtr IsLast(const Variable& operand, const std::wstring& name /*= L""*/)
{
VerifyIsSequence(operand);
return Internal::IsWithin(operand, -1);
}
FunctionPtr First(const Variable& operand, const std::wstring& name /*= L""*/)
{
VerifyIsSequence(operand);
return Slice(operand, operand.DynamicAxes()[0], 0, 1);
}
FunctionPtr Last(const Variable& operand, const std::wstring& name /*= L""*/)
{
VerifyIsSequence(operand);
return Slice(operand, operand.DynamicAxes()[0], -1, 0);
}
std::vector<Axis> WhereOpDynamicAxes(const Variable& operand)
{
VerifyIsSequence(operand);
std::vector<Axis> newDynamicAxes = { Axis::NewUniqueDynamicAxis(L"whereNodeDynamicAxis") };
for (size_t i = 1; i < operand.DynamicAxes().size(); ++i)
newDynamicAxes.push_back(operand.DynamicAxes()[i]);
return newDynamicAxes;
}
FunctionPtr Where(const Variable& condition, const std::wstring& name /*= L""*/)
{
return Internal::Where(condition, WhereOpDynamicAxes(condition), name);
}
FunctionPtr Gather(const Variable& operand, const Variable& condition, const std::wstring& name /*= L""*/)
{
return Internal::Gather(operand, condition, WhereOpDynamicAxes(condition), name);
}
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::wstring& name /*= L""*/)
{
return Internal::Scatter(operand, condition, WhereOpDynamicAxes(condition), name);
}
FunctionPtr BroadcastAs(const Variable& operand, const Variable& broadcastAs, const std::wstring& name /*= L""*/)
{
auto dataPadded = Internal::Scatter(operand, Sequence::IsFirst(broadcastAs), broadcastAs.DynamicAxes());
auto placeHolderOutput = Placeholder(operand.Shape(), broadcastAs.DynamicAxes());
auto output = ElementSelect(Sequence::IsFirst(dataPadded), dataPadded, PastValue(placeHolderOutput, ScalarConstant(operand.GetDataType(), 0.0f), 1), name);
return output->ReplacePlaceholders({ { placeHolderOutput, output } });
}
} }
namespace Internal namespace Internal
{ {
FunctionPtr Combine(const std::vector<Variable>& operands, const std::wstring& name /*= L""*/)
{
std::unordered_set<Variable> uniqueOperands;
for (auto operand : operands)
{
if (uniqueOperands.find(operand) != uniqueOperands.end())
LogicError("All operands specified to Combine must be unique");
uniqueOperands.insert(operand);
}
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Combine, operands, Dictionary(), name), name);
}
FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name /*= L""*/)
{
Sequence::VerifyIsSequence(operand);
if (offset == 0)
InvalidArgument("CNTK::Sequence::IsWithin: The offset must be positive");
if (offset > 0)
return PastValue(Internal::ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), offset, name);
else
return FutureValue(Internal::ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), -offset, name);
}
FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name /*= L""*/) FunctionPtr PackedIndex(const Variable& operand, const Variable& index, const std::wstring& name /*= L""*/)
{ {
return BinaryOp(PrimitiveOpType::PackedIndex, operand, index, Dictionary(), name); return BinaryOp(PrimitiveOpType::PackedIndex, operand, index, Dictionary(), name);
@ -1782,34 +1894,36 @@ namespace CNTK
return BinaryOp(PrimitiveOpType::GatherPacked, operand, packedIndex, Dictionary(), name); return BinaryOp(PrimitiveOpType::GatherPacked, operand, packedIndex, Dictionary(), name);
} }
FunctionPtr ScatterPacked(const Variable& operand, const Variable& packedIndex, const Variable& condition, const std::wstring& name /*= L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ScatterPacked, std::vector<Variable>({ operand, packedIndex, condition }), Dictionary(), name), name);
}
FunctionPtr ZeroesLike(const Variable& operand) FunctionPtr ZeroesLike(const Variable& operand)
{ {
if (operand.Shape().NumAxes() > 1) if (operand.Shape().NumAxes() > 1)
LogicError("ZerosLike currently does not support operands with more than 1 static axes"); LogicError("Internal::ZeroesLike: Currently only 1D inputs are supported!");
auto rowSliceFunc = Internal::Slice(operand, Axis(0), 0, 1); if (operand.IsSparse())
return Minus(rowSliceFunc, rowSliceFunc); {
} if (operand.GetDataType() == DataType::Float)
return Times(Constant({1, operand.Shape()[0]}, 0.0f), operand);
FunctionPtr IsWithin(const Variable& operand, int offset, const std::wstring& name /*= L""*/) else if (operand.GetDataType() == DataType::Double)
{ return Times(Constant({ 1, operand.Shape()[0] }, 0.0), operand);
if (offset == 0) else
InvalidArgument("Internal::CNTK::IsWithin: The offset must be positive"); LogicError("Unsupported DataType %s", DataTypeName(operand.GetDataType()));
}
if (offset > 0)
return PastValue(ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), offset, name);
else else
return FutureValue(ZeroesLike(operand), ScalarConstant(operand.GetDataType(), 1.0f), -offset, name); {
auto rowSliceFunc = Internal::Slice(operand, Axis(0), 0, 1);
return Minus(rowSliceFunc, rowSliceFunc);
}
} }
FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name /*= L""*/) FunctionPtr Where(const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name /*= L""*/)
{ {
auto additionalProperties = Dictionary(); auto additionalProperties = Dictionary();
std::vector<std::wstring> newDynamicAxesNames; additionalProperties[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(newDynamicAxes);
for (auto axis : newDynamicAxes)
newDynamicAxesNames.push_back(axis.Name());
additionalProperties[PrimitiveFunction::AttributeNameNewDynamicAxes] = AsDictionaryValueVector(newDynamicAxesNames);
return UnaryOp(PrimitiveOpType::Where, condition, std::move(additionalProperties), name); return UnaryOp(PrimitiveOpType::Where, condition, std::move(additionalProperties), name);
} }
@ -1818,6 +1932,11 @@ namespace CNTK
return Internal::GatherPacked(operand, Internal::PackedIndex(operand, Where(condition, newDynamicAxes))); return Internal::GatherPacked(operand, Internal::PackedIndex(operand, Where(condition, newDynamicAxes)));
} }
FunctionPtr Scatter(const Variable& operand, const Variable& condition, const std::vector<Axis>& newDynamicAxes, const std::wstring& name /*= L""*/)
{
return Internal::ScatterPacked(operand, Internal::PackedIndex(operand, Where(condition, newDynamicAxes)), condition);
}
FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name /*= L""*/) FunctionPtr Slice(const Variable& operand, const Axis& axis, int beginIndex, int endIndex, const std::wstring& name /*= L""*/)
{ {
auto additionalProperties = Dictionary(); auto additionalProperties = Dictionary();

Просмотреть файл

@ -46,6 +46,7 @@ namespace CNTK
GreaterEqual, GreaterEqual,
PackedIndex, PackedIndex,
GatherPacked, GatherPacked,
ScatterPacked,
Times, Times,
TransposeTimes, TransposeTimes,
Convolution, Convolution,
@ -57,6 +58,7 @@ namespace CNTK
ReduceElements, ReduceElements,
BatchNormalization, BatchNormalization,
Clip, Clip,
Select,
Splice, Splice,
Combine, Combine,
}; };
@ -77,7 +79,7 @@ namespace CNTK
{ {
inline const char* PrimitiveOpTypeName(PrimitiveOpType opType) inline const char* PrimitiveOpTypeName(PrimitiveOpType opType)
{ {
static std::unordered_map<PrimitiveOpType, const char*> primitiveOpNames = { static const std::unordered_map<PrimitiveOpType, const char*> primitiveOpNames = {
{ PrimitiveOpType::Negate, "Negate" }, { PrimitiveOpType::Negate, "Negate" },
{ PrimitiveOpType::Sigmoid, "Sigmoid" }, { PrimitiveOpType::Sigmoid, "Sigmoid" },
{ PrimitiveOpType::Tanh, "Tanh" }, { PrimitiveOpType::Tanh, "Tanh" },
@ -108,6 +110,7 @@ namespace CNTK
{ PrimitiveOpType::GreaterEqual, "GreaterEqual" }, { PrimitiveOpType::GreaterEqual, "GreaterEqual" },
{ PrimitiveOpType::PackedIndex, "PackedIndex" }, { PrimitiveOpType::PackedIndex, "PackedIndex" },
{ PrimitiveOpType::GatherPacked, "GatherPacked" }, { PrimitiveOpType::GatherPacked, "GatherPacked" },
{ PrimitiveOpType::ScatterPacked, "ScatterPacked" },
{ PrimitiveOpType::Times, "Times" }, { PrimitiveOpType::Times, "Times" },
{ PrimitiveOpType::TransposeTimes, "TransposeTimes" }, { PrimitiveOpType::TransposeTimes, "TransposeTimes" },
{ PrimitiveOpType::Convolution, "Convolution" }, { PrimitiveOpType::Convolution, "Convolution" },
@ -119,6 +122,7 @@ namespace CNTK
{ PrimitiveOpType::ReduceElements, "ReduceElements" }, { PrimitiveOpType::ReduceElements, "ReduceElements" },
{ PrimitiveOpType::BatchNormalization, "BatchNormalization" }, { PrimitiveOpType::BatchNormalization, "BatchNormalization" },
{ PrimitiveOpType::Clip, "Clip" }, { PrimitiveOpType::Clip, "Clip" },
{ PrimitiveOpType::Select, "Select" },
{ PrimitiveOpType::Splice, "Splice" }, { PrimitiveOpType::Splice, "Splice" },
{ PrimitiveOpType::Combine, "Combine" } { PrimitiveOpType::Combine, "Combine" }
}; };
@ -288,9 +292,9 @@ namespace CNTK
{ {
if ((leftOperandShape[i] == NDShape::InferredDimension) && (rightOperandShape[i] == NDShape::InferredDimension)) if ((leftOperandShape[i] == NDShape::InferredDimension) && (rightOperandShape[i] == NDShape::InferredDimension))
outputDims[i] = NDShape::InferredDimension; outputDims[i] = NDShape::InferredDimension;
else if (leftOperandShape[i] == NDShape::InferredDimension) else if ((leftOperandShape[i] == NDShape::InferredDimension) || (leftOperandShape[i] == 1))
outputDims[i] = rightOperandShape[i]; outputDims[i] = rightOperandShape[i];
else if (rightOperandShape[i] == NDShape::InferredDimension) else if ((rightOperandShape[i] == NDShape::InferredDimension) || (rightOperandShape[i] == 1))
outputDims[i] = leftOperandShape[i]; outputDims[i] = leftOperandShape[i];
else else
{ {
@ -308,6 +312,18 @@ namespace CNTK
return NDShape(std::move(outputDims)); return NDShape(std::move(outputDims));
} }
static NDShape NaryElementwiseOpOutputShape(PrimitiveOpType op, const std::vector<NDShape>& operandShapes, bool broadcastAllowed = true)
{
assert(!operandShapes.empty());
// TODO: Is this logic of transitively constructing the output shape from the operands correct?
NDShape outputShape = {};
for (auto& operandShape : operandShapes)
outputShape = BinaryElementwiseOpOutputShape(op, outputShape, operandShape, broadcastAllowed);
return outputShape;
}
static NDShape TimesOpOutputShape(const NDShape& leftOperandShape, const NDShape& rightOperandShape, size_t outputRank) static NDShape TimesOpOutputShape(const NDShape& leftOperandShape, const NDShape& rightOperandShape, size_t outputRank)
{ {
if (outputRank == 0) if (outputRank == 0)
@ -362,7 +378,7 @@ namespace CNTK
else else
computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape; computeOutputShapeFunc = &Microsoft::MSR::CNTK::ConvolveGeometry::ComputeInputShape;
return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape, true), AsTensorShape(kernelShape, true), AsTensorShape(outputMapCount, true), AsTensorShape(strides, true), sharing, autoPad, AsTensorShape(lowerPad, true), AsTensorShape(upperPad, true))); return AsNDShape(computeOutputShapeFunc(AsTensorShape(operandShape), AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPad, AsTensorShape(lowerPad), AsTensorShape(upperPad)));
} }
// TODO: Reconcile this with the ComputationNode::Validate functionality in core CNTK to avoid duplication of inference logic // TODO: Reconcile this with the ComputationNode::Validate functionality in core CNTK to avoid duplication of inference logic
@ -533,9 +549,9 @@ namespace CNTK
inline std::vector<CNTK::Axis> DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName) inline std::vector<CNTK::Axis> DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName)
{ {
std::vector<CNTK::Axis> inputVarDynamicAxes; std::vector<CNTK::Axis> inputVarDynamicAxes;
if (internalDynamicAxisName == CNTK::CompositeFunction::InternalDefaultDynamicAxisName) if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalDefaultDynamicAxisName.length()) == CNTK::CompositeFunction::InternalDefaultDynamicAxisName)
inputVarDynamicAxes = { CNTK::Axis::DefaultDynamicAxis(), CNTK::Axis::DefaultBatchAxis() }; inputVarDynamicAxes = { CNTK::Axis::DefaultDynamicAxis(), CNTK::Axis::DefaultBatchAxis() };
else if (internalDynamicAxisName == CNTK::CompositeFunction::InternalNoSequenceAxisName) else if (internalDynamicAxisName.substr(0, CNTK::CompositeFunction::InternalNoSequenceAxisName.length()) == CNTK::CompositeFunction::InternalNoSequenceAxisName)
inputVarDynamicAxes = { CNTK::Axis::DefaultBatchAxis() }; inputVarDynamicAxes = { CNTK::Axis::DefaultBatchAxis() };
else else
inputVarDynamicAxes = { CNTK::Axis(internalDynamicAxisName), CNTK::Axis::DefaultBatchAxis() }; inputVarDynamicAxes = { CNTK::Axis(internalDynamicAxisName), CNTK::Axis::DefaultBatchAxis() };

Просмотреть файл

@ -3,6 +3,7 @@
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
// //
#include "stdafx.h"
#include "Learner.h" #include "Learner.h"
#include "TensorView.h" #include "TensorView.h"
#include "Utils.h" #include "Utils.h"
@ -155,12 +156,17 @@ namespace CNTK
LearnerBase::LearnerBase(const unordered_set<Parameter>& parameters, LearnerBase::LearnerBase(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
bool allocateSmoothGradients /* = true */) bool allocateSmoothGradients /* = true */,
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
: Learner(parameters), : Learner(parameters),
m_learningRates(learningRates), m_learningRates(learningRates),
m_sampleCount(0), m_sampleCount(0),
m_minibatchCount(0) m_minibatchCount(0)
{ {
m_additionalOptions.gradientClippingThresholdPerSample = clippingThresholdPerSample;
m_additionalOptions.gradientClippingWithTruncation = gradientClippingWithTruncation;
for (const auto& parameter : parameters) for (const auto& parameter : parameters)
{ {
if (!allocateSmoothGradients) if (!allocateSmoothGradients)
@ -356,8 +362,10 @@ namespace CNTK
LearnerAdaGrad::LearnerAdaGrad(const unordered_set<Parameter>& parameters, LearnerAdaGrad::LearnerAdaGrad(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
bool needAveMultiplier) bool needAveMultiplier,
: LearnerBase(parameters, learningRates), double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
: LearnerBase(parameters, learningRates, true, clippingThresholdPerSample, gradientClippingWithTruncation),
m_needAveMultiplier(needAveMultiplier) m_needAveMultiplier(needAveMultiplier)
{ {
} }
@ -385,8 +393,10 @@ namespace CNTK
LearnerFSAdaGrad::LearnerFSAdaGrad(const unordered_set<Parameter>& parameters, LearnerFSAdaGrad::LearnerFSAdaGrad(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums) const MomentumsPerSample& momentums,
: LearnerMomentumSGD(parameters, learningRates, momentums, /*allocateSmoothGradients*/ false) double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
: LearnerMomentumSGD(parameters, learningRates, momentums, /*allocateSmoothGradients*/ false, clippingThresholdPerSample, gradientClippingWithTruncation)
{ {
for (const auto& parameter : parameters) for (const auto& parameter : parameters)
{ {
@ -417,10 +427,11 @@ namespace CNTK
} }
LearnerRMSProp::LearnerRMSProp(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, LearnerRMSProp::LearnerRMSProp(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates,
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier) double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
: LearnerBase(parameters, learningRates, /*allocateSmoothGradients*/ false), double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min), bool gradientClippingWithTruncation /*= true*/)
m_needAveMultiplier(needAveMultiplier) : LearnerBase(parameters, learningRates, /*allocateSmoothGradients*/ false, clippingThresholdPerSample, gradientClippingWithTruncation),
m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min), m_needAveMultiplier(needAveMultiplier)
{ {
for (const auto& parameter : parameters) for (const auto& parameter : parameters)
{ {
@ -467,35 +478,56 @@ namespace CNTK
template shared_ptr<Matrix<float>> LearnerBase::GetWritableMatrix<float>(const NDArrayViewPtr& arrayView); template shared_ptr<Matrix<float>> LearnerBase::GetWritableMatrix<float>(const NDArrayViewPtr& arrayView);
template shared_ptr<Matrix<double>> LearnerBase::GetWritableMatrix<double>(const NDArrayViewPtr& arrayView); template shared_ptr<Matrix<double>> LearnerBase::GetWritableMatrix<double>(const NDArrayViewPtr& arrayView);
LearnerPtr SGDLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates) LearnerPtr SGDLearner(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates,
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
{ {
return MakeSharedObject<LearnerSGD>(parameters, learningRates); return MakeSharedObject<LearnerSGD>(parameters, learningRates, true, clippingThresholdPerSample, gradientClippingWithTruncation);
} }
LearnerPtr MomentumSGDLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, const MomentumsPerSample& momentums) LearnerPtr MomentumSGDLearner(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums,
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
{ {
return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRates, momentums); return MakeSharedObject<LearnerMomentumSGD>(parameters, learningRates, momentums, true, clippingThresholdPerSample, gradientClippingWithTruncation);
} }
LearnerPtr NesterovLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, const MomentumsPerSample& momentums) LearnerPtr NesterovLearner(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums,
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
{ {
return MakeSharedObject<LearnerNesterov>(parameters, learningRates, momentums); return MakeSharedObject<LearnerNesterov>(parameters, learningRates, momentums, clippingThresholdPerSample, gradientClippingWithTruncation);
} }
LearnerPtr AdaGradLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, bool needAveMultiplier) LearnerPtr FSAdaGradLearner(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums,
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
{ {
return MakeSharedObject<LearnerAdaGrad>(parameters, learningRates, needAveMultiplier); return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRates, momentums, clippingThresholdPerSample, gradientClippingWithTruncation);
} }
LearnerPtr FSAdaGradLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, const MomentumsPerSample& momentums) LearnerPtr AdaGradLearner(const unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates,
bool needAveMultiplier /*= true*/,
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
{ {
return MakeSharedObject<LearnerFSAdaGrad>(parameters, learningRates, momentums); return MakeSharedObject<LearnerAdaGrad>(parameters, learningRates, needAveMultiplier, clippingThresholdPerSample, gradientClippingWithTruncation);
} }
LearnerPtr RMSPropLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates, LearnerPtr RMSPropLearner(const unordered_set<Parameter>& parameters, const LearningRatesPerSample& learningRates,
double gamma, double inc, double dec, double max, double min, double gamma, double inc, double dec, double max, double min,
bool needAveMultiplier) bool needAveMultiplier /*= true*/,
double clippingThresholdPerSample /*= std::numeric_limits<double>::infinity()*/,
bool gradientClippingWithTruncation /*= true*/)
{ {
return MakeSharedObject<LearnerRMSProp>(parameters, learningRates, gamma, inc, dec, max, min, needAveMultiplier); return MakeSharedObject<LearnerRMSProp>(parameters, learningRates, gamma, inc, dec, max, min, needAveMultiplier, clippingThresholdPerSample, gradientClippingWithTruncation);
} }
} }

Просмотреть файл

@ -36,7 +36,9 @@ namespace CNTK
protected: protected:
LearnerBase(const std::unordered_set<Parameter>& parameters, LearnerBase(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
bool allocateSmoothGradients = true); bool allocateSmoothGradients = true,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
virtual void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const = 0; virtual void Update(const Parameter& parameter, const NDArrayViewPtr& gradientValue, const NDArrayViewPtr& smoothedGradientValue, size_t trainingSampleCount) const = 0;
@ -104,11 +106,13 @@ namespace CNTK
public: public:
LearnerSGD(const std::unordered_set<Parameter>& parameters, LearnerSGD(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
bool allocateSmoothGradients = true) bool allocateSmoothGradients = true,
: LearnerBase(parameters, learningRates, allocateSmoothGradients), double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true)
: LearnerBase(parameters, learningRates, allocateSmoothGradients, clippingThresholdPerSample, gradientClippingWithTruncation),
m_momentums(0.0), m_momentums(0.0),
m_useNesterovAcceleration(false) m_useNesterovAcceleration(false)
{ } {}
protected: protected:
@ -129,8 +133,10 @@ namespace CNTK
LearnerMomentumSGD(const std::unordered_set<Parameter>& parameters, LearnerMomentumSGD(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums, const MomentumsPerSample& momentums,
bool allocateSmoothGradients = true) bool allocateSmoothGradients = true,
: LearnerSGD(parameters, learningRates, allocateSmoothGradients) double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true)
: LearnerSGD(parameters, learningRates, allocateSmoothGradients, clippingThresholdPerSample, gradientClippingWithTruncation)
{ {
m_momentums = momentums; m_momentums = momentums;
} }
@ -143,8 +149,10 @@ namespace CNTK
LearnerNesterov(const std::unordered_set<Parameter>& parameters, LearnerNesterov(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums) const MomentumsPerSample& momentums,
: LearnerMomentumSGD(parameters, learningRates, momentums) double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true)
: LearnerMomentumSGD(parameters, learningRates, momentums, true, clippingThresholdPerSample, gradientClippingWithTruncation)
{ {
m_useNesterovAcceleration = true; m_useNesterovAcceleration = true;
} }
@ -156,7 +164,9 @@ namespace CNTK
LearnerAdaGrad(const std::unordered_set<Parameter>& parameters, LearnerAdaGrad(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
bool needAveMultiplier); bool needAveMultiplier,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
protected: protected:
bool m_needAveMultiplier; bool m_needAveMultiplier;
@ -173,7 +183,9 @@ namespace CNTK
LearnerFSAdaGrad(const std::unordered_set<Parameter>& parameters, LearnerFSAdaGrad(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
const MomentumsPerSample& momentums); const MomentumsPerSample& momentums,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
protected: protected:
@ -190,7 +202,9 @@ namespace CNTK
LearnerRMSProp(const std::unordered_set<Parameter>& parameters, LearnerRMSProp(const std::unordered_set<Parameter>& parameters,
const LearningRatesPerSample& learningRates, const LearningRatesPerSample& learningRates,
double gamma, double inc, double dec, double max, double min, double gamma, double inc, double dec, double max, double min,
bool needAveMultiplier); bool needAveMultiplier,
double clippingThresholdPerSample = std::numeric_limits<double>::infinity(),
bool gradientClippingWithTruncation = true);
protected: protected:

Просмотреть файл

@ -29,7 +29,7 @@ namespace CNTK
auto matrixDims = GetMatrixDimensions(viewShape); auto matrixDims = GetMatrixDimensions(viewShape);
std::shared_ptr<Matrix<ElementType>> matrix = std::make_shared<Matrix<ElementType>>(matrixDims.first, matrixDims.second, (ElementType*)dataBuffer, AsCNTKImplDeviceId(device), matrixFlagDontOwnBuffer); std::shared_ptr<Matrix<ElementType>> matrix = std::make_shared<Matrix<ElementType>>(matrixDims.first, matrixDims.second, (ElementType*)dataBuffer, AsCNTKImplDeviceId(device), matrixFlagDontOwnBuffer);
return new TensorView<ElementType>(matrix, AsTensorShape(viewShape)); return new TensorView<ElementType>(matrix, AsTensorViewShape(viewShape));
} }
static void* AllocateTensorView(CNTK::DataType dataType, static void* AllocateTensorView(CNTK::DataType dataType,
@ -61,7 +61,7 @@ namespace CNTK
AsCNTKImplDeviceId(device), AsCNTKImplDeviceId(device),
IsSparseStorageFormat(storageType) ? MatrixType::SPARSE : MatrixType::DENSE, IsSparseStorageFormat(storageType) ? MatrixType::SPARSE : MatrixType::DENSE,
AsCNTKImplMatrixFormat(storageType)); AsCNTKImplMatrixFormat(storageType));
return new TensorView<ElementType>(matrix, AsTensorShape(viewShape)); return new TensorView<ElementType>(matrix, AsTensorViewShape(viewShape));
} }
static void* AllocateTensorView(CNTK::DataType dataType, static void* AllocateTensorView(CNTK::DataType dataType,
@ -320,7 +320,7 @@ namespace CNTK
{ {
auto matrixDims = GetMatrixDimensions(shape); auto matrixDims = GetMatrixDimensions(shape);
auto randomNormalMatrix = std::make_shared<Matrix<ElementType>>(Matrix<ElementType>::RandomGaussian(matrixDims.first, matrixDims.second, AsCNTKImplDeviceId(device), (ElementType)mean, (ElementType)stdDev, seed)); auto randomNormalMatrix = std::make_shared<Matrix<ElementType>>(Matrix<ElementType>::RandomGaussian(matrixDims.first, matrixDims.second, AsCNTKImplDeviceId(device), (ElementType)mean, (ElementType)stdDev, seed));
auto tensorView = new TensorView<ElementType>(randomNormalMatrix, AsTensorShape(shape)); auto tensorView = new TensorView<ElementType>(randomNormalMatrix, AsTensorViewShape(shape));
return MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), device, StorageFormat::Dense, shape, false, tensorView); return MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), device, StorageFormat::Dense, shape, false, tensorView);
} }
@ -330,7 +330,7 @@ namespace CNTK
{ {
auto matrixDims = GetMatrixDimensions(shape); auto matrixDims = GetMatrixDimensions(shape);
auto randomUniformMatrix = std::make_shared<Matrix<ElementType>>(Matrix<ElementType>::RandomUniform(matrixDims.first, matrixDims.second, AsCNTKImplDeviceId(device), (ElementType)rangeBegin, (ElementType)rangeEnd, seed)); auto randomUniformMatrix = std::make_shared<Matrix<ElementType>>(Matrix<ElementType>::RandomUniform(matrixDims.first, matrixDims.second, AsCNTKImplDeviceId(device), (ElementType)rangeBegin, (ElementType)rangeEnd, seed));
auto tensorView = new TensorView<ElementType>(randomUniformMatrix, AsTensorShape(shape)); auto tensorView = new TensorView<ElementType>(randomUniformMatrix, AsTensorViewShape(shape));
return MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), device, StorageFormat::Dense, shape, false, tensorView); return MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), device, StorageFormat::Dense, shape, false, tensorView);
} }

Просмотреть файл

@ -9,10 +9,12 @@
namespace CNTK namespace CNTK
{ {
Trainer::Trainer(const FunctionPtr& model, const Variable& trainingLoss, const std::unordered_set<LearnerPtr>& parameterLearners) Trainer::Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const FunctionPtr& evaluationFunction, const std::unordered_set<LearnerPtr>& parameterLearners)
: m_model(model), m_trainingLossVar(trainingLoss), m_parameterLearners(parameterLearners), m_prevMinibatchNumSamples(1) : m_model(model), m_lossFunction(lossFunction), m_evaluationFunction(evaluationFunction), m_parameterLearners(parameterLearners), m_prevMinibatchNumSamples(1)
{ {
auto modelParameters = model->Parameters(); m_combinedTrainingFunction = Combine({ model, lossFunction, evaluationFunction });
auto modelParameters = m_combinedTrainingFunction->Parameters();
std::unordered_set<Parameter> learnerParameters; std::unordered_set<Parameter> learnerParameters;
for (const auto& learner : parameterLearners) for (const auto& learner : parameterLearners)
{ {
@ -29,36 +31,97 @@ namespace CNTK
InvalidArgument("Trainer ctor: Union of the parameters covered by the specified parameterLearners should match the specified model's parameters"); InvalidArgument("Trainer ctor: Union of the parameters covered by the specified parameterLearners should match the specified model's parameters");
} }
Trainer::Trainer(const FunctionPtr& model, const FunctionPtr& lossFunction, const std::unordered_set<LearnerPtr>& parameterLearners)
: Trainer(model, lossFunction, nullptr, parameterLearners)
{}
static double GetScalarValue(const ValuePtr& value)
{
if (value->Mask())
LogicError("Scalar Value object cannot have an associated mask");
auto scalarData = value->Data();
if (scalarData->Shape().TotalSize() != 1)
LogicError("Scalar Value object's has a size > 1");
double scalar = std::numeric_limits<double>::quiet_NaN();
NDArrayViewPtr cpuData;
if (scalarData->Device() == DeviceDescriptor::CPUDevice())
cpuData = scalarData;
else
{
cpuData = std::make_shared<NDArrayView>(scalarData->GetDataType(), scalarData->Shape(), CNTK::DeviceDescriptor::CPUDevice());
cpuData->CopyFrom(*scalarData);
}
if (scalarData->GetDataType() == DataType::Float)
scalar = *(cpuData->DataBuffer<float>());
else if (scalarData->GetDataType() == DataType::Double)
scalar = *(cpuData->DataBuffer<double>());
else
LogicError("Unsupported DataType of training loss value");
return scalar;
}
static size_t GetSampleCountFromArguments(const Variable& evalOrLossArgument, const std::unordered_map<Variable, ValuePtr>& arguments)
{
// Find the argument whose dynamic axes match the criterion operation's dynamic axes (i.e. label dynamic axes)
// Then we determine the actual number of samples contributing to the training loss from the argument's Value object
auto argumentIter = std::find_if(arguments.begin(), arguments.end(), [evalOrLossArgument](const std::pair<Variable, ValuePtr>& currentPair) {
return (currentPair.first.DynamicAxes() == evalOrLossArgument.DynamicAxes());
});
auto argumentValue = argumentIter->second;
auto argumentVar = argumentIter->first;
auto argumentDataShape = argumentValue->Data()->Shape();
auto mask = argumentValue->Mask();
size_t numMaskedSamples = (mask != nullptr) ? mask->MaskedCount() : 0;
size_t numSamplesInDataArrayView = argumentDataShape.SubShape(argumentVar.Shape().NumAxes()).TotalSize();
if (numMaskedSamples > numSamplesInDataArrayView)
LogicError("Number of masked values cannot exceed the number of samples that the Value object's Data NDArrayView can hold");
return (numSamplesInDataArrayView - numMaskedSamples);
}
double Trainer::TestMinbatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{
if (!m_evaluationFunction)
InvalidArgument("Trainer::TestMinbatch: Cannot test when no evaluation function was specified during 'this' trainer's construction");
// TODO: Should we refactor this code that is somewhat similar to the prologue of the TrainMinibatch function
std::unordered_map<Variable, ValuePtr> outputs = { { m_evaluationFunction, nullptr } };
m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice);
auto sampleCount = GetSampleCountFromArguments(*(m_evaluationFunction->Arguments().begin()), arguments);
return (GetScalarValue(outputs[m_evaluationFunction]) / sampleCount);
}
bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/) bool Trainer::TrainMinibatch(const std::unordered_map<Variable, ValuePtr>& arguments, const DeviceDescriptor& computeDevice /*= DeviceDescriptor::UseDefaultDevice()*/)
{ {
std::unordered_map<Variable, ValuePtr> outputs = { { m_trainingLossVar, nullptr } }; std::unordered_map<Variable, ValuePtr> outputs = { { m_lossFunction, nullptr } };
auto backPropSate = m_model->Forward(arguments, outputs, computeDevice, { m_trainingLossVar }); if (m_evaluationFunction)
m_prevMinibatchTrainingLossValue = outputs.begin()->second; outputs.insert({ m_evaluationFunction, nullptr });
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(m_trainingLossVar.GetDataType(), outputs.at(m_trainingLossVar)->Data()->Shape(), computeDevice), outputs.at(m_trainingLossVar)->Mask()); auto backPropSate = m_combinedTrainingFunction->Forward(arguments, outputs, computeDevice, { m_lossFunction });
if (m_trainingLossVar.GetDataType() == DataType::Float) m_prevMinibatchAggregateTrainingLossValue = outputs[m_lossFunction];
if (m_evaluationFunction)
m_prevMinibatchAggregateEvalCriterionValue = outputs[m_evaluationFunction];
ValuePtr rootGradientValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(m_lossFunction->Output().GetDataType(), m_prevMinibatchAggregateTrainingLossValue->Data()->Shape(), computeDevice), outputs.at(m_lossFunction)->Mask());
if (m_lossFunction->Output().GetDataType() == DataType::Float)
rootGradientValue->Data()->SetValue(1.0f); rootGradientValue->Data()->SetValue(1.0f);
else else
rootGradientValue->Data()->SetValue(1.0); rootGradientValue->Data()->SetValue(1.0);
auto modelParameters = m_model->Parameters(); auto modelParameters = m_combinedTrainingFunction->Parameters();
std::unordered_map<Variable, ValuePtr> parameterGradients; std::unordered_map<Variable, ValuePtr> parameterGradients;
for (const auto& parameter : modelParameters) for (const auto& parameter : modelParameters)
parameterGradients[parameter] = nullptr; parameterGradients[parameter] = nullptr;
m_model->Backward(backPropSate, { { m_trainingLossVar, rootGradientValue } }, parameterGradients); m_combinedTrainingFunction->Backward(backPropSate, { { m_lossFunction, rootGradientValue } }, parameterGradients);
auto trainingLossArgument = *(m_trainingLossVar.Owner()->Arguments().begin()); m_prevMinibatchNumSamples = GetSampleCountFromArguments(*(m_lossFunction->Arguments().begin()), arguments);
// Find the argument whose dynamic axes match the criterion operation's dynamic axes (i.e. label dynamic axes)
// Then we determine the actual number of samples contributing to the training loss from the argument's Value object
auto argumentValue = std::find_if(arguments.begin(), arguments.end(), [trainingLossArgument](const std::pair<Variable, ValuePtr>& currentPair) {
return (currentPair.first.DynamicAxes() == trainingLossArgument.DynamicAxes());
})->second;
auto argumentData = argumentValue->Data();
auto argumentDataShape = argumentData->Shape();
auto mask = argumentValue->Mask();
m_prevMinibatchNumSamples = argumentDataShape[argumentDataShape.NumAxes() - 1] - ((mask != nullptr) ? mask->MaskedCount() : 0);
bool anyUpdatesPerformed = false; bool anyUpdatesPerformed = false;
for (auto learner : m_parameterLearners) for (auto learner : m_parameterLearners)
@ -79,27 +142,16 @@ namespace CNTK
return anyUpdatesPerformed; return anyUpdatesPerformed;
} }
double Trainer::PreviousMinibatchAverageTrainingLoss() const double Trainer::PreviousMinibatchLossAverage() const
{ {
double trainLossValue = std::numeric_limits<double>::quiet_NaN(); return (GetScalarValue(m_prevMinibatchAggregateTrainingLossValue) / m_prevMinibatchNumSamples);
auto prevMBTrainingLossValue = PreviousMinibatchTrainingLossValue()->Data(); }
NDArrayViewPtr cpuTrainLossValue; double Trainer::PreviousMinibatchEvaluationAverage() const
if (prevMBTrainingLossValue->Device() == DeviceDescriptor::CPUDevice()) {
cpuTrainLossValue = prevMBTrainingLossValue; if (!m_evaluationFunction)
else InvalidArgument("Trainer::PreviousMinibatchEvaluationAverage: Cannot get evaluation criterion value when no evaluation function was specified during 'this' trainer's construction");
{
cpuTrainLossValue = std::make_shared<NDArrayView>(prevMBTrainingLossValue->GetDataType(), prevMBTrainingLossValue->Shape(), CNTK::DeviceDescriptor::CPUDevice());
cpuTrainLossValue->CopyFrom(*prevMBTrainingLossValue);
}
if (prevMBTrainingLossValue->GetDataType() == DataType::Float) return (GetScalarValue(m_prevMinibatchAggregateEvalCriterionValue) / m_prevMinibatchNumSamples);
trainLossValue = *(cpuTrainLossValue->DataBuffer<float>());
else if (prevMBTrainingLossValue->GetDataType() == DataType::Double)
trainLossValue = *(cpuTrainLossValue->DataBuffer<double>());
else
LogicError("Unsupported DataType of training loss value");
return (trainLossValue / m_prevMinibatchNumSamples);
} }
} }

Просмотреть файл

@ -601,12 +601,14 @@ namespace CNTK
} }
template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value); template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
template void DictionaryValue::AllocateDataPtr<Axis>(const Axis& value);
template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value); template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
template void DictionaryValue::AllocateDataPtr<wstring>(const wstring& value); template void DictionaryValue::AllocateDataPtr<wstring>(const wstring& value);
template void DictionaryValue::AllocateDataPtr<Dictionary>(const Dictionary& value); template void DictionaryValue::AllocateDataPtr<Dictionary>(const Dictionary& value);
template void DictionaryValue::AllocateDataPtr<NDArrayView>(const NDArrayView& value); template void DictionaryValue::AllocateDataPtr<NDArrayView>(const NDArrayView& value);
template void DictionaryValue::FreePtrAsType<NDShape>(); template void DictionaryValue::FreePtrAsType<NDShape>();
template void DictionaryValue::FreePtrAsType<Axis>();
template void DictionaryValue::FreePtrAsType<vector<DictionaryValue>>(); template void DictionaryValue::FreePtrAsType<vector<DictionaryValue>>();
template void DictionaryValue::FreePtrAsType<wstring>(); template void DictionaryValue::FreePtrAsType<wstring>();
template void DictionaryValue::FreePtrAsType<Dictionary>(); template void DictionaryValue::FreePtrAsType<Dictionary>();

Просмотреть файл

@ -119,14 +119,14 @@ namespace CNTK
} }
} }
inline Microsoft::MSR::CNTK::TensorShape AsTensorShape(const NDShape& viewShape, bool preserveRank = false) inline Microsoft::MSR::CNTK::TensorShape AsTensorShape(const NDShape& viewShape)
{ {
const size_t maxNumAxesSupportedByTensorView = 12; const size_t maxNumAxesSupportedByTensorView = 12;
if (viewShape.NumAxes() > maxNumAxesSupportedByTensorView) if (viewShape.NumAxes() > maxNumAxesSupportedByTensorView)
LogicError("The number of requested axes exceeds the currently supported limit"); LogicError("The number of requested axes exceeds the currently supported limit");
// TensorShape is required to be at least 2D // TensorShape is required to be at least 1D
size_t minRankSize = preserveRank ? viewShape.NumAxes() : 2; size_t minRankSize = 1;
Microsoft::MSR::CNTK::SmallVector<size_t> tensorViewShape(std::max<size_t>(minRankSize, viewShape.NumAxes())); Microsoft::MSR::CNTK::SmallVector<size_t> tensorViewShape(std::max<size_t>(minRankSize, viewShape.NumAxes()));
for (size_t i = 0; i < tensorViewShape.size(); ++i) for (size_t i = 0; i < tensorViewShape.size(); ++i)
tensorViewShape[i] = (i < viewShape.NumAxes()) ? viewShape[i] : 1; tensorViewShape[i] = (i < viewShape.NumAxes()) ? viewShape[i] : 1;
@ -134,6 +134,17 @@ namespace CNTK
return tensorViewShape; return tensorViewShape;
} }
inline Microsoft::MSR::CNTK::TensorShape AsTensorViewShape(const Microsoft::MSR::CNTK::TensorShape& viewShape)
{
// For TensorView shapes we pad the TensorShape to be at least rank 2
return viewShape.PadRank(std::max<size_t>(2, viewShape.GetRank()));
}
inline Microsoft::MSR::CNTK::TensorShape AsTensorViewShape(const NDShape& viewShape)
{
return AsTensorViewShape(AsTensorShape(viewShape));
}
inline std::string AsString(const NDShape& shape) inline std::string AsString(const NDShape& shape)
{ {
std::string shapeString = "["; std::string shapeString = "[";
@ -242,35 +253,39 @@ namespace CNTK
} }
template <typename T> template <typename T>
inline std::vector<DictionaryValue> AsDictionaryValueVector(const std::vector<T>& basicElementTypeVector) inline std::vector<DictionaryValue> AsDictionaryValueVector(const std::vector<T>& elementVector)
{ {
static_assert(std::is_same<T, bool>::value || static_assert(std::is_same<T, bool>::value ||
std::is_same<T, size_t>::value || std::is_same<T, size_t>::value ||
std::is_same<T, float>::value || std::is_same<T, float>::value ||
std::is_same<T, double>::value || std::is_same<T, double>::value ||
std::is_same<T, std::wstring>::value, "Unsupported ValueType"); std::is_same<T, Axis>::value ||
std::is_same<T, std::wstring>::value,
"Unsupported ValueType");
std::vector<DictionaryValue> dictionaryValueVector; std::vector<DictionaryValue> dictionaryValueVector;
for (auto value : basicElementTypeVector) for (auto value : elementVector)
dictionaryValueVector.push_back(value); dictionaryValueVector.push_back(value);
return dictionaryValueVector; return dictionaryValueVector;
} }
template <typename T> template <typename T>
inline std::vector<T> AsBasicElementTypeVector(const std::vector<DictionaryValue>& dictionaryValueVector) inline std::vector<T> AsVector(const std::vector<DictionaryValue>& dictionaryValueVector)
{ {
static_assert(std::is_same<T, bool>::value || static_assert(std::is_same<T, bool>::value ||
std::is_same<T, size_t>::value || std::is_same<T, size_t>::value ||
std::is_same<T, float>::value || std::is_same<T, float>::value ||
std::is_same<T, double>::value || std::is_same<T, double>::value ||
std::is_same<T, std::wstring>::value, "Unsupported ValueType"); std::is_same<T, Axis>::value ||
std::is_same<T, std::wstring>::value,
"Unsupported ValueType");
std::vector<T> basicElementTypeVector; std::vector<T> elementVector;
for (auto value : dictionaryValueVector) for (auto value : dictionaryValueVector)
basicElementTypeVector.push_back(value.Value<T>()); elementVector.push_back(value.Value<T>());
return basicElementTypeVector; return elementVector;
} }
inline PoolingType AsPoolingType(Microsoft::MSR::CNTK::PoolKind cntkPoolingKind) inline PoolingType AsPoolingType(Microsoft::MSR::CNTK::PoolKind cntkPoolingKind)

Просмотреть файл

@ -3,6 +3,7 @@
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
// //
#include "stdafx.h"
#include "CNTKLibrary.h" #include "CNTKLibrary.h"
namespace CNTK namespace CNTK

Просмотреть файл

@ -3,7 +3,10 @@
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
// //
#include "stdafx.h"
#include "CNTKLibrary.h" #include "CNTKLibrary.h"
#include "Utils.h"
#include "Function.h"
namespace CNTK namespace CNTK
{ {
@ -16,6 +19,44 @@ namespace CNTK
FunctionPtr Variable::Owner() const FunctionPtr Variable::Owner() const
{ {
return m_dataFields->m_ownerFunction->shared_from_this(); if (m_dataFields->m_ownerFunction != nullptr)
return m_dataFields->m_ownerFunction->shared_from_this();
else
return nullptr;
}
Variable::operator FunctionPtr() const
{
auto varOwner = Owner();
if (varOwner)
return CompositeFunction::Create(varOwner, varOwner->Name());
else
return Internal::Combine({ *this });
}
/*static*/ Parameter Parameter::UniformInitParameter(const NDShape& shape, DataType type, double range, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name)
{
switch (type)
{
case DataType::Float:
return Parameter(NDArrayView::RandomUniform<float>(shape, -range, range, seed, device), name);
case DataType::Double:
return Parameter(NDArrayView::RandomUniform<double>(shape, -range, range, seed, device), name);
default:
InvalidArgument("Parameter construction: Unsupported DataType %s", DataTypeName(type));
}
}
/*static*/ Parameter Parameter::NormalInitParameter(const NDShape& shape, DataType type, double stdDev, unsigned long seed, const DeviceDescriptor& device, const std::wstring& name)
{
switch (type)
{
case DataType::Float:
return Parameter(NDArrayView::RandomNormal<float>(shape, 0, stdDev, seed, device), name);
case DataType::Double:
return Parameter(NDArrayView::RandomNormal<double>(shape, 0, stdDev, seed, device), name);
default:
InvalidArgument("Parameter construction: Unsupported DataType %s", DataTypeName(type));
}
} }
} }

Просмотреть файл

@ -1239,7 +1239,11 @@ void Matrix<ElemType>::AssignValuesOf(const Matrix<ElemType>& deepCopyFrom)
DISPATCH_MATRIX_ON_FLAG(&deepCopyFrom, nullptr, DISPATCH_MATRIX_ON_FLAG(&deepCopyFrom, nullptr,
{ m_GPUMatrix->SetValue(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols(), this->GetDeviceId(), deepCopyFrom.m_CPUMatrix->Data()); }, { m_GPUMatrix->SetValue(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols(), this->GetDeviceId(), deepCopyFrom.m_CPUMatrix->Data()); },
{ m_GPUMatrix->SetValue(*deepCopyFrom.m_GPUMatrix); }, { m_GPUMatrix->SetValue(*deepCopyFrom.m_GPUMatrix); },
{ LogicError("AssignValuesOf: Assigning a CPUSparseMatrix to a GPUMatrix is not yet implemented."); },//{ m_GPUMatrix->SetValue(*deepCopyFrom.m_CPUSparseMatrix); }, {
CPUMatrix<ElemType> tempCPUDenseMatrix(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols());
deepCopyFrom.m_CPUSparseMatrix->AssignColumnSliceToDense(tempCPUDenseMatrix, 0, deepCopyFrom.GetNumCols());
m_GPUMatrix->SetValue(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols(), this->GetDeviceId(), tempCPUDenseMatrix.Data());
},//{ m_GPUMatrix->SetValue(*deepCopyFrom.m_CPUSparseMatrix); },
{ LogicError("AssignValuesOf: Assigning a GPUSparseMatrix to a GPUMatrix is not yet implemented."); });//{ m_GPUMatrix->SetValue(*deepCopyFrom.m_GPUSparseMatrix); }); { LogicError("AssignValuesOf: Assigning a GPUSparseMatrix to a GPUMatrix is not yet implemented."); });//{ m_GPUMatrix->SetValue(*deepCopyFrom.m_GPUSparseMatrix); });
}, },
{ {

Просмотреть файл

@ -132,24 +132,28 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
const size_t numOutputClasses = labelStreamInfo.m_sampleLayout[0]; const size_t numOutputClasses = labelStreamInfo.m_sampleLayout[0];
Variable imageInput(inputImageShape, imageStreamInfo.m_elementType, L"Images"); Variable imageInput(inputImageShape, imageStreamInfo.m_elementType, L"Images");
auto classifierOutputFunction = ResNetClassifier(imageInput, numOutputClasses, device, L"classifierOutput"); auto classifierOutput = ResNetClassifier(imageInput, numOutputClasses, device, L"classifierOutput");
Variable classifierOutput = classifierOutputFunction;
auto labelsVar = Variable({ numOutputClasses }, labelStreamInfo.m_elementType, L"Labels"); auto labelsVar = Variable({ numOutputClasses }, labelStreamInfo.m_elementType, L"Labels");
auto trainingLoss = CrossEntropyWithSoftmax(classifierOutput, labelsVar, L"lossFunction");
auto trainingLossFunction = CrossEntropyWithSoftmax(classifierOutputFunction, labelsVar, L"lossFunction"); auto prediction = ClassificationError(classifierOutput, labelsVar, L"predictionError");
Variable trainingLoss = trainingLossFunction;
auto predictionFunction = ClassificationError(classifierOutputFunction, labelsVar, L"predictionError");
Variable prediction = predictionFunction;
auto imageClassifier = Combine({ trainingLossFunction, predictionFunction, classifierOutputFunction }, L"ImageClassifier");
if (testSaveAndReLoad) if (testSaveAndReLoad)
SaveAndReloadModel<float>(imageClassifier, { &imageInput, &labelsVar, &trainingLoss, &prediction, &classifierOutput }, device); {
Variable classifierOutputVar = classifierOutput;
Variable trainingLossVar = trainingLoss;
Variable predictionVar = prediction;
auto imageClassifier = Combine({ trainingLoss, prediction, classifierOutput }, L"ImageClassifier");
SaveAndReloadModel<float>(imageClassifier, { &imageInput, &labelsVar, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
trainingLoss = trainingLossVar;
prediction = predictionVar;
classifierOutput = classifierOutputVar;
}
double learningRatePerSample = 0.0078125; double learningRatePerSample = 0.0078125;
Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
Trainer trainer(imageClassifier, trainingLoss, { SGDLearner(imageClassifier->Parameters(), learningRatePerSample) });
const size_t minibatchSize = 32; const size_t minibatchSize = 32;
size_t numMinibatchesToTrain = 100; size_t numMinibatchesToTrain = 100;
size_t outputFrequencyInMinibatches = 20; size_t outputFrequencyInMinibatches = 20;
@ -157,12 +161,7 @@ void TrainResNetCifarClassifer(const DeviceDescriptor& device, bool testSaveAndR
{ {
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device); auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
trainer.TrainMinibatch({ { imageInput, minibatchData[imageStreamInfo].m_data }, { labelsVar, minibatchData[labelStreamInfo].m_data } }, device); trainer.TrainMinibatch({ { imageInput, minibatchData[imageStreamInfo].m_data }, { labelsVar, minibatchData[labelStreamInfo].m_data } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
if ((i % outputFrequencyInMinibatches) == 0)
{
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
}
} }
} }

Просмотреть файл

@ -116,114 +116,82 @@ inline CNTK::FunctionPtr FullyConnectedDNNLayer(CNTK::Variable input, size_t out
template <typename ElementType> template <typename ElementType>
std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(CNTK::Variable input, CNTK::Variable prevOutput, CNTK::Variable prevCellState, const CNTK::DeviceDescriptor& device) std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(CNTK::Variable input, CNTK::Variable prevOutput, CNTK::Variable prevCellState, const CNTK::DeviceDescriptor& device)
{ {
assert(input.Shape().NumAxes() == 1); if ((input.Shape().NumAxes() != 1) || (prevOutput.Shape().NumAxes() != 1) || (prevCellState.Shape().NumAxes() != 1))
size_t inputDim = input.Shape()[0]; LogicError("The LSTM implementation in the test library currently only supports 1D inputs and outputs");
auto stabilize = [](const Variable& x) {
float scalarConstant = 4.0f;
auto f = Constant({}, scalarConstant);
auto fInv = Constant({}, 1.0f / scalarConstant);
auto beta = ElementTimes(fInv, Log(Constant({}, 1.0f) + Exp(ElementTimes(f, Parameter({}, 0.99537863f /* 1/f*ln (e^f-1) */)))));
return ElementTimes(beta, x);
};
size_t inputDim = input.Shape()[0];
size_t outputDim = prevOutput.Shape()[0]; size_t outputDim = prevOutput.Shape()[0];
size_t cellDim = prevCellState.Shape()[0]; size_t cellDim = prevCellState.Shape()[0];
auto createBiasParam = [device](size_t dim) {
return CNTK::Parameter({ dim }, (ElementType)0.0, device);
};
unsigned long seed = 1; unsigned long seed = 1;
auto createProjectionParam = [device, &seed](size_t outputDim, size_t inputDim) {
return CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ outputDim, inputDim }, -0.5, 0.5, seed++, device));
};
auto Wxo = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device)); auto createDiagWeightParam = [device, &seed](size_t dim) {
auto Wxi = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device)); return CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ dim }, -0.5, 0.5, seed++, device));
auto Wxf = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device)); };
auto Wxc = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, inputDim }, -0.5, 0.5, seed++, device));
auto Bo = CNTK::Parameter({ cellDim }, (ElementType)0.0, device); auto stabilizedPrevOutput = stabilize(prevOutput);
auto Bc = CNTK::Parameter({ cellDim }, (ElementType)0.0, device); auto stabilizedPrevCellState = stabilize(prevCellState);
auto Bi = CNTK::Parameter({ cellDim }, (ElementType)0.0, device);
auto Bf = CNTK::Parameter({ cellDim }, (ElementType)0.0, device);
auto Whi = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device)); auto projectInput = [input, cellDim, inputDim, createBiasParam, createProjectionParam]() {
auto Wci = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device)); return createBiasParam(cellDim) + Times(createProjectionParam(cellDim, inputDim), input);
};
auto Whf = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device)); // Input gate
auto Wcf = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device)); auto it = Sigmoid(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
auto bit = ElementTimes(it, Tanh(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput)));
auto Who = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device)); // Forget-me-not gate
auto Wco = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim }, -0.5, 0.5, seed++, device)); auto ft = Sigmoid(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + ElementTimes(createDiagWeightParam(cellDim), stabilizedPrevCellState));
auto bft = ElementTimes(ft, prevCellState);
auto Whc = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ cellDim, outputDim }, -0.5, 0.5, seed++, device)); auto ct = bft + bit;
auto Wmr = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<ElementType>({ outputDim, cellDim }, -0.5, 0.5, seed++, device)); // Output gate
auto ot = Sigmoid(projectInput() + Times(createProjectionParam(cellDim, outputDim), stabilizedPrevOutput) + ElementTimes(createDiagWeightParam(cellDim), stabilize(ct)));
auto ht = ElementTimes(ot, Tanh(ct));
// Stabilization by routing input through an extra scalar parameter auto c = ct;
auto sWxo = CNTK::Parameter({}, (ElementType)0.0, device); auto h = (outputDim != cellDim) ? Times(createProjectionParam(outputDim, cellDim), stabilize(ht)) : ht;
auto sWxi = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWxf = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWxc = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWhi = CNTK::Parameter({}, (ElementType)0.0, device); return{ h, c };
auto sWci = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWhf = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWcf = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWho = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWco = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWhc = CNTK::Parameter({}, (ElementType)0.0, device);
auto sWmr = CNTK::Parameter({}, (ElementType)0.0, device);
auto expsWxo = CNTK::Exp(sWxo);
auto expsWxi = CNTK::Exp(sWxi);
auto expsWxf = CNTK::Exp(sWxf);
auto expsWxc = CNTK::Exp(sWxc);
auto expsWhi = CNTK::Exp(sWhi);
auto expsWci = CNTK::Exp(sWci);
auto expsWhf = CNTK::Exp(sWhf);
auto expsWcf = CNTK::Exp(sWcf);
auto expsWho = CNTK::Exp(sWho);
auto expsWco = CNTK::Exp(sWco);
auto expsWhc = CNTK::Exp(sWhc);
auto expsWmr = CNTK::Exp(sWmr);
auto Wxix = CNTK::Times(Wxi, CNTK::ElementTimes(expsWxi, input));
auto Whidh = CNTK::Times(Whi, CNTK::ElementTimes(expsWhi, prevOutput));
auto Wcidc = CNTK::ElementTimes(Wci, CNTK::ElementTimes(expsWci, prevCellState));
auto it = CNTK::Sigmoid(CNTK::Plus(CNTK::Plus(CNTK::Plus(Wxix, Bi), Whidh), Wcidc));
auto Wxcx = CNTK::Times(Wxc, CNTK::ElementTimes(expsWxc, input));
auto Whcdh = CNTK::Times(Whc, CNTK::ElementTimes(expsWhc, prevOutput));
auto bit = CNTK::ElementTimes(it, CNTK::Tanh(CNTK::Plus(Wxcx, CNTK::Plus(Whcdh, Bc))));
auto Wxfx = CNTK::Times(Wxf, CNTK::ElementTimes(expsWxf, input));
auto Whfdh = CNTK::Times(Whf, CNTK::ElementTimes(expsWhf, prevOutput));
auto Wcfdc = CNTK::ElementTimes(Wcf, CNTK::ElementTimes(expsWcf, prevCellState));
auto ft = CNTK::Sigmoid(CNTK::Plus(CNTK::Plus(CNTK::Plus(Wxfx, Bf), Whfdh), Wcfdc));
auto bft = CNTK::ElementTimes(ft, prevCellState);
auto ct = CNTK::Plus(bft, bit);
auto Wxox = CNTK::Times(Wxo, CNTK::ElementTimes(expsWxo, input));
auto Whodh = CNTK::Times(Who, CNTK::ElementTimes(expsWho, prevOutput));
auto Wcoct = CNTK::ElementTimes(Wco, CNTK::ElementTimes(expsWco, ct));
auto ot = CNTK::Sigmoid(CNTK::Plus(CNTK::Plus(CNTK::Plus(Wxox, Bo), Whodh), Wcoct));
auto mt = CNTK::ElementTimes(ot, Tanh(ct));
return{ CNTK::Times(Wmr, CNTK::ElementTimes(expsWmr, mt)), ct };
} }
template <typename ElementType> template <typename ElementType>
CNTK::FunctionPtr LSTMPComponentWithSelfStabilization(CNTK::Variable input, size_t outputDim, size_t cellDim, const CNTK::DeviceDescriptor& device) std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPComponentWithSelfStabilization(CNTK::Variable input,
size_t outputDim,
size_t cellDim,
const std::function<CNTK::FunctionPtr(const CNTK::Variable&)>& recurrenceHookH,
const std::function<CNTK::FunctionPtr(const CNTK::Variable&)>& recurrenceHookC,
const CNTK::DeviceDescriptor& device)
{ {
auto dh = CNTK::Placeholder({ outputDim }); auto dh = CNTK::Placeholder({ outputDim }, input.DynamicAxes());
auto dc = CNTK::Placeholder({ cellDim }); auto dc = CNTK::Placeholder({ cellDim }, input.DynamicAxes());
auto LSTMCell = LSTMPCellWithSelfStabilization<ElementType>(input, dh, dc, device); auto LSTMCell = LSTMPCellWithSelfStabilization<ElementType>(input, dh, dc, device);
auto actualDh = CNTK::PastValue(LSTMCell.first, CNTK::Constant({}, (ElementType)0.0, device), 1); auto actualDh = recurrenceHookH(LSTMCell.first);
auto actualDc = CNTK::PastValue(LSTMCell.second, CNTK::Constant({}, (ElementType)0.0, device), 1); auto actualDc = recurrenceHookC(LSTMCell.second);
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc // Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
return LSTMCell.first->ReplacePlaceholders({ { dh, actualDh }, { dc, actualDc } }); LSTMCell.first->ReplacePlaceholders({ { dh, actualDh }, { dc, actualDc } });
return { LSTMCell.first, LSTMCell.second };
} }
inline CNTK::MinibatchSourcePtr CreateTextMinibatchSource(const std::wstring& filePath, inline CNTK::MinibatchSourcePtr CreateTextMinibatchSource(const std::wstring& filePath,
@ -355,4 +323,14 @@ inline void OpenStream(std::fstream& stream, const std::wstring& filename, bool
stream.open(wtocharpath(filename.c_str()).c_str(), mode); stream.open(wtocharpath(filename.c_str()).c_str(), mode);
#endif #endif
stream.exceptions(std::ios_base::failbit | std::ios_base::badbit); stream.exceptions(std::ios_base::failbit | std::ios_base::badbit);
} }
inline void PrintTrainingProgress(const CNTK::Trainer& trainer, size_t minibatchIdx, size_t outputFrequencyInMinibatches)
{
if ((minibatchIdx % outputFrequencyInMinibatches) == 0)
{
double trainLossValue = trainer.PreviousMinibatchLossAverage();
double evaluationValue = trainer.PreviousMinibatchEvaluationAverage();
printf("Minibatch %d: CrossEntropy loss = %.8g, Evaluation criterion = %.8g\n", (int)minibatchIdx, trainLossValue, evaluationValue);
}
}

Просмотреть файл

@ -13,6 +13,7 @@ void FunctionTests();
void TrainLSTMSequenceClassifer(); void TrainLSTMSequenceClassifer();
void SerializationTests(); void SerializationTests();
void LearnerTests(); void LearnerTests();
void TrainSequenceToSequenceTranslator();
int main() int main()
{ {
@ -29,6 +30,8 @@ int main()
TestCifarResnet(); TestCifarResnet();
TrainLSTMSequenceClassifer(); TrainLSTMSequenceClassifer();
TrainSequenceToSequenceTranslator();
fprintf(stderr, "\nCNTKv2Library tests: Passed\n"); fprintf(stderr, "\nCNTKv2Library tests: Passed\n");
fflush(stderr); fflush(stderr);

Просмотреть файл

@ -10,10 +10,13 @@ static unsigned long seed = 1;
template <typename ElementType> template <typename ElementType>
FunctionPtr LSTMNet(Variable features, size_t cellDim, size_t hiddenDim, size_t numOutputClasses, size_t numLSTMLayers, const DeviceDescriptor& device, const std::wstring& outputName) FunctionPtr LSTMNet(Variable features, size_t cellDim, size_t hiddenDim, size_t numOutputClasses, size_t numLSTMLayers, const DeviceDescriptor& device, const std::wstring& outputName)
{ {
using namespace std::placeholders;
assert(numLSTMLayers >= 1); assert(numLSTMLayers >= 1);
auto classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(features, hiddenDim, cellDim, device); FunctionPtr classifierRoot = features;
for (size_t i = 1; i < numLSTMLayers; ++i) { auto pastValueRecurrenceHook = std::bind(PastValue, _1, CNTK::Constant({}, (ElementType)0.0), 1, L"");
classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(classifierRoot, hiddenDim, cellDim, device); for (size_t i = 0; i < numLSTMLayers; ++i) {
classifierRoot = LSTMPComponentWithSelfStabilization<ElementType>(classifierRoot, hiddenDim, cellDim, pastValueRecurrenceHook, pastValueRecurrenceHook, device).first;
} }
auto W = Parameter(NDArrayView::RandomUniform<ElementType>({ numOutputClasses, hiddenDim }, -0.5, 0.5, seed++, device)); auto W = Parameter(NDArrayView::RandomUniform<ElementType>({ numOutputClasses, hiddenDim }, -0.5, 0.5, seed++, device));

Просмотреть файл

@ -0,0 +1,187 @@
#include "CNTKLibrary.h"
#include <functional>
#include "Common.h"
using namespace CNTK;
using namespace std::placeholders;
inline CNTK::MinibatchSourcePtr CreateSeq2SeqMinibatchSource(const std::wstring& filePath, size_t inputVocabSize, size_t labelsVocabSize)
{
CNTK::Dictionary inputStreamConfig;
inputStreamConfig[L"dim"] = inputVocabSize;
inputStreamConfig[L"format"] = L"sparse";
inputStreamConfig[L"alias"] = L"S0";
CNTK::Dictionary labelsStreamConfig;
labelsStreamConfig[L"dim"] = labelsVocabSize;
labelsStreamConfig[L"format"] = L"sparse";
labelsStreamConfig[L"alias"] = L"S1";
CNTK::Dictionary inputStreamsConfig;
inputStreamsConfig[L"rawInput"] = inputStreamConfig;
inputStreamsConfig[L"rawLabels"] = labelsStreamConfig;
CNTK::Dictionary deserializerConfiguration;
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
deserializerConfiguration[L"file"] = filePath;
deserializerConfiguration[L"input"] = inputStreamsConfig;
deserializerConfiguration[L"skipSequenceIds"] = L"false";
deserializerConfiguration[L"maxErrors"] = (size_t)100;
deserializerConfiguration[L"traceLevel"] = (size_t)1;
deserializerConfiguration[L"chunkSizeInBytes"] = (size_t)30000000;
CNTK::Dictionary minibatchSourceConfiguration;
minibatchSourceConfiguration[L"epochSize"] = (size_t)2000;
minibatchSourceConfiguration[L"deserializers"] = std::vector<CNTK::DictionaryValue>({ deserializerConfiguration });
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
}
void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useSparseInputs, bool testSaveAndReLoad)
{
using namespace std::placeholders;
const size_t inputVocabDim = 69;
const size_t labelVocabDim = 69;
const size_t hiddenDim = 512;
const size_t numLayers = 2;
const size_t embeddingDim = 300;
const size_t inputEmbeddingDim = std::min(inputVocabDim, embeddingDim);
const size_t labelEmbeddingDim = std::min(labelVocabDim, embeddingDim);
/* Inputs */
std::vector<Axis> inputDynamicAxes = { Axis(L"inputAxis"), Axis::DefaultBatchAxis() };
auto rawInput = Variable({ inputVocabDim }, useSparseInputs /*isSparse*/, DataType::Float, L"rawInput", inputDynamicAxes);
std::vector<Axis> labelDynamicAxes = { Axis(L"labelAxis"), Axis::DefaultBatchAxis() };
auto rawLabels = Variable({ labelVocabDim }, useSparseInputs /*isSparse*/, DataType::Float, L"rawLabels", labelDynamicAxes);
FunctionPtr inputSequence = rawInput;
// Drop the sentence start token from the label, for decoder training
auto labelSequence = Slice(rawLabels, labelDynamicAxes[0], 1, 0);
auto labelSentenceStart = Sequence::First(rawLabels);
auto isFirstLabel = Sequence::IsFirst(labelSequence);
bool forceEmbedding = useSparseInputs;
/* Embeddings */
auto inputEmbeddingWeights = Parameter(NDArrayView::RandomUniform<float>({ inputEmbeddingDim, inputVocabDim }, -0.05, 0.05, 1, device));
auto labelEmbeddingWeights = Parameter(NDArrayView::RandomUniform<float>({ labelEmbeddingDim, labelVocabDim }, -0.05, 0.05, 1, device));
auto inputEmbedding = (!forceEmbedding && (inputVocabDim <= inputEmbeddingDim)) ? inputSequence : Times(inputEmbeddingWeights, inputSequence);
auto labelEmbedding = (!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSequence : Times(labelEmbeddingWeights, labelSequence);
auto labelSentenceStartEmbedding = (!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSentenceStart : Times(labelEmbeddingWeights, labelSentenceStart);
auto labelSentenceStartEmbeddedScattered = Sequence::Scatter(labelSentenceStartEmbedding, isFirstLabel);
auto stabilize = [](const Variable& x) {
float scalarConstant = 4.0f;
auto f = Constant({}, scalarConstant);
auto fInv = Constant({}, 1.0f/scalarConstant);
auto beta = ElementTimes(fInv, Log(Constant({}, 1.0f) + Exp(ElementTimes(f, Parameter({}, 0.99537863f /* 1/f*ln (e^f-1) */)))));
return ElementTimes(beta, x);
};
/* Encoder */
auto encoderOutputH = stabilize(inputEmbedding);
FunctionPtr encoderOutputC;
auto futureValueRecurrenceHook = std::bind(FutureValue, _1, CNTK::Constant({}, 0.0f), 1, L"");
for (size_t i = 0; i < numLayers; ++i)
std::tie(encoderOutputH, encoderOutputC) = LSTMPComponentWithSelfStabilization<float>(encoderOutputH, hiddenDim, hiddenDim, futureValueRecurrenceHook, futureValueRecurrenceHook, device);
auto thoughtVectorH = Sequence::First(encoderOutputH);
auto thoughtVectorC = Sequence::First(encoderOutputC);
auto thoughtVectorBroadcastH = Sequence::BroadcastAs(thoughtVectorH, labelEmbedding);
auto thoughtVectorBroadcastC = Sequence::BroadcastAs(thoughtVectorC, labelEmbedding);
/* Decoder */
auto decoderHistoryFromGroundTruth = labelEmbedding;
auto decoderInput = ElementSelect(isFirstLabel, labelSentenceStartEmbeddedScattered, PastValue(decoderHistoryFromGroundTruth, Constant({}, 0.0f), 1));
auto decoderOutputH = stabilize(decoderInput);
FunctionPtr decoderOutputC;
auto pastValueRecurrenceHook = std::bind(PastValue, _1, CNTK::Constant({}, 0.0f), 1, L"");
for (size_t i = 0; i < numLayers; ++i)
{
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC;
if (i == 0)
{
recurrenceHookH = pastValueRecurrenceHook;
recurrenceHookC = pastValueRecurrenceHook;
}
else
{
auto isFirst = Sequence::IsFirst(labelEmbedding);
recurrenceHookH = [labelEmbedding, thoughtVectorBroadcastH, isFirst](const Variable& operand) {
return ElementSelect(isFirst, thoughtVectorBroadcastH, PastValue(operand, CNTK::Constant({}, 0.0f), 1, L""));
};
recurrenceHookC = [labelEmbedding, thoughtVectorBroadcastC, isFirst](const Variable& operand) {
return ElementSelect(isFirst, thoughtVectorBroadcastC, PastValue(operand, CNTK::Constant({}, 0.0f), 1, L""));
};
}
std::tie(decoderOutputH, encoderOutputC) = LSTMPComponentWithSelfStabilization<float>(decoderOutputH, hiddenDim, hiddenDim, recurrenceHookH, recurrenceHookC, device);
}
auto decoderOutput = decoderOutputH;
auto decoderDim = hiddenDim;
/* Softmax output layer */
auto outputLayerProjWeights = Parameter(NDArrayView::RandomUniform<float>({ labelVocabDim, decoderDim }, -0.05, 0.05, 1, device));
auto biasWeights = Parameter({ labelVocabDim }, 0.0f, device);
auto z = Plus(Times(outputLayerProjWeights, stabilize(decoderOutput)), biasWeights, L"classifierOutput");
auto ce = CrossEntropyWithSoftmax(z, labelSequence, L"lossFunction");
auto errs = ClassificationError(z, labelSequence, L"classificationError");
if (testSaveAndReLoad)
{
Variable zVar = z;
Variable ceVar = ce;
Variable errsVar = errs;
auto seq2seqModel = Combine({ ce, errs, z }, L"seq2seqModel");
SaveAndReloadModel<float>(seq2seqModel, { &rawInput, &rawLabels, &ceVar, &errsVar, &zVar }, device);
z = zVar;
ce = ceVar;
errs = errsVar;
}
auto minibatchSource = CreateSeq2SeqMinibatchSource(L"cmudict-0.7b.train-dev-20-21.bsf.ctf.2", inputVocabDim, labelVocabDim);
auto rawInputStreamInfo = minibatchSource->StreamInfo(L"rawInput");
auto rawLabelsStreamInfo = minibatchSource->StreamInfo(L"rawLabels");
double learningRatePerSample = 0.007;
size_t momentumTimeConstant = 1100;
double momentumPerSample = std::exp(-1.0 / momentumTimeConstant);
double clippingThresholdPerSample = 2.3;
bool gradientClippingWithTruncation = true;
Trainer trainer(z, ce, errs, { MomentumSGDLearner(z->Parameters(), learningRatePerSample, momentumPerSample, clippingThresholdPerSample, gradientClippingWithTruncation) });
size_t outputFrequencyInMinibatches = 1;
size_t minibatchSize = 72;
for (size_t i = 0; true; i++)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
if (minibatchData.empty())
break;
trainer.TrainMinibatch({ { rawInput, minibatchData[rawInputStreamInfo].m_data }, { rawLabels, minibatchData[rawLabelsStreamInfo].m_data } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
}
}
void TrainSequenceToSequenceTranslator()
{
// TODO: Also test with sparse input variables in the graph
TrainSequenceToSequenceTranslator(DeviceDescriptor::GPUDevice(0), false, true);
TrainSequenceToSequenceTranslator(DeviceDescriptor::CPUDevice(), false, false);
}

Просмотреть файл

@ -15,16 +15,12 @@ FunctionPtr Embedding(const Variable& input, size_t embeddingDim, const DeviceDe
return Times(embeddingParameters, input); return Times(embeddingParameters, input);
} }
FunctionPtr SelectLast(const Variable& operand)
{
return Slice(operand, Axis::DefaultDynamicAxis(), -1, 0);
}
FunctionPtr LSTMSequenceClassiferNet(const Variable& input, size_t numOutputClasses, size_t embeddingDim, size_t LSTMDim, size_t cellDim, const DeviceDescriptor& device, const std::wstring& outputName) FunctionPtr LSTMSequenceClassiferNet(const Variable& input, size_t numOutputClasses, size_t embeddingDim, size_t LSTMDim, size_t cellDim, const DeviceDescriptor& device, const std::wstring& outputName)
{ {
auto embeddingFunction = Embedding(input, embeddingDim, device); auto embeddingFunction = Embedding(input, embeddingDim, device);
auto LSTMFunction = LSTMPComponentWithSelfStabilization<float>(embeddingFunction, LSTMDim, cellDim, device); auto pastValueRecurrenceHook = std::bind(PastValue, _1, CNTK::Constant({}, 0.0f), 1, L"");
auto thoughtVectorFunction = SelectLast(LSTMFunction); auto LSTMFunction = LSTMPComponentWithSelfStabilization<float>(embeddingFunction, LSTMDim, cellDim, pastValueRecurrenceHook, pastValueRecurrenceHook, device).first;
auto thoughtVectorFunction = Sequence::Last(LSTMFunction);
return FullyConnectedLinearLayer(thoughtVectorFunction, numOutputClasses, device, outputName); return FullyConnectedLinearLayer(thoughtVectorFunction, numOutputClasses, device, outputName);
} }
@ -38,27 +34,34 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool testSaveAnd
const size_t numOutputClasses = 5; const size_t numOutputClasses = 5;
Variable features({ inputDim }, true /*isSparse*/, DataType::Float, L"features"); Variable features({ inputDim }, true /*isSparse*/, DataType::Float, L"features");
auto classifierOutputFunction = LSTMSequenceClassiferNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, L"classifierOutput"); auto classifierOutput = LSTMSequenceClassiferNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, L"classifierOutput");
Variable classifierOutput = classifierOutputFunction;
Variable labels({ numOutputClasses }, DataType::Float, L"labels", { Axis::DefaultBatchAxis() }); Variable labels({ numOutputClasses }, DataType::Float, L"labels", { Axis::DefaultBatchAxis() });
auto trainingLossFunction = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction"); auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");
Variable trainingLoss = trainingLossFunction; auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
auto predictionFunction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
Variable prediction = predictionFunction;
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss.Owner(), prediction.Owner(), classifierOutput.Owner() }, L"classifierModel");
if (testSaveAndReLoad) if (testSaveAndReLoad)
SaveAndReloadModel<float>(oneHiddenLayerClassifier, { &features, &labels, &trainingLoss, &prediction, &classifierOutput }, device); {
Variable classifierOutputVar = classifierOutput;
Variable trainingLossVar = trainingLoss;
Variable predictionVar = prediction;
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss, prediction, classifierOutput }, L"classifierModel");
SaveAndReloadModel<float>(oneHiddenLayerClassifier, { &features, &labels, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
classifierOutput = classifierOutputVar;
trainingLoss = trainingLossVar;
prediction = predictionVar;
}
auto minibatchSource = CreateTextMinibatchSource(L"Train.ctf", inputDim, numOutputClasses, 0, true, false, L"x", L"y"); auto minibatchSource = CreateTextMinibatchSource(L"Train.ctf", inputDim, numOutputClasses, 0, true, false, L"x", L"y");
const size_t minibatchSize = 200; const size_t minibatchSize = 200;
auto featureStreamInfo = minibatchSource->StreamInfo(features); auto featureStreamInfo = minibatchSource->StreamInfo(features);
auto labelStreamInfo = minibatchSource->StreamInfo(labels); auto labelStreamInfo = minibatchSource->StreamInfo(labels);
double learningRatePerSample = 0.0005; double learningRatePerSample = 0.0005;
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) }); Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
size_t outputFrequencyInMinibatches = 1; size_t outputFrequencyInMinibatches = 1;
for (size_t i = 0; true; i++) for (size_t i = 0; true; i++)
{ {
@ -67,12 +70,7 @@ void TrainLSTMSequenceClassifer(const DeviceDescriptor& device, bool testSaveAnd
break; break;
trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device); trainer.TrainMinibatch({ { features, minibatchData[featureStreamInfo].m_data }, { labels, minibatchData[labelStreamInfo].m_data } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
if ((i % outputFrequencyInMinibatches) == 0)
{
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
}
} }
} }

Просмотреть файл

@ -35,28 +35,34 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device)); auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device));
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device)); auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device));
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput)); classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
Variable labels({ numOutputClasses }, DataType::Float, L"labels"); Variable labels({ numOutputClasses }, DataType::Float, L"labels");
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");; auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");;
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError"); auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss, prediction, classifierOutput }, L"classifierModel"); // Test save and reload of model
{
Variable classifierOutputVar = classifierOutput;
Variable trainingLossVar = trainingLoss;
Variable predictionVar = prediction;
auto combinedNet = Combine({ trainingLoss, prediction, classifierOutput }, L"feedForwardClassifier");
SaveAndReloadModel<float>(combinedNet, { &input, &labels, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
classifierOutput = classifierOutputVar;
trainingLoss = trainingLossVar;
prediction = predictionVar;
}
double learningRatePerSample = 0.02; double learningRatePerSample = 0.02;
minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, SIZE_MAX); minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, SIZE_MAX);
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) }); Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
size_t outputFrequencyInMinibatches = 20; size_t outputFrequencyInMinibatches = 20;
for (size_t i = 0; i < numMinibatchesToTrain; ++i) for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{ {
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device); auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device); trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
if ((i % outputFrequencyInMinibatches) == 0)
{
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
}
} }
} }
@ -71,13 +77,24 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
auto classifierOutput = FullyConnectedDNNLayer(scaledInput, hiddenLayerDim, device, std::bind(Sigmoid, _1, L"")); auto classifierOutput = FullyConnectedDNNLayer(scaledInput, hiddenLayerDim, device, std::bind(Sigmoid, _1, L""));
auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device)); auto outputTimesParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device));
auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device)); auto outputBiasParam = Parameter(NDArrayView::RandomUniform<float>({ numOutputClasses }, -0.05, 0.05, 1, device));
classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput)); classifierOutput = Plus(outputBiasParam, Times(outputTimesParam, classifierOutput), L"classifierOutput");
Variable labels({ numOutputClasses }, DataType::Float, L"labels"); Variable labels({ numOutputClasses }, DataType::Float, L"labels");
auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");; auto trainingLoss = CNTK::CrossEntropyWithSoftmax(classifierOutput, labels, L"lossFunction");;
auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError"); auto prediction = CNTK::ClassificationError(classifierOutput, labels, L"classificationError");
auto oneHiddenLayerClassifier = CNTK::Combine({ trainingLoss, prediction, classifierOutput }, L"classifierModel"); // Test save and reload of model
{
Variable classifierOutputVar = classifierOutput;
Variable trainingLossVar = trainingLoss;
Variable predictionVar = prediction;
auto combinedNet = Combine({ trainingLoss, prediction, classifierOutput }, L"MNISTClassifier");
SaveAndReloadModel<float>(combinedNet, { &input, &labels, &trainingLossVar, &predictionVar, &classifierOutputVar }, device);
classifierOutput = classifierOutputVar;
trainingLoss = trainingLossVar;
prediction = predictionVar;
}
const size_t minibatchSize = 32; const size_t minibatchSize = 32;
const size_t numSamplesPerSweep = 60000; const size_t numSamplesPerSweep = 60000;
@ -91,18 +108,14 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"labels"); }); auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInformation& streamInfo) { return (streamInfo.m_name == L"labels"); });
double learningRatePerSample = 0.003125; double learningRatePerSample = 0.003125;
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) }); Trainer trainer(classifierOutput, trainingLoss, prediction, { SGDLearner(classifierOutput->Parameters(), learningRatePerSample) });
size_t outputFrequencyInMinibatches = 20; size_t outputFrequencyInMinibatches = 20;
for (size_t i = 0; i < numMinibatchesToTrain; ++i) for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{ {
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device); auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device); trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].m_data }, { labels, minibatchData[*labelStreamInfo].m_data } }, device);
PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches);
if ((i % outputFrequencyInMinibatches) == 0)
{
double trainLossValue = trainer.PreviousMinibatchAverageTrainingLoss();
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
}
} }
} }

Просмотреть файл

@ -111,6 +111,7 @@
<ItemGroup> <ItemGroup>
<ClCompile Include="CifarResNet.cpp" /> <ClCompile Include="CifarResNet.cpp" />
<ClCompile Include="LearnerTests.cpp" /> <ClCompile Include="LearnerTests.cpp" />
<ClCompile Include="Seq2Seq.cpp" />
<ClCompile Include="SerializationTests.cpp" /> <ClCompile Include="SerializationTests.cpp" />
<ClCompile Include="FeedForwardTests.cpp" /> <ClCompile Include="FeedForwardTests.cpp" />
<ClCompile Include="FunctionTests.cpp" /> <ClCompile Include="FunctionTests.cpp" />

Просмотреть файл

@ -48,6 +48,9 @@
<ClCompile Include="LearnerTests.cpp"> <ClCompile Include="LearnerTests.cpp">
<Filter>Source Files</Filter> <Filter>Source Files</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="Seq2Seq.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClInclude Include="Common.h"> <ClInclude Include="Common.h">