add support for node Abs; take CR comments into account
This commit is contained in:
Родитель
4b3561f06c
Коммит
a6a80cded9
|
@ -79,6 +79,7 @@ L"ParameterTensor(dims, learningRateMultiplier = 1.0, init = 'uniform'/*|fixedVa
|
|||
#ifdef COMING_SOON
|
||||
TernaryStandardNode(CRF, labelVectorSequence, positionDependenScoreVectorSequence, transitionScores) // TODO: better names
|
||||
#endif
|
||||
UnaryStandardNode(Abs, x)
|
||||
QuaternaryStandardNode(ClassBasedCrossEntropyWithSoftmax, labelClassDescriptorVectorSequence, mainInputInfo, mainWeight, classLogProbsBeforeSoftmax)
|
||||
// BUGBUG: the commented-out ones are not mentioned in the CNTK book, nor are their parameters documented in the source code
|
||||
BinaryStandardNode(ColumnElementTimes, aVectorSequence, anotherVectorSequence)
|
||||
|
|
|
@ -149,7 +149,8 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
|
|||
|
||||
wstring nodeType = msra::strfun::utf16(p_nodeType);
|
||||
bool ret = false;
|
||||
if (EqualInsensitive(nodeType, OperationNameOf(AveragePoolingNode))) ret = true;
|
||||
if (EqualInsensitive(nodeType, OperationNameOf(AbsNode))) ret = true;
|
||||
else if (EqualInsensitive(nodeType, OperationNameOf(AveragePoolingNode))) ret = true;
|
||||
else if (EqualInsensitive(nodeType, OperationNameOf(BatchNormalizationNode))) ret = true;
|
||||
#ifdef COMING_SOON
|
||||
else if (EqualInsensitive(nodeType, OperationNameOf(CRFNode), L"CRF")) ret = true;
|
||||
|
|
|
@ -37,7 +37,8 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
|
|||
if (nodeType == OperationNameOf(CRFNode)) return New<CRFNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else
|
||||
#endif
|
||||
if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode))return New<ClassBasedCrossEntropyWithSoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
if (nodeType == OperationNameOf(AbsNode)) return New<AbsNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode))return New<ClassBasedCrossEntropyWithSoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(CosDistanceNode)) return New<CosDistanceNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(CosDistanceWithNegativeSamplesNode)) return New<CosDistanceWithNegativeSamplesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(CosineNode)) return New<CosineNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
|
|
@ -77,6 +77,7 @@ public:
|
|||
#ifdef COMING_SOON
|
||||
ComputationNodePtr CRF(const ComputationNodePtr label, const ComputationNodePtr postDepScore, const ComputationNodePtr transition_score, const std::wstring nodeName = L"");
|
||||
#endif
|
||||
ComputationNodePtr Abs(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr ClassCrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr cls_log_post_prob, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr Cos(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr CosDistance(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
|
|
|
@ -117,6 +117,7 @@ DeclareUnaryElementWiseWithOpCodeNode(RectifiedLinear, LinearRectifier, Elementw
|
|||
DeclareUnaryElementWiseWithOpCodeNode(Log, Log, ElementwiseProductWithLogDerivativeFromOutput, true);
|
||||
DeclareUnaryElementWiseWithOpCodeNode(Exp, Exp, ElementwiseProduct, true);
|
||||
DeclareUnaryElementWiseWithOpCodeNode(Cosine, Cosine, ElementwiseProductWithCosDerivative, false);
|
||||
DeclareUnaryElementWiseWithOpCodeNode(Abs, Abs, ElementwiseProductWithAbsDerivative, false);
|
||||
|
||||
#pragma pop_macro("DeclareUnaryElementWiseWithOpCodeNode")
|
||||
|
||||
|
|
|
@ -111,6 +111,7 @@ enum ElementWiseOperator
|
|||
opElementwiseProductWithLinearRectifierDerivativeFromOutput,
|
||||
opElementwiseProductWithLogDerivativeFromOutput,
|
||||
opElementwiseProductWithCosDerivative,
|
||||
opElementwiseProductWithAbsDerivative,
|
||||
opSqrOfDifference,
|
||||
// binary ops for indexing
|
||||
// opIndex,
|
||||
|
@ -160,8 +161,9 @@ enum ElementWiseOperator
|
|||
Macro(ElementwiseProductWithTanhDerivativeFromOutput); \
|
||||
Macro(ElementwiseProductWithLinearRectifierDerivativeFromOutput); \
|
||||
Macro(ElementwiseProductWithLogDerivativeFromOutput); \
|
||||
Macro(ElementwiseProductWithCosDerivative); \
|
||||
Macro(SqrOfDifference); \
|
||||
Macro(ElementwiseProductWithCosDerivative); \
|
||||
Macro(ElementwiseProductWithAbsDerivative); \
|
||||
Macro(SqrOfDifference); \
|
||||
//Macro(Index);
|
||||
|
||||
#define ForAllTernaryOps(Macro) \
|
||||
|
|
|
@ -96,6 +96,14 @@ DECL ElemType LinearRectifierDerivative(ElemType z)
|
|||
return z > 0 ? (ElemType) 1 : 0;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
DECL ElemType Sgn(ElemType z)
|
||||
{
|
||||
if (z > 0.0) return 1.0;
|
||||
if (z < 0.0) return -1.0;
|
||||
return z;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
DECL ElemType Sqr(ElemType z)
|
||||
{
|
||||
|
@ -227,6 +235,7 @@ DefBinaryOp(ElementwiseProductWithTanhDerivativeFromOutput, a*(1 - b * b));
|
|||
DefBinaryOp(ElementwiseProductWithLinearRectifierDerivativeFromOutput, b > 0 ? a : 0);
|
||||
DefBinaryOp(ElementwiseProductWithLogDerivativeFromOutput, a* exp_(-b));
|
||||
DefBinaryOp(ElementwiseProductWithCosDerivative, a * -sin_(b)); // note: b = input for cos()
|
||||
DefBinaryOp(ElementWideProductWithAbsDerivative, a * Sgn(b)); // note: b = input for abs()
|
||||
DefBinaryOp(SqrOfDifference, Sqr(a - b));
|
||||
//DefBinaryOp(Index, IndexElement(a, b, i)); // note: this one uses the third argument
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче