CNTK v2 library: Expose 18 additional operators in the v2 API
This commit is contained in:
Родитель
f3dec438d6
Коммит
a632393c11
|
@ -1331,10 +1331,74 @@ namespace CNTK
|
|||
};
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in matrix multiplication operation with the specified input operands.
|
||||
/// TODO: Specify the constraints on the shapes of the operands.
|
||||
/// Create an instance of the CNTK built-in elementwise negate operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr Negate(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise sigmoid operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Sigmoid(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise tanh operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Tanh(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise linear rectifier operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr ReLU(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise exp operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Exp(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise log operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Log(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise square operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Square(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise square-root operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Sqrt(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise round operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Round(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise floor operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Floor(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise ceil operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Ceil(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise abs operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Abs(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise reciprocal operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Reciprocal(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in softmax operation on specified tensor input operand
|
||||
///
|
||||
CNTK_API FunctionPtr Softmax(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise tensor addition operation with the specified input operands.
|
||||
|
@ -1342,30 +1406,71 @@ namespace CNTK
|
|||
CNTK_API FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise sigmoid operation with the specified input operand.
|
||||
/// Create an instance of the CNTK built-in elementwise tensor subtraction operation with the specified input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr Sigmoid(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr Minus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise tanh operation with the specified input operand.
|
||||
/// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr Tanh(const Variable& operand, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise division operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementDivide(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise equality comparison operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr Equal(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise not-equal comparison operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr NotEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise less than comparison operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr Less(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise less than or equal to comparison operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr LessEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise greater than comparison operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr Greater(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise greater than or equal to comparison operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr GreaterEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in matrix multiplication operation with the specified input operands.
|
||||
/// TODO: Specify the constraints on the shapes of the operands.
|
||||
///
|
||||
CNTK_API FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in operation to compute squared-error for specified input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr SquaredError(const Variable& prediction, const Variable& targets, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in operation to compute cross-entropy with softmax for specified input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr CrossEntropyWithSoftmax(const Variable& output, const Variable& labels, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr CrossEntropyWithSoftmax(const Variable& prediction, const Variable& labels, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in operation for computing the classification prediction error for specified operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ClassificationError(const Variable& prediction, const Variable& labels, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise exp operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr Exp(const Variable& operand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in operation for getting the past value along the lone dynamic axis of the specified operand.
|
||||
/// Throws an exception of the operand has more than one dynamic axis.
|
||||
|
@ -1380,11 +1485,6 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise multiplication operation on specified tensor input operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in sum reduction operation on specified tensor input operand along all the axes
|
||||
///
|
||||
|
|
|
@ -51,8 +51,8 @@ namespace CNTK
|
|||
// Placeholders can be replaced incrementally - i.e. not all placeholders need to replaced in one go.
|
||||
// The only requirement is that they must all be replaced before making any 'Forward' calls on the Function instance.
|
||||
/*virtual*/ void Function::ReplacePlaceholders(const std::unordered_map<Placeholder, Variable>& placeholderReplacements,
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Placeholder>& replacedPlaceholders)
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Placeholder>& replacedPlaceholders)
|
||||
{
|
||||
visitedFunctions.insert(this);
|
||||
|
||||
|
@ -75,8 +75,8 @@ namespace CNTK
|
|||
// Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables
|
||||
// should have been replaced before performing any Forward compute of 'this' Function.
|
||||
/*virtual*/ void CompositeFunction::ReplacePlaceholders(const std::unordered_map<Placeholder, Variable>& placeholderReplacements,
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Placeholder>& replacedPlaceholders)
|
||||
std::unordered_set<const Function*>& visitedFunctions,
|
||||
std::unordered_set<Placeholder>& replacedPlaceholders)
|
||||
{
|
||||
RootFunction()->ReplacePlaceholders(placeholderReplacements, visitedFunctions, replacedPlaceholders);
|
||||
|
||||
|
@ -101,10 +101,10 @@ namespace CNTK
|
|||
// top level 'variable'
|
||||
template <typename ElementType>
|
||||
/*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
{
|
||||
auto iter = variableToNodeMap.find(variable);
|
||||
if (iter != variableToNodeMap.end())
|
||||
|
@ -152,10 +152,10 @@ namespace CNTK
|
|||
|
||||
template <typename ElementType>
|
||||
/*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
|
||||
ComputationNetworkBuilder<ElementType>& builder,
|
||||
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
|
||||
std::unordered_map<Variable, bool>& isVariableRootMap)
|
||||
{
|
||||
assert(variable.IsOutput());
|
||||
|
||||
|
@ -180,12 +180,8 @@ namespace CNTK
|
|||
PrimitiveOpType op = primitiveFunction->OpType();
|
||||
switch (op)
|
||||
{
|
||||
case PrimitiveOpType::Plus:
|
||||
computationNodePtr = builder.Plus(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Times:
|
||||
// TODO: The output rank of the times operation is currently hardcoded to 1
|
||||
computationNodePtr = builder.Times(input0Node, input1Node, 1, function->Name());
|
||||
case PrimitiveOpType::Negate:
|
||||
computationNodePtr = builder.Negate(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Sigmoid:
|
||||
computationNodePtr = builder.Sigmoid(input0Node, function->Name());
|
||||
|
@ -193,15 +189,73 @@ namespace CNTK
|
|||
case PrimitiveOpType::Tanh:
|
||||
computationNodePtr = builder.Tanh(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::ReLU:
|
||||
computationNodePtr = builder.RectifiedLinear(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Exp:
|
||||
computationNodePtr = builder.Exp(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Log:
|
||||
computationNodePtr = builder.Log(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Sqrt:
|
||||
computationNodePtr = builder.Sqrt(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Floor:
|
||||
computationNodePtr = builder.Floor(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Abs:
|
||||
computationNodePtr = builder.Abs(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Reciprocal:
|
||||
computationNodePtr = builder.Reciprocal(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Softmax:
|
||||
if (functionInputs[0].Shape().NumAxes() > 1)
|
||||
InvalidArgument("Softmax operation can only be applied to a 1D input");
|
||||
|
||||
computationNodePtr = builder.Softmax(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Plus:
|
||||
computationNodePtr = builder.Plus(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Minus:
|
||||
computationNodePtr = builder.Minus(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
computationNodePtr = builder.ElementTimes(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Equal:
|
||||
computationNodePtr = builder.Equal(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::NotEqual:
|
||||
computationNodePtr = builder.NotEqual(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Less:
|
||||
computationNodePtr = builder.Less(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::LessEqual:
|
||||
computationNodePtr = builder.LessEqual(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Greater:
|
||||
computationNodePtr = builder.Greater(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::GreaterEqual:
|
||||
computationNodePtr = builder.GreaterEqual(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Times:
|
||||
// TODO: The output rank of the times operation is currently hardcoded to 1
|
||||
computationNodePtr = builder.Times(input0Node, input1Node, 1, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::SquaredError:
|
||||
computationNodePtr = builder.SquareError(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
computationNodePtr = builder.CrossEntropyWithSoftmax(input1Node, input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::ClassificationError:
|
||||
computationNodePtr = builder.ErrorPrediction(input1Node, input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::Exp:
|
||||
computationNodePtr = builder.Exp(input0Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
{
|
||||
|
@ -231,9 +285,6 @@ namespace CNTK
|
|||
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
computationNodePtr = builder.ElementTimes(input0Node, input1Node, function->Name());
|
||||
break;
|
||||
case PrimitiveOpType::ReduceSum:
|
||||
{
|
||||
// TODO: Use the new ReduceElements node instead of the legacy SumElements node for reduction. Currently ReduceElements has incorrect MBLayout inference.
|
||||
|
@ -409,7 +460,7 @@ namespace CNTK
|
|||
layout->AddSequence(0, 0, 0, maxNumTimeSteps);
|
||||
}
|
||||
|
||||
return{ matrixData , layout};
|
||||
return{ matrixData, layout };
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -444,7 +495,7 @@ namespace CNTK
|
|||
// The data needs to be rearranged since CNTK requires sequences to be interleaved across timesteps
|
||||
std::vector<MBLayout::SequenceInfo> sequences;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
sequences.push_back({ i, SIZE_MAX, 0, sequenceLengths[i]});
|
||||
sequences.push_back({ i, SIZE_MAX, 0, sequenceLengths[i] });
|
||||
|
||||
auto layout = std::make_shared<MBLayout>();
|
||||
std::vector<std::pair<size_t, size_t>> placement;
|
||||
|
@ -458,10 +509,10 @@ namespace CNTK
|
|||
|
||||
// Now generate the gather indices
|
||||
auto matrixData = std::make_shared<Matrix<ElementType>>(var.Shape().TotalSize(),
|
||||
layout->GetNumCols(),
|
||||
AsCNTKImplDeviceId(value->Data()->Device()),
|
||||
value->Data()->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE,
|
||||
AsCNTKImplMatrixFormat(value->Data()->GetStorageFormat()));
|
||||
layout->GetNumCols(),
|
||||
AsCNTKImplDeviceId(value->Data()->Device()),
|
||||
value->Data()->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE,
|
||||
AsCNTKImplMatrixFormat(value->Data()->GetStorageFormat()));
|
||||
|
||||
std::vector<size_t> sequencesShorterThanLongestSequence;
|
||||
for (size_t i = 0; i < numSequences; ++i)
|
||||
|
@ -772,9 +823,9 @@ namespace CNTK
|
|||
}
|
||||
|
||||
/*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, const ValuePtr>& arguments,
|
||||
std::unordered_map<Variable, ValuePtr>& outputs,
|
||||
const DeviceDescriptor& computeDevice,
|
||||
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor)
|
||||
std::unordered_map<Variable, ValuePtr>& outputs,
|
||||
const DeviceDescriptor& computeDevice,
|
||||
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor)
|
||||
{
|
||||
// TODO: How about zero argument functions?
|
||||
// TODO: We need a better way to determine the ElementType for the network
|
||||
|
@ -819,8 +870,8 @@ namespace CNTK
|
|||
}
|
||||
|
||||
/*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
|
||||
const std::unordered_map<Variable, const ValuePtr>& rootGradientValues,
|
||||
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
|
||||
const std::unordered_map<Variable, const ValuePtr>& rootGradientValues,
|
||||
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
|
||||
{
|
||||
auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
|
||||
if (backpropState == nullptr)
|
||||
|
@ -829,7 +880,7 @@ namespace CNTK
|
|||
// TODO: Support multiple concurrent backprop states
|
||||
if (backpropState->EvalTimeStamp().second != m_variableToNodeMap[backpropState->EvalTimeStamp().first]->GetEvalTimeStamp())
|
||||
LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
|
||||
"This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");
|
||||
"This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");
|
||||
|
||||
if (rootGradientValues.size() > 1)
|
||||
LogicError("Currently gradient backprop from only one of the Function Outputs is supported");
|
||||
|
@ -852,24 +903,179 @@ namespace CNTK
|
|||
// TODO: How to deal with the specified 'computeDevice'
|
||||
}
|
||||
|
||||
FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
FunctionPtr UnaryOp(PrimitiveOpType op, const Variable& operand, Dictionary&& opConfig, const std::wstring& name)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Times, std::vector<Variable>({ leftOperand, rightOperand }), Dictionary(), name), name);
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(op, std::vector<Variable>({ operand }), std::move(opConfig), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
FunctionPtr Negate(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Plus, std::vector<Variable>({ leftOperand, rightOperand }), Dictionary(), name), name);
|
||||
return UnaryOp(PrimitiveOpType::Negate, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Sigmoid(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Sigmoid, std::vector<Variable>({ operand }), Dictionary(), name), name);
|
||||
return UnaryOp(PrimitiveOpType::Sigmoid, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Tanh(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Tanh, std::vector<Variable>({ operand }), Dictionary(), name), name);
|
||||
return UnaryOp(PrimitiveOpType::Tanh, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr ReLU(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::ReLU, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Exp(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Exp, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Log(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Log, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Square(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return ElementTimes(operand, operand, name);
|
||||
}
|
||||
|
||||
FunctionPtr Sqrt(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Sqrt, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Round(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return Floor(Plus(operand, Constant(NDShape({}), 0.5f)), name);
|
||||
}
|
||||
|
||||
FunctionPtr Floor(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Floor, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Ceil(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return Negate(Floor(Negate(operand)), name);
|
||||
}
|
||||
|
||||
FunctionPtr Abs(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Abs, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Reciprocal(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Reciprocal, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Softmax(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Softmax, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr BinaryOp(PrimitiveOpType op, const Variable& leftOperand, const Variable& rightOperand, Dictionary&& opConfig, const std::wstring& name)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(op, std::vector<Variable>({ leftOperand, rightOperand }), std::move(opConfig), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr Plus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Plus, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Minus(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Minus, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::ElementTimes, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementDivide(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return ElementTimes(leftOperand, Reciprocal(rightOperand), name);
|
||||
}
|
||||
|
||||
FunctionPtr Equal(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Equal, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr NotEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::NotEqual, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Less(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Less, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr LessEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::LessEqual, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Greater(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Greater, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr GreaterEqual(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::GreaterEqual, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Times(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Times, leftOperand, rightOperand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr SquaredError(const Variable& prediction, const Variable& targets, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::SquaredError, prediction, targets, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr CrossEntropyWithSoftmax(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::CrossEntropyWithSoftmax, prediction, labels, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr ClassificationError(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::ClassificationError, prediction, labels, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr PastValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
if (operand.DynamicAxes().size() != 1)
|
||||
InvalidArgument("PastValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
|
||||
return BinaryOp(PrimitiveOpType::PastValue, initialState, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
if (operand.DynamicAxes().size() != 1)
|
||||
InvalidArgument("FutureValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
|
||||
return BinaryOp(PrimitiveOpType::FutureValue, initialState, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::ReduceSum, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name/* = L""*/)
|
||||
|
@ -888,49 +1094,4 @@ namespace CNTK
|
|||
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Combine, inputs, Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr CrossEntropyWithSoftmax(const Variable& output, const Variable& labels, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::CrossEntropyWithSoftmax, std::vector<Variable>({ output, labels }), Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ClassificationError(const Variable& prediction, const Variable& labels, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ClassificationError, std::vector<Variable>({ prediction, labels }), Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr Exp(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Exp, std::vector<Variable>({ operand }), Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr PastValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
if (operand.DynamicAxes().size() != 1)
|
||||
InvalidArgument("PastValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::PastValue, std::vector<Variable>({ initialState, operand }), std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr FutureValue(const Variable& initialState, const Variable& operand, size_t stepSize, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
if (operand.DynamicAxes().size() != 1)
|
||||
InvalidArgument("FutureValue overload that does not explicitly specify a dynamic axis can only be used for operands with exactly one dynamic axis");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[L"stepSize"] = DictionaryValue(stepSize);
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::FutureValue, std::vector<Variable>({ initialState, operand }), std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementTimes(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ElementTimes, std::vector<Variable>({ leftOperand, rightOperand }), Dictionary(), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ReduceSum(const Variable& operand, const std::wstring& name/* = L""*/)
|
||||
{
|
||||
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ReduceSum, std::vector<Variable>({ operand }), Dictionary(), name), name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,51 +13,89 @@
|
|||
|
||||
namespace CNTK
|
||||
{
|
||||
enum class PrimitiveOpType
|
||||
enum class PrimitiveOpType : unsigned int
|
||||
{
|
||||
Plus,
|
||||
Times,
|
||||
Negate,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
Combine,
|
||||
ReLU,
|
||||
Exp,
|
||||
Log,
|
||||
Sqrt,
|
||||
Floor,
|
||||
Abs,
|
||||
Reciprocal,
|
||||
Softmax,
|
||||
Plus,
|
||||
Minus,
|
||||
ElementTimes,
|
||||
Equal,
|
||||
NotEqual,
|
||||
Less,
|
||||
LessEqual,
|
||||
Greater,
|
||||
GreaterEqual,
|
||||
Times,
|
||||
SquaredError,
|
||||
CrossEntropyWithSoftmax,
|
||||
ClassificationError,
|
||||
Exp,
|
||||
PastValue,
|
||||
FutureValue,
|
||||
ElementTimes,
|
||||
ReduceSum
|
||||
ReduceSum,
|
||||
Combine,
|
||||
};
|
||||
}
|
||||
|
||||
namespace std
|
||||
{
|
||||
template <> struct hash<CNTK::PrimitiveOpType>
|
||||
{
|
||||
size_t operator()(const CNTK::PrimitiveOpType& x) const
|
||||
{
|
||||
return std::hash<unsigned int>()((unsigned int)x);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
inline const char* PrimitiveOpTypeName(PrimitiveOpType opType)
|
||||
{
|
||||
// TODO: Put these in table form
|
||||
if (opType == PrimitiveOpType::Plus)
|
||||
return "Plus";
|
||||
else if (opType == PrimitiveOpType::Times)
|
||||
return "Times";
|
||||
else if (opType == PrimitiveOpType::Sigmoid)
|
||||
return "Sigmoid";
|
||||
else if (opType == PrimitiveOpType::Tanh)
|
||||
return "Tanh";
|
||||
else if (opType == PrimitiveOpType::Combine)
|
||||
return "Combine";
|
||||
else if (opType == PrimitiveOpType::CrossEntropyWithSoftmax)
|
||||
return "CrossEntropyWithSoftmax";
|
||||
else if (opType == PrimitiveOpType::ClassificationError)
|
||||
return "ClassificationError";
|
||||
else if (opType == PrimitiveOpType::Exp)
|
||||
return "Exp";
|
||||
else if (opType == PrimitiveOpType::PastValue)
|
||||
return "PastValue";
|
||||
else if (opType == PrimitiveOpType::FutureValue)
|
||||
return "FutureValue";
|
||||
else if (opType == PrimitiveOpType::ElementTimes)
|
||||
return "ElementTimes";
|
||||
else if (opType == PrimitiveOpType::ReduceSum)
|
||||
return "ReduceSum";
|
||||
else
|
||||
static std::unordered_map<PrimitiveOpType, const char*> primitiveOpNames = {
|
||||
{ PrimitiveOpType::Negate, "Negate" },
|
||||
{ PrimitiveOpType::Sigmoid, "Sigmoid" },
|
||||
{ PrimitiveOpType::Tanh, "Tanh" },
|
||||
{ PrimitiveOpType::ReLU, "ReLU" },
|
||||
{ PrimitiveOpType::Exp, "Exp" },
|
||||
{ PrimitiveOpType::Log, "Log" },
|
||||
{ PrimitiveOpType::Sqrt, "Sqrt" },
|
||||
{ PrimitiveOpType::Floor, "Floor" },
|
||||
{ PrimitiveOpType::Abs, "Abs" },
|
||||
{ PrimitiveOpType::Reciprocal, "Reciprocal" },
|
||||
{ PrimitiveOpType::Softmax, "Softmax" },
|
||||
{ PrimitiveOpType::Plus, "Plus" },
|
||||
{ PrimitiveOpType::Minus, "Minus" },
|
||||
{ PrimitiveOpType::ElementTimes, "ElementTimes" },
|
||||
{ PrimitiveOpType::Equal, "Equal" },
|
||||
{ PrimitiveOpType::NotEqual, "NotEqual" },
|
||||
{ PrimitiveOpType::Less, "Less" },
|
||||
{ PrimitiveOpType::LessEqual, "LessEqual" },
|
||||
{ PrimitiveOpType::Greater, "Greater" },
|
||||
{ PrimitiveOpType::GreaterEqual, "GreaterEqual" },
|
||||
{ PrimitiveOpType::Times, "Times" },
|
||||
{ PrimitiveOpType::SquaredError, "SquaredError" },
|
||||
{ PrimitiveOpType::CrossEntropyWithSoftmax, "CrossEntropyWithSoftmax" },
|
||||
{ PrimitiveOpType::ClassificationError, "ClassificationError" },
|
||||
{ PrimitiveOpType::PastValue, "PastValue" },
|
||||
{ PrimitiveOpType::FutureValue, "FutureValue" },
|
||||
{ PrimitiveOpType::ReduceSum, "ReduceSum" },
|
||||
{ PrimitiveOpType::Combine, "Combine" }
|
||||
};
|
||||
|
||||
if (primitiveOpNames.find(opType) == primitiveOpNames.end())
|
||||
LogicError("Unknown PrimitiveOpType");
|
||||
|
||||
return primitiveOpNames.find(opType)->second;
|
||||
}
|
||||
|
||||
class PrimitiveFunction final : public Function
|
||||
|
@ -195,19 +233,29 @@ namespace CNTK
|
|||
|
||||
switch (op)
|
||||
{
|
||||
case PrimitiveOpType::Negate:
|
||||
case PrimitiveOpType::Sigmoid:
|
||||
case PrimitiveOpType::Tanh:
|
||||
case PrimitiveOpType::ReLU:
|
||||
case PrimitiveOpType::Exp:
|
||||
case PrimitiveOpType::Log:
|
||||
case PrimitiveOpType::Sqrt:
|
||||
case PrimitiveOpType::Floor:
|
||||
case PrimitiveOpType::Abs:
|
||||
case PrimitiveOpType::Reciprocal:
|
||||
case PrimitiveOpType::Softmax:
|
||||
assert(inputs.size() == 1);
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[0].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
assert(inputs.size() == 2);
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::Plus:
|
||||
case PrimitiveOpType::Minus:
|
||||
case PrimitiveOpType::ElementTimes:
|
||||
case PrimitiveOpType::Equal:
|
||||
case PrimitiveOpType::NotEqual:
|
||||
case PrimitiveOpType::Less:
|
||||
case PrimitiveOpType::LessEqual:
|
||||
case PrimitiveOpType::Greater:
|
||||
case PrimitiveOpType::GreaterEqual:
|
||||
assert(inputs.size() == 2);
|
||||
outputs.push_back(Variable(BinaryElementwiseOpOutputShape(op, inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
|
@ -215,6 +263,7 @@ namespace CNTK
|
|||
assert(inputs.size() == 2);
|
||||
outputs.push_back(Variable(TimesOpOutputShape(inputs[0].Shape(), inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::SquaredError:
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
case PrimitiveOpType::ClassificationError:
|
||||
{
|
||||
|
@ -235,6 +284,11 @@ namespace CNTK
|
|||
outputs.push_back(Variable(ReductionOpOutputShape(op, predictionShape, reductionAxes), outputDataType, owner, {}));
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::PastValue:
|
||||
case PrimitiveOpType::FutureValue:
|
||||
assert(inputs.size() == 2);
|
||||
outputs.push_back(Variable(UnaryElementwiseOpOutputShape(inputs[1].Shape()), outputDataType, owner, outputDynamicAxes));
|
||||
break;
|
||||
case PrimitiveOpType::ReduceSum:
|
||||
{
|
||||
assert(inputs.size() == 1);
|
||||
|
|
Загрузка…
Ссылка в новой задаче