add support for node Abs; take CR comments into account

This commit is contained in:
William Darling 2016-03-07 14:58:00 +01:00
Родитель 4b3561f06c
Коммит a6a80cded9
7 изменённых файлов: 20 добавлений и 4 удалений

Просмотреть файл

@ -79,6 +79,7 @@ L"ParameterTensor(dims, learningRateMultiplier = 1.0, init = 'uniform'/*|fixedVa
#ifdef COMING_SOON
TernaryStandardNode(CRF, labelVectorSequence, positionDependenScoreVectorSequence, transitionScores) // TODO: better names
#endif
UnaryStandardNode(Abs, x)
QuaternaryStandardNode(ClassBasedCrossEntropyWithSoftmax, labelClassDescriptorVectorSequence, mainInputInfo, mainWeight, classLogProbsBeforeSoftmax)
// BUGBUG: the commented-out ones are not mentioned in the CNTK book, nor are their parameters documented in the source code
BinaryStandardNode(ColumnElementTimes, aVectorSequence, anotherVectorSequence)

Просмотреть файл

@ -149,7 +149,8 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
wstring nodeType = msra::strfun::utf16(p_nodeType);
bool ret = false;
if (EqualInsensitive(nodeType, OperationNameOf(AveragePoolingNode))) ret = true;
if (EqualInsensitive(nodeType, OperationNameOf(AbsNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(AveragePoolingNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(BatchNormalizationNode))) ret = true;
#ifdef COMING_SOON
else if (EqualInsensitive(nodeType, OperationNameOf(CRFNode), L"CRF")) ret = true;

Просмотреть файл

@ -37,7 +37,8 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
if (nodeType == OperationNameOf(CRFNode)) return New<CRFNode<ElemType>>(forward<_Types>(_Args)...);
else
#endif
if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode))return New<ClassBasedCrossEntropyWithSoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
if (nodeType == OperationNameOf(AbsNode)) return New<AbsNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode))return New<ClassBasedCrossEntropyWithSoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(CosDistanceNode)) return New<CosDistanceNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(CosDistanceWithNegativeSamplesNode)) return New<CosDistanceWithNegativeSamplesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(CosineNode)) return New<CosineNode<ElemType>>(forward<_Types>(_Args)...);

Просмотреть файл

@ -77,6 +77,7 @@ public:
#ifdef COMING_SOON
ComputationNodePtr CRF(const ComputationNodePtr label, const ComputationNodePtr postDepScore, const ComputationNodePtr transition_score, const std::wstring nodeName = L"");
#endif
ComputationNodePtr Abs(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr ClassCrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr cls_log_post_prob, const std::wstring nodeName = L"");
ComputationNodePtr Cos(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr CosDistance(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");

Просмотреть файл

@ -117,6 +117,7 @@ DeclareUnaryElementWiseWithOpCodeNode(RectifiedLinear, LinearRectifier, Elementw
DeclareUnaryElementWiseWithOpCodeNode(Log, Log, ElementwiseProductWithLogDerivativeFromOutput, true);
DeclareUnaryElementWiseWithOpCodeNode(Exp, Exp, ElementwiseProduct, true);
DeclareUnaryElementWiseWithOpCodeNode(Cosine, Cosine, ElementwiseProductWithCosDerivative, false);
DeclareUnaryElementWiseWithOpCodeNode(Abs, Abs, ElementwiseProductWithAbsDerivative, false);
#pragma pop_macro("DeclareUnaryElementWiseWithOpCodeNode")

Просмотреть файл

@ -111,6 +111,7 @@ enum ElementWiseOperator
opElementwiseProductWithLinearRectifierDerivativeFromOutput,
opElementwiseProductWithLogDerivativeFromOutput,
opElementwiseProductWithCosDerivative,
opElementwiseProductWithAbsDerivative,
opSqrOfDifference,
// binary ops for indexing
// opIndex,
@ -160,8 +161,9 @@ enum ElementWiseOperator
Macro(ElementwiseProductWithTanhDerivativeFromOutput); \
Macro(ElementwiseProductWithLinearRectifierDerivativeFromOutput); \
Macro(ElementwiseProductWithLogDerivativeFromOutput); \
Macro(ElementwiseProductWithCosDerivative); \
Macro(SqrOfDifference); \
Macro(ElementwiseProductWithCosDerivative); \
Macro(ElementwiseProductWithAbsDerivative); \
Macro(SqrOfDifference); \
//Macro(Index);
#define ForAllTernaryOps(Macro) \

Просмотреть файл

@ -96,6 +96,14 @@ DECL ElemType LinearRectifierDerivative(ElemType z)
return z > 0 ? (ElemType) 1 : 0;
}
template <class ElemType>
DECL ElemType Sgn(ElemType z)
{
if (z > 0.0) return 1.0;
if (z < 0.0) return -1.0;
return z;
}
template <class ElemType>
DECL ElemType Sqr(ElemType z)
{
@ -227,6 +235,7 @@ DefBinaryOp(ElementwiseProductWithTanhDerivativeFromOutput, a*(1 - b * b));
DefBinaryOp(ElementwiseProductWithLinearRectifierDerivativeFromOutput, b > 0 ? a : 0);
DefBinaryOp(ElementwiseProductWithLogDerivativeFromOutput, a* exp_(-b));
DefBinaryOp(ElementwiseProductWithCosDerivative, a * -sin_(b)); // note: b = input for cos()
DefBinaryOp(ElementWideProductWithAbsDerivative, a * Sgn(b)); // note: b = input for abs()
DefBinaryOp(SqrOfDifference, Sqr(a - b));
//DefBinaryOp(Index, IndexElement(a, b, i)); // note: this one uses the third argument