CNTK v2 library: Expose 18 additional operators in the v2 API

This commit is contained in:
Amit Agarwal 2016-07-22 13:46:19 -07:00 коммит произвёл Amit
Родитель f3dec438d6
Коммит a632393c11
3 изменённых файлов: 459 добавлений и 144 удалений

Просмотреть файл

@ -1331,10 +1331,74 @@ namespace CNTK
};
///
/// Create an instance of the CNTK built-in matrix multiplication operation with the specified input operands.
/// TODO: Specify the constraints on the shapes of the operands.
/// Create an instance of the CNTK built-in elementwise negate operation with the specified input operand.
///
CNTK_API FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
CNTK_API FunctionPtr Negate(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise sigmoid operation with the specified input operand.
///
CNTK_API FunctionPtr Sigmoid(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise tanh operation with the specified input operand.
///
CNTK_API FunctionPtr Tanh(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise linear rectifier operation with the specified input operand.
///
CNTK_API FunctionPtr ReLU(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise exp operation with the specified input operand.
///
CNTK_API FunctionPtr Exp(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise log operation with the specified input operand.
///
CNTK_API FunctionPtr Log(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise square operation with the specified input operand.
///
CNTK_API FunctionPtr Square(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise square-root operation with the specified input operand.
///
CNTK_API FunctionPtr Sqrt(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise round operation with the specified input operand.
///
CNTK_API FunctionPtr Round(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise floor operation with the specified input operand.
///
CNTK_API FunctionPtr Floor(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise ceil operation with the specified input operand.
///
CNTK_API FunctionPtr Ceil(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise abs operation with the specified input operand.
///
CNTK_API FunctionPtr Abs(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise reciprocal operation with the specified input operand.
///
CNTK_API FunctionPtr Reciprocal(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in softmax operation on specified tensor input operand
///
CNTK_API FunctionPtr Softmax(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise tensor addition operation with the specified input operands.
@ -1342,30 +1406,71 @@ namespace CNTK
CNTK_API FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise sigmoid operation with the specified input operand.
/// Create an instance of the CNTK built-in elementwise tensor subtraction operation with the specified input operands.
///
CNTK_API FunctionPtr Sigmoid(const Variable& operand, const std::wstring& name = L"");
CNTK_API FunctionPtr Minus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise tanh operation with the specified input operand.
/// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands.
///
CNTK_API FunctionPtr Tanh(const Variable& operand, const std::wstring& name = L"");
CNTK_API FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise division operation on specified tensor input operands.
///
CNTK_API FunctionPtr ElementDivide(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise equality comparison operation on specified tensor input operands.
///
CNTK_API FunctionPtr Equal(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise not-equal comparison operation on specified tensor input operands.
///
CNTK_API FunctionPtr NotEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise less than comparison operation on specified tensor input operands.
///
CNTK_API FunctionPtr Less(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise less than or equal to comparison operation on specified tensor input operands.
///
CNTK_API FunctionPtr LessEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise greater than comparison operation on specified tensor input operands.
///
CNTK_API FunctionPtr Greater(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise greater than or equal to comparison operation on specified tensor input operands.
///
CNTK_API FunctionPtr GreaterEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in matrix multiplication operation with the specified input operands.
/// TODO: Specify the constraints on the shapes of the operands.
///
CNTK_API FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in operation to compute squared-error for specified input operands.
///
CNTK_API FunctionPtr SquaredError(const Variable& prediction, const Variable& targets, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in operation to compute cross-entropy with softmax for specified input operands.
///
CNTK_API FunctionPtr CrossEntropyWithSoftmax(const Variable& output, const Variable& labels, const std::wstring& name = L"");
CNTK_API FunctionPtr CrossEntropyWithSoftmax(const Variable& prediction, const Variable& labels, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in operation for computing the classification prediction error for specified operands.
///
CNTK_API FunctionPtr ClassificationError(const Variable& prediction, const Variable& labels, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise exp operation with the specified input operand.
///
CNTK_API FunctionPtr Exp(const Variable& operand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in operation for getting the past value along the lone dynamic axis of the specified operand.
/// Throws an exception of the operand has more than one dynamic axis.
@ -1380,11 +1485,6 @@ namespace CNTK
///
CNTK_API FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands.
///
CNTK_API FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
///
/// Create an instance of the CNTK built-in sum reduction operation on specified tensor input operand along all the axes
///

Просмотреть файл

@ -51,8 +51,8 @@ namespace CNTK
// Placeholders can be replaced incrementally - i.e. not all placeholders need to replaced in one go.
// The only requirement is that they must all be replaced before making any 'Forward' calls on the Function instance.
/*virtual*/ void Function::ReplacePlaceholders(const std::unordered_map<Placeholder, Variable>& placeholderReplacements,
std::unordered_set<const Function*>& visitedFunctions,
std::unordered_set<Placeholder>& replacedPlaceholders)
std::unordered_set<const Function*>& visitedFunctions,
std::unordered_set<Placeholder>& replacedPlaceholders)
{
visitedFunctions.insert(this);
@ -75,8 +75,8 @@ namespace CNTK
// Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables
// should have been replaced before performing any Forward compute of 'this' Function.
/*virtual*/ void CompositeFunction::ReplacePlaceholders(const std::unordered_map<Placeholder, Variable>& placeholderReplacements,
std::unordered_set<const Function*>& visitedFunctions,
std::unordered_set<Placeholder>& replacedPlaceholders)
std::unordered_set<const Function*>& visitedFunctions,
std::unordered_set<Placeholder>& replacedPlaceholders)
{
RootFunction()->ReplacePlaceholders(placeholderReplacements, visitedFunctions, replacedPlaceholders);
@ -101,10 +101,10 @@ namespace CNTK
// top level 'variable'
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap)
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap)
{
auto iter = variableToNodeMap.find(variable);
if (iter != variableToNodeMap.end())
@ -152,10 +152,10 @@ namespace CNTK
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap)
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap)
{
assert(variable.IsOutput());
@ -180,12 +180,8 @@ namespace CNTK
PrimitiveOpType op = primitiveFunction->OpType();
switch (op)
{
case PrimitiveOpType::Plus:
computationNodePtr = builder.Plus(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::Times:
// TODO: The output rank of the times operation is currently hardcoded to 1
computationNodePtr = builder.Times(input0Node, input1Node, 1, function->Name());
case PrimitiveOpType::Negate:
computationNodePtr = builder.Negate(input0Node, function->Name());
break;
case PrimitiveOpType::Sigmoid:
computationNodePtr = builder.Sigmoid(input0Node, function->Name());
@ -193,15 +189,73 @@ namespace CNTK
case PrimitiveOpType::Tanh:
computationNodePtr = builder.Tanh(input0Node, function->Name());
break;
case PrimitiveOpType::ReLU:
computationNodePtr = builder.RectifiedLinear(input0Node, function->Name());
break;
case PrimitiveOpType::Exp:
computationNodePtr = builder.Exp(input0Node, function->Name());
break;
case PrimitiveOpType::Log:
computationNodePtr = builder.Log(input0Node, function->Name());
break;
case PrimitiveOpType::Sqrt:
computationNodePtr = builder.Sqrt(input0Node, function->Name());
break;
case PrimitiveOpType::Floor:
computationNodePtr = builder.Floor(input0Node, function->Name());
break;
case PrimitiveOpType::Abs:
computationNodePtr = builder.Abs(input0Node, function->Name());
break;
case PrimitiveOpType::Reciprocal:
computationNodePtr = builder.Reciprocal(input0Node, function->Name());
break;
case PrimitiveOpType::Softmax:
if (functionInputs[0].Shape().NumAxes() > 1)
InvalidArgument("Softmax operation can only be applied to a 1D input");
computationNodePtr = builder.Softmax(input0Node, function->Name());
break;
case PrimitiveOpType::Plus:
computationNodePtr = builder.Plus(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::Minus:
computationNodePtr = builder.Minus(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::ElementTimes:
computationNodePtr = builder.ElementTimes(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::Equal:
computationNodePtr = builder.Equal(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::NotEqual:
computationNodePtr = builder.NotEqual(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::Less:
computationNodePtr = builder.Less(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::LessEqual:
computationNodePtr = builder.LessEqual(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::Greater:
computationNodePtr = builder.Greater(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::GreaterEqual:
computationNodePtr = builder.GreaterEqual(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::Times:
// TODO: The output rank of the times operation is currently hardcoded to 1
computationNodePtr = builder.Times(input0Node, input1Node, 1, function->Name());
break;
case PrimitiveOpType::SquaredError:
computationNodePtr = builder.SquareError(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::CrossEntropyWithSoftmax:
computationNodePtr = builder.CrossEntropyWithSoftmax(input1Node, input0Node, function->Name());
break;
case PrimitiveOpType::ClassificationError:
computationNodePtr = builder.ErrorPrediction(input1Node, input0Node, function->Name());
break;
case PrimitiveOpType::Exp:
computationNodePtr = builder.Exp(input0Node, function->Name());
break;
case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue:
{
@ -231,9 +285,6 @@ namespace CNTK
break;
}
case PrimitiveOpType::ElementTimes:
computationNodePtr = builder.ElementTimes(input0Node, input1Node, function->Name());
break;
case PrimitiveOpType::ReduceSum:
{
// TODO: Use the new ReduceElements node instead of the legacy SumElements node for reduction. Currently ReduceElements has incorrect MBLayout inference.
@ -409,7 +460,7 @@ namespace CNTK
layout->AddSequence(0, 0, 0, maxNumTimeSteps);
}
return{ matrixData , layout};
return{ matrixData, layout };
}
else
{
@ -444,7 +495,7 @@ namespace CNTK
// The data needs to be rearranged since CNTK requires sequences to be interleaved across timesteps
std::vector<MBLayout::SequenceInfo> sequences;
for (size_t i = 0; i < numSequences; ++i)
sequences.push_back({ i, SIZE_MAX, 0, sequenceLengths[i]});
sequences.push_back({ i, SIZE_MAX, 0, sequenceLengths[i] });
auto layout = std::make_shared<MBLayout>();
std::vector<std::pair<size_t, size_t>> placement;
@ -458,10 +509,10 @@ namespace CNTK
// Now generate the gather indices
auto matrixData = std::make_shared<Matrix<ElementType>>(var.Shape().TotalSize(),
layout->GetNumCols(),
AsCNTKImplDeviceId(value->Data()->Device()),
value->Data()->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE,
AsCNTKImplMatrixFormat(value->Data()->GetStorageFormat()));
layout->GetNumCols(),
AsCNTKImplDeviceId(value->Data()->Device()),
value->Data()->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE,
AsCNTKImplMatrixFormat(value->Data()->GetStorageFormat()));
std::vector<size_t> sequencesShorterThanLongestSequence;
for (size_t i = 0; i < numSequences; ++i)
@ -772,9 +823,9 @@ namespace CNTK
}
/*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, const ValuePtr>& arguments,
std::unordered_map<Variable, ValuePtr>& outputs,
const DeviceDescriptor& computeDevice,
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor)
std::unordered_map<Variable, ValuePtr>& outputs,
const DeviceDescriptor& computeDevice,
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor)
{
// TODO: How about zero argument functions?
// TODO: We need a better way to determine the ElementType for the network
@ -819,8 +870,8 @@ namespace CNTK
}
/*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
const std::unordered_map<Variable, const ValuePtr>& rootGradientValues,
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
const std::unordered_map<Variable, const ValuePtr>& rootGradientValues,
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
{
auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
if (backpropState == nullptr)
@ -829,7 +880,7 @@ namespace CNTK
// TODO: Support multiple concurrent backprop states
if (backpropState->EvalTimeStamp().second != m_variableToNodeMap[backpropState->EvalTimeStamp().first]->GetEvalTimeStamp())
LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
"This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");
"This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");
if (rootGradientValues.size() > 1)
LogicError("Currently gradient backprop from only one of the Function Outputs is supported");
@ -852,24 +903,179 @@ namespace CNTK
// TODO: How to deal with the specified 'computeDevice'
}
FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
FunctionPtr UnaryOp(PrimitiveOpType op, const Variable& operand, Dictionary&& opConfig, const std::wstring& name)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Times, std::vector<Variable>({ leftOperand, rightOperand }), Dictionary(), name), name);
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(op, std::vector<Variable>({ operand }), std::move(opConfig), name), name);
}
FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
FunctionPtr Negate(const Variable& operand, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Plus, std::vector<Variable>({ leftOperand, rightOperand }), Dictionary(), name), name);
return UnaryOp(PrimitiveOpType::Negate, operand, Dictionary(), name);
}
FunctionPtr Sigmoid(const Variable& operand, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Sigmoid, std::vector<Variable>({ operand }), Dictionary(), name), name);
return UnaryOp(PrimitiveOpType::Sigmoid, operand, Dictionary(), name);
}
FunctionPtr Tanh(const Variable& operand, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Tanh, std::vector<Variable>({ operand }), Dictionary(), name), name);
return UnaryOp(PrimitiveOpType::Tanh, operand, Dictionary(), name);
}
FunctionPtr ReLU(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::ReLU, operand, Dictionary(), name);
}
FunctionPtr Exp(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::Exp, operand, Dictionary(), name);
}
FunctionPtr Log(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::Log, operand, Dictionary(), name);
}
FunctionPtr Square(const Variable& operand, const std::wstring& name/* = L""*/)
{
return ElementTimes(operand, operand, name);
}
FunctionPtr Sqrt(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::Sqrt, operand, Dictionary(), name);
}
FunctionPtr Round(const Variable& operand, const std::wstring& name/* = L""*/)
{
return Floor(Plus(operand, Constant(NDShape({}), 0.5f)), name);
}
FunctionPtr Floor(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::Floor, operand, Dictionary(), name);
}
FunctionPtr Ceil(const Variable& operand, const std::wstring& name/* = L""*/)
{
return Negate(Floor(Negate(operand)), name);
}
FunctionPtr Abs(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::Abs, operand, Dictionary(), name);
}
FunctionPtr Reciprocal(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::Reciprocal, operand, Dictionary(), name);
}
FunctionPtr Softmax(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::Softmax, operand, Dictionary(), name);
}
FunctionPtr BinaryOp(PrimitiveOpType op, const Variable& leftOperand, const Variable& rightOperand, Dictionary&& opConfig, const std::wstring& name)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(op, std::vector<Variable>({ leftOperand, rightOperand }), std::move(opConfig), name), name);
}
FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::Plus, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr Minus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::Minus, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::ElementTimes, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr ElementDivide(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return ElementTimes(leftOperand, Reciprocal(rightOperand), name);
}
FunctionPtr Equal(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::Equal, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr NotEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::NotEqual, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr Less(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::Less, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr LessEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::LessEqual, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr Greater(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::Greater, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr GreaterEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::GreaterEqual, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::Times, leftOperand, rightOperand, Dictionary(), name);
}
FunctionPtr SquaredError(const Variable& prediction, const Variable& targets, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::SquaredError, prediction, targets, Dictionary(), name);
}
FunctionPtr CrossEntropyWithSoftmax(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::CrossEntropyWithSoftmax, prediction, labels, Dictionary(), name);
}
FunctionPtr ClassificationError(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
{
return BinaryOp(PrimitiveOpType::ClassificationError, prediction, labels, Dictionary(), name);
}
FunctionPtr PastValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
{
if (operand.DynamicAxes().size() != 1)
InvalidArgument("PastValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
auto additionalProperties = Dictionary();
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
return BinaryOp(PrimitiveOpType::PastValue, initialState, operand, std::move(additionalProperties), name);
}
FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
{
if (operand.DynamicAxes().size() != 1)
InvalidArgument("FutureValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
auto additionalProperties = Dictionary();
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
return BinaryOp(PrimitiveOpType::FutureValue, initialState, operand, std::move(additionalProperties), name);
}
FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name/* = L""*/)
{
return UnaryOp(PrimitiveOpType::ReduceSum, operand, Dictionary(), name);
}
FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name/* = L""*/)
@ -888,49 +1094,4 @@ namespace CNTK
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Combine, inputs, Dictionary(), name), name);
}
FunctionPtr CrossEntropyWithSoftmax(const Variable& output, const Variable& labels, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::CrossEntropyWithSoftmax, std::vector<Variable>({ output, labels }), Dictionary(), name), name);
}
FunctionPtr ClassificationError(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ClassificationError, std::vector<Variable>({ prediction, labels }), Dictionary(), name), name);
}
FunctionPtr Exp(const Variable& operand, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Exp, std::vector<Variable>({ operand }), Dictionary(), name), name);
}
FunctionPtr PastValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
{
if (operand.DynamicAxes().size() != 1)
InvalidArgument("PastValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
auto additionalProperties = Dictionary();
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::PastValue, std::vector<Variable>({ initialState, operand }), std::move(additionalProperties), name), name);
}
FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
{
if (operand.DynamicAxes().size() != 1)
InvalidArgument("FutureValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
auto additionalProperties = Dictionary();
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::FutureValue, std::vector<Variable>({ initialState, operand }), std::move(additionalProperties), name), name);
}
FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ElementTimes, std::vector<Variable>({ leftOperand, rightOperand }), Dictionary(), name), name);
}
FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name/* = L""*/)
{
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ReduceSum, std::vector<Variable>({ operand }), Dictionary(), name), name);
}
}

Просмотреть файл

@ -13,51 +13,89 @@
namespace CNTK
{
enum class PrimitiveOpType
enum class PrimitiveOpType : unsigned int
{
Plus,
Times,
Negate,
Sigmoid,
Tanh,
Combine,
ReLU,
Exp,
Log,
Sqrt,
Floor,
Abs,
Reciprocal,
Softmax,
Plus,
Minus,
ElementTimes,
Equal,
NotEqual,
Less,
LessEqual,
Greater,
GreaterEqual,
Times,
SquaredError,
CrossEntropyWithSoftmax,
ClassificationError,
Exp,
PastValue,
FutureValue,
ElementTimes,
ReduceSum
ReduceSum,
Combine,
};
}
namespace std
{
template <> struct hash<CNTK::PrimitiveOpType>
{
size_t operator()(const CNTK::PrimitiveOpType& x) const
{
return std::hash<unsigned int>()((unsigned int)x);
}
};
}
namespace CNTK
{
inline const char* PrimitiveOpTypeName(PrimitiveOpType opType)
{
// TODO: Put these in table form
if (opType == PrimitiveOpType::Plus)
return "Plus";
else if (opType == PrimitiveOpType::Times)
return "Times";
else if (opType == PrimitiveOpType::Sigmoid)
return "Sigmoid";
else if (opType == PrimitiveOpType::Tanh)
return "Tanh";
else if (opType == PrimitiveOpType::Combine)
return "Combine";
else if (opType == PrimitiveOpType::CrossEntropyWithSoftmax)
return "CrossEntropyWithSoftmax";
else if (opType == PrimitiveOpType::ClassificationError)
return "ClassificationError";
else if (opType == PrimitiveOpType::Exp)
return "Exp";
else if (opType == PrimitiveOpType::PastValue)
return "PastValue";
else if (opType == PrimitiveOpType::FutureValue)
return "FutureValue";
else if (opType == PrimitiveOpType::ElementTimes)
return "ElementTimes";
else if (opType == PrimitiveOpType::ReduceSum)
return "ReduceSum";
else
static std::unordered_map<PrimitiveOpType, const char*> primitiveOpNames = {
{ PrimitiveOpType::Negate, "Negate" },
{ PrimitiveOpType::Sigmoid, "Sigmoid" },
{ PrimitiveOpType::Tanh, "Tanh" },
{ PrimitiveOpType::ReLU, "ReLU" },
{ PrimitiveOpType::Exp, "Exp" },
{ PrimitiveOpType::Log, "Log" },
{ PrimitiveOpType::Sqrt, "Sqrt" },
{ PrimitiveOpType::Floor, "Floor" },
{ PrimitiveOpType::Abs, "Abs" },
{ PrimitiveOpType::Reciprocal, "Reciprocal" },
{ PrimitiveOpType::Softmax, "Softmax" },
{ PrimitiveOpType::Plus, "Plus" },
{ PrimitiveOpType::Minus, "Minus" },
{ PrimitiveOpType::ElementTimes, "ElementTimes" },
{ PrimitiveOpType::Equal, "Equal" },
{ PrimitiveOpType::NotEqual, "NotEqual" },
{ PrimitiveOpType::Less, "Less" },
{ PrimitiveOpType::LessEqual, "LessEqual" },
{ PrimitiveOpType::Greater, "Greater" },
{ PrimitiveOpType::GreaterEqual, "GreaterEqual" },
{ PrimitiveOpType::Times, "Times" },
{ PrimitiveOpType::SquaredError, "SquaredError" },
{ PrimitiveOpType::CrossEntropyWithSoftmax, "CrossEntropyWithSoftmax" },
{ PrimitiveOpType::ClassificationError, "ClassificationError" },
{ PrimitiveOpType::PastValue, "PastValue" },
{ PrimitiveOpType::FutureValue, "FutureValue" },
{ PrimitiveOpType::ReduceSum, "ReduceSum" },
{ PrimitiveOpType::Combine, "Combine" }
};
if (primitiveOpNames.find(opType) == primitiveOpNames.end())
LogicError("Unknown PrimitiveOpType");
return primitiveOpNames.find(opType)->second;
}
class PrimitiveFunction final : public Function
@ -195,19 +233,29 @@ namespace CNTK
switch (op)
{
case PrimitiveOpType::Negate:
case PrimitiveOpType::Sigmoid:
case PrimitiveOpType::Tanh:
case PrimitiveOpType::ReLU:
case PrimitiveOpType::Exp:
case PrimitiveOpType::Log:
case PrimitiveOpType::Sqrt:
case PrimitiveOpType::Floor:
case PrimitiveOpType::Abs:
case PrimitiveOpType::Reciprocal:
case PrimitiveOpType::Softmax:
assert(inputs.size() == 1);
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes));
break;
case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue:
assert(inputs.size() == 2);
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
break;
case PrimitiveOpType::Plus:
case PrimitiveOpType::Minus:
case PrimitiveOpType::ElementTimes:
case PrimitiveOpType::Equal:
case PrimitiveOpType::NotEqual:
case PrimitiveOpType::Less:
case PrimitiveOpType::LessEqual:
case PrimitiveOpType::Greater:
case PrimitiveOpType::GreaterEqual:
assert(inputs.size() == 2);
outputs.push_back(Variable(BinaryElementwiseOpOutputShape(op, inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
break;
@ -215,6 +263,7 @@ namespace CNTK
assert(inputs.size() == 2);
outputs.push_back(Variable(TimesOpOutputShape(inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
break;
case PrimitiveOpType::SquaredError:
case PrimitiveOpType::CrossEntropyWithSoftmax:
case PrimitiveOpType::ClassificationError:
{
@ -235,6 +284,11 @@ namespace CNTK
outputs.push_back(Variable(ReductionOpOutputShape(op, predictionShape, reductionAxes), outputDataType, owner, {}));
break;
}
case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue:
assert(inputs.size() == 2);
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
break;
case PrimitiveOpType::ReduceSum:
{
assert(inputs.size() == 1);