moved all tensor ops to a new header TensorOps.h so they can be shared between matrix types;

also moved the float/double-unified math overloads (e.g. exp_()) there, as well as additional typically needed functions such as Sigmoid()
This commit is contained in:
Frank Seide 2015-12-18 08:54:19 -08:00
Родитель 625a877c29
Коммит cb668a9378
8 изменённых файлов: 206 добавлений и 115 удалений

Просмотреть файл

@ -9154,7 +9154,7 @@ L
\begin_layout Standard
\begin_inset Formula
\begin{eqnarray}
\alpha_{t}\left(i\right) & \leftarrow & h_{it}+logadd_{k}\left(\delta_{t-1}(k)+\eta a_{ki}\right)\\
\alpha_{t}\left(i\right) & \leftarrow & h_{it}+LogAdd{k}\left(\delta_{t-1}(k)+\eta a_{ki}\right)\\
\mathbf{\frac{\partial R}{\partial\delta_{t-1}(i)}} & \leftarrow & \sum_{j}\frac{\partial C_{logadd}}{\partial\delta_{t}(j)}\frac{\exp(\delta_{t-1}(i)+a_{i,j})}{\sum_{k}\exp(\delta_{t-1}(k)+a_{k,j})}\\
\mathbf{\frac{\partial R}{\partial\delta_{T}(i)}} & \leftarrow & \frac{\exp(\delta_{T}(i))}{\sum_{k}\exp(\delta_{T}(k))}\\
\frac{\partial R}{\partial h_{t}(i)} & \leftarrow & l_{t}(i)-\frac{\partial R}{\partial\delta_{t}(i)}\\

Просмотреть файл

@ -9,12 +9,13 @@
#include "stdafx.h"
#include "Basics.h"
#include "File.h"
#include "CPUMatrix.h"
#include "TensorOps.h"
#include <assert.h>
#include <stdexcept>
#include <omp.h>
#include <math.h>
#include "CPUMatrix.h"
#include <random>
#include <chrono>
#include <exception>
@ -4304,7 +4305,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (sample_id == 0)
sample_prob = -sample_prob;
double score_noise = log_num_noise_samples + sample_prob;
double z = logadd(score, score_noise);
double z = LogAdd(score, score_noise);
double logprob = score - z;
double logprob_noise = score_noise - z;
tmp(sample_id, instance_id) = (ElemType)-std::exp(logprob);
@ -5258,32 +5259,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
#pragma endregion Static BLAS Functions
template<typename ElemType>
ElemType logadd_(ElemType x, ElemType y)
{
if (x < y)
{
ElemType temp = x; x = y; y = temp;
}
ElemType diff = y - x;
if (diff < (ElemType)MINLOGEXP)
{
return (x < (ElemType)LSMALL) ? (ElemType)LZERO : x;
}
else
{
ElemType z = exp_(diff);
return x + log_((ElemType)1.0 + z);
}
}
double logadd(double x, double y) { return logadd_(x, y); }
// 'double' version of LogAdd
double LogAddD(double x, double y) { return LogAdd(x, y); }
template<class ElemType>
ElemType CPUMatrix<ElemType>::LogAddSumOfElements() const
{
ElemType fAlpha = (ElemType)LZERO;
for (int k = 0; k < GetNumElements(); k++)
fAlpha = (ElemType) logadd(fAlpha, m_pArray[k]);
fAlpha = (ElemType) LogAddD(fAlpha, m_pArray[k]);
return fAlpha;
}
@ -5330,7 +5314,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
fSum = (ElemType)LZERO;
for (int j = 0; j < iNumLab; j++)
{
fSum = (ElemType)logadd((double)fSum, alpha(j, t));
fSum = (ElemType)LogAddD(fSum, alpha(j, t));
}
fTmp = alpha(k, t) - fSum;
@ -5343,10 +5327,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
fSum = (ElemType)LZERO;
for (int m = 0; m < iNumLab; m++)
{
fSum = (ElemType)logadd((double)fSum, alpha(m, t) + pair_scores(j, m));
fSum = (ElemType)LogAddD(fSum, alpha(m, t) + pair_scores(j, m));
}
fTmp = (ElemType)logadd(fTmp, beta(j, t + 1) + alpha(k, t) + pair_scores(j, k) - fSum);
fTmp = (ElemType)LogAddD(fTmp, beta(j, t + 1) + alpha(k, t) + pair_scores(j, k) - fSum);
}
beta(k, t) = fTmp;
}
@ -5455,7 +5439,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
else{
fTmp2 = a(k, 0);
}
fSum = (ElemType)logadd(fSum, fTmp2 + pair_scores(j, k));
fSum = (ElemType)LogAddD(fSum, fTmp2 + pair_scores(j, k));
}
fTmp -= fSum;
@ -5537,6 +5521,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// TensorView support
// -----------------------------------------------------------------------
// To save time, this makes extensive use of templates and macros.
// perform loop over reduction index m
// This function is declared inside a wrapper struct to allow partial specialization (m = -1).
template<class ElemType, size_t N, typename OPFN, int m>
@ -5654,43 +5640,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
template<class ElemType>
static inline ElemType Sigmoid(ElemType z)
{
if (z >= 0)
return 1 / (1 + exp_(-z));
else
{
ElemType v = exp_(z);
return v / (1 + v);
}
}
template<class ElemType>
static inline ElemType SigmoidDerivative(ElemType z)
{
ElemType v = Sigmoid(z);
return v * (1 - v);
}
template<class ElemType>
static inline ElemType LinearRectifierDerivative(ElemType z)
{
return z > 0 ? (ElemType)1 : 0;
}
template<class ElemType>
static inline ElemType Sqrt(ElemType z)
{
return sqrt_(max(0, z));
}
// define a static function for every operation
#define DefUnaryOp(op, expr) template<class ElemType> static inline ElemType Op ## op(ElemType a) { return expr; }
DefUnaryOp(Copy, a);
DefUnaryOp(Negate, -a); DefUnaryOp(Not, !a);
DefUnaryOp(Abs, fabs_(a));
DefUnaryOp(Sigmoid, Sigmoid(a)); DefUnaryOp(SigmoidDerivative, SigmoidDerivative(a)); DefUnaryOp(Tanh, tanh_(a)); DefUnaryOp(Sqrt, Sqrt(a)); DefUnaryOp(Exp, exp_(a)); DefUnaryOp(Log, log_(a)); DefUnaryOp(LinearRectifierDerivative, LinearRectifierDerivative(a)); DefUnaryOp(Cosine, cos_(a)); DefUnaryOp(NegativeSine, -sin_(a));
//DefUnaryOp(SaturateBetaAlpha); DefUnaryOp(SumAlpha); DefUnaryOp(SubDifferenceToAlpha); DefUnaryOp(SubDifferenceFromAlpha);
// perform unary operation 'op' on a giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
// This maps 'op' to a lambda.
template<class ElemType>
@ -5699,29 +5648,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 2> & regularStrides,
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 2> & reducingStrides)
{
array<ElemType*, 2> pointers = { a.m_pArray, m_pArray };
#define CaseUnaryTensorOp(oper) \
case ElementWiseOperator::op ## oper: \
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 2> & pp) { return Op ## oper((*(pp[0]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
array<ElemType*, 2> pointers = { a.m_pArray, m_pArray };
switch (op)
{
CaseUnaryTensorOp(Copy);
CaseUnaryTensorOp(Negate); CaseUnaryTensorOp(Not);
CaseUnaryTensorOp(Abs);
CaseUnaryTensorOp(Sigmoid); CaseUnaryTensorOp(SigmoidDerivative); CaseUnaryTensorOp(Tanh); CaseUnaryTensorOp(Sqrt); CaseUnaryTensorOp(Exp); CaseUnaryTensorOp(Log); CaseUnaryTensorOp(LinearRectifierDerivative); CaseUnaryTensorOp(Cosine); CaseUnaryTensorOp(NegativeSine);
// functions with lambda arguments--these are different
//CaseUnaryTensorOp(SaturateBetaAlpha); CaseUnaryTensorOp(SumAlpha); CaseUnaryTensorOp(SubDifferenceToAlpha); CaseUnaryTensorOp(SubDifferenceFromAlpha);
ForAllUnaryOps(CaseUnaryTensorOp);
default: LogicError("TensorUnaryOp: Unknown op code %d.", (int)op);
}
}
// define a static function for every operation
#define DefBinaryOp(op, expr) template<class ElemType> static inline ElemType Op ## op(ElemType a, ElemType b) { return expr; }
DefBinaryOp(Sum, a + b); DefBinaryOp(Difference, a - b); DefBinaryOp(ElementWiseProduct, a*b); DefBinaryOp(ElementWiseQuotient, a / b);
DefBinaryOp(LogSum, logadd_(a, b)); DefBinaryOp(Max, a > b ? a : b); DefBinaryOp(Min, a < b ? a : b);
DefBinaryOp(EQ, a == b); DefBinaryOp(NE, a != b); DefBinaryOp(GT, a > b); DefBinaryOp(LT, a < b); DefBinaryOp(GE, a >= b); DefBinaryOp(LE, a <= b);
// perform binary operation 'op' on a and b giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
// This maps 'op' to a lambda.
template<class ElemType>
@ -5730,24 +5668,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 3> & regularStrides,
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 3> & reducingStrides)
{
array<ElemType*, 3> pointers = { a.m_pArray, b.m_pArray, m_pArray };
#define CaseBinaryTensorOp(oper) \
case ElementWiseOperator::op ## oper: \
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 3> & pp) { return Op ## oper((*(pp[0])), (*(pp[1]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
array<ElemType*, 3> pointers = { a.m_pArray, b.m_pArray, m_pArray };
switch (op)
{
CaseBinaryTensorOp(Sum); CaseBinaryTensorOp(Difference); CaseBinaryTensorOp(ElementWiseProduct); CaseBinaryTensorOp(ElementWiseQuotient);
CaseBinaryTensorOp(LogSum); CaseBinaryTensorOp(Max); CaseBinaryTensorOp(Min);
CaseBinaryTensorOp(EQ); CaseBinaryTensorOp(NE); CaseBinaryTensorOp(GT); CaseBinaryTensorOp(LT); CaseBinaryTensorOp(GE); CaseBinaryTensorOp(LE);
ForAllBinaryOps(CaseBinaryTensorOp);
default: LogicError("TensorBinaryOp: Unknown op code %d.", (int)op);
}
}
// define a static function for every operation
#define DefTernaryOp(op, expr) template<class ElemType> static inline ElemType Op ## op(ElemType a, ElemType b, ElemType c) { return expr; }
DefTernaryOp(Cond, a ? b : c);
// perform ternary operation 'op' on a, and c giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
// This maps 'op' to a lambda.
template<class ElemType>
@ -5756,18 +5688,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 4> & regularStrides,
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 4> & reducingStrides)
{
array<ElemType*, 4> pointers = { a.m_pArray, b.m_pArray, c.m_pArray, m_pArray };
#define CaseTernaryTensorOp(oper) \
case ElementWiseOperator::op ## oper: \
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 4> & pp) { return Op ## oper((*(pp[0])), (*(pp[1])), (*(pp[2]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
array<ElemType*, 4> pointers = { a.m_pArray, b.m_pArray, c.m_pArray, m_pArray };
switch (op)
{
CaseTernaryTensorOp(Cond);
ForAllTernaryOps(CaseTernaryTensorOp);
default: LogicError("TensorTernaryOp: Unknown op code %d.", (int)op);
}
}
// The explicit instantiation part
// -----------------------------------------------------------------------
// explicit instantiations
// -----------------------------------------------------------------------
template class MATH_API CPUMatrix<float>;
template class MATH_API CPUMatrix<double>;

Просмотреть файл

@ -64,15 +64,23 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// Note: not all of the above are actually implement at present; and not all that's implemented has an opcode.
};
// declare float and double versions of a func under f_
// e.g. exp_ -> exp(double), expf(float)
#define OverloadUnaryMathFns(func) \
static inline float func ## _(float arg) { return func ## f(arg); } \
static inline double func ## _(double arg) { return func(arg); }
// helper to apply a C macro for all operations of each kind
#define ForAllUnaryOps(Macro) \
Macro(Copy); \
Macro(Negate); Macro(Not); \
Macro(Abs); \
Macro(Sigmoid); Macro(SigmoidDerivative); Macro(Tanh); Macro(Sqrt); Macro(Exp); Macro(Log); Macro(LinearRectifierDerivative); Macro(Cosine); Macro(NegativeSine);
OverloadUnaryMathFns(fabs); OverloadUnaryMathFns(sqrt);
OverloadUnaryMathFns(exp); OverloadUnaryMathFns(log);
OverloadUnaryMathFns(tanh); OverloadUnaryMathFns(cos); OverloadUnaryMathFns(sin);
#define ForAllParameterizedUnaryOps(Macro) \
Macro(SaturateBetaAlpha); Macro(SumAlpha); Macro(SubDifferenceToAlpha); Macro(SubDifferenceFromAlpha);
#define ForAllBinaryOps(Macro) \
Macro(Sum); Macro(Difference); Macro(ElementWiseProduct); Macro(ElementWiseQuotient); \
Macro(LogSum); Macro(Max); Macro(Min); \
Macro(EQ); Macro(NE); Macro(GT); Macro(LT); Macro(GE); Macro(LE);
#define ForAllTernaryOps(Macro) \
Macro(Cond);
// -----------------------------------------------------------------------
// various enums to describe

Просмотреть файл

@ -162,6 +162,7 @@
<ClInclude Include="CommonMatrix.h" />
<ClInclude Include="ConvolutionEngine.h" />
<ClInclude Include="CPUMatrix.h" />
<ClInclude Include="TensorOps.h" />
<ClInclude Include="TensorView.h" />
<None Include="ClassDiagram.cd" />
<None Include="GPUWatcher.cu" />

Просмотреть файл

@ -70,6 +70,9 @@
<ClInclude Include="TensorView.h">
<Filter>Tensors</Filter>
</ClInclude>
<ClInclude Include="TensorOps.h">
<Filter>Tensors</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="GPUMatrix.h">

Просмотреть файл

@ -4887,6 +4887,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
return x - y * floor(x / y);
}
// TODO: use static LogAdd() as defined in TensorOps.h
// Not doing this currently because that one uses ElemType for all ops, while this one uses double inside. Must compare before making this change.
template<class ElemType>
ElemType Matrix<ElemType>::LogAdd(ElemType x, ElemType y)
{

132
Source/Math/TensorOps.h Normal file
Просмотреть файл

@ -0,0 +1,132 @@
//
// <copyright file="TensorView.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// This implements the elementwise tensor operations, including helper macros and some actual functions.
#pragma once
#include "Basics.h"
#include "CommonMatrix.h"
#pragma push_macro("TENSOR_OPS_DECL")
#ifndef TENSOR_OPS_DECL // to make these accessible to CUDA kernels, say '#define TENSOR_OPS_DECL __device__ __host__'
#define TENSOR_OPS_DECL
#endif
#pragma push_macro("DECL")
#define DECL static inline TENSOR_OPS_DECL
// This class is exported from the Math.dll.
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// unified overloads for float/double math functions
//
// Declare float and double versions of the functions f we need as f_(),
// e.g. exp_ -> exp(double), expf(float).
// -----------------------------------------------------------------------
#pragma push_macro("OverloadUnaryMathFns")
#define OverloadUnaryMathFns(func) \
DECL float func ## _(float arg) { return func ## f(arg); } \
DECL double func ## _(double arg) { return func(arg); }
OverloadUnaryMathFns(fabs); OverloadUnaryMathFns(sqrt);
OverloadUnaryMathFns(exp); OverloadUnaryMathFns(log);
OverloadUnaryMathFns(tanh); OverloadUnaryMathFns(cos); OverloadUnaryMathFns(sin);
#pragma push_macro("OverloadUnaryMathFns")
// -----------------------------------------------------------------------
// additional functions that are standard in our context
// -----------------------------------------------------------------------
template<class ElemType>
DECL ElemType Sigmoid(ElemType z)
{
if (z >= 0)
return 1 / (1 + exp_(-z));
else
{
ElemType v = exp_(z);
return v / (1 + v);
}
}
template<class ElemType>
DECL ElemType SigmoidDerivative(ElemType z)
{
ElemType v = Sigmoid(z);
return v * (1 - v);
}
template<class ElemType>
DECL ElemType LinearRectifierDerivative(ElemType z)
{
return z > 0 ? (ElemType)1 : 0;
}
template<class ElemType>
DECL ElemType Sqrt(ElemType z)
{
// BUGBUG: Why clip to 0? An invalid sqrt() should show up as a NaN in the result, instead of hiding it.
return sqrt_(z > 0 ? z : 0);
}
// TODO: call this LogAdd() for consistency
template<typename ElemType>
DECL ElemType LogAdd(ElemType x, ElemType y)
{
if (x < y)
{
ElemType temp = x; x = y; y = temp;
}
ElemType diff = y - x;
if (diff < (ElemType)MINLOGEXP)
{
return (x < (ElemType)LSMALL) ? (ElemType)LZERO : x;
}
else
{
ElemType z = exp_(diff);
return x + log_((ElemType)1.0 + z);
}
}
// -----------------------------------------------------------------------
// ElementWiseOperator implementations
//
// Define a static function for every ElementWiseOperator (CommonMatrix.h).
// -----------------------------------------------------------------------
#pragma push_macro("DefUnaryOp")
#define DefUnaryOp(op, expr) template<class ElemType> DECL ElemType Op ## op(ElemType a) { return expr; }
DefUnaryOp(Copy, a);
DefUnaryOp(Negate, -a); DefUnaryOp(Not, !a);
DefUnaryOp(Abs, fabs_(a));
DefUnaryOp(Sigmoid, Sigmoid(a)); DefUnaryOp(SigmoidDerivative, SigmoidDerivative(a)); DefUnaryOp(Tanh, tanh_(a)); DefUnaryOp(Sqrt, Sqrt(a)); DefUnaryOp(Exp, exp_(a)); DefUnaryOp(Log, log_(a)); DefUnaryOp(LinearRectifierDerivative, LinearRectifierDerivative(a)); DefUnaryOp(Cosine, cos_(a)); DefUnaryOp(NegativeSine, -sin_(a));
#pragma pop_macro("DefUnaryOp")
// parameterized unary ops
//DefUnaryOp(SaturateBetaAlpha); DefUnaryOp(SumAlpha); DefUnaryOp(SubDifferenceToAlpha); DefUnaryOp(SubDifferenceFromAlpha);
#pragma push_macro("DefBinaryOp")
#define DefBinaryOp(op, expr) template<class ElemType> DECL ElemType Op ## op(ElemType a, ElemType b) { return expr; }
DefBinaryOp(Sum, a + b); DefBinaryOp(Difference, a - b); DefBinaryOp(ElementWiseProduct, a*b); DefBinaryOp(ElementWiseQuotient, a / b);
DefBinaryOp(LogSum, LogAdd(a, b)); DefBinaryOp(Max, a > b ? a : b); DefBinaryOp(Min, a < b ? a : b);
DefBinaryOp(EQ, a == b); DefBinaryOp(NE, a != b); DefBinaryOp(GT, a > b); DefBinaryOp(LT, a < b); DefBinaryOp(GE, a >= b); DefBinaryOp(LE, a <= b);
#pragma pop_macro("DefBinaryOp")
#pragma push_macro("DefTernaryOp")
#define DefTernaryOp(op, expr) template<class ElemType> DECL ElemType Op ## op(ElemType a, ElemType b, ElemType c) { return expr; }
DefTernaryOp(Cond, a ? b : c);
#pragma pop_macro("DefTernaryOp")
}}}
#pragma pop_macro("DECL")
#pragma pop_macro("TENSOR_OPS_DECL")

Просмотреть файл

@ -4,7 +4,7 @@
// </copyright>
//
// This implements the TensorView class, which is a layer around Matrix that reinterprets its content as a generic tensor.
// This implements the TensorView class, which is a layer around Matrix that reinterprets its content as a generic tensor. [fseide]
#pragma once
@ -56,26 +56,35 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// If beta == 0, c is not read out, i.e. it can be uninitialized or contain NaNs.
// -------------------------------------------------------------------
#pragma push_macro("DeclareUnaryTensorOp")
#define DeclareUnaryTensorOp(oper) \
void Do ## oper ## Of(ElemType beta, const TensorView & a, ElemType alpha) { DoUnaryOpOf(beta, a, alpha, ElementWiseOperator::op ## oper); }
DeclareUnaryTensorOp(Copy);
DeclareUnaryTensorOp(Negate); DeclareUnaryTensorOp(Not);
DeclareUnaryTensorOp(Abs);
DeclareUnaryTensorOp(Sigmoid); DeclareUnaryTensorOp(SigmoidDerivative); DeclareUnaryTensorOp(Tanh); DeclareUnaryTensorOp(Sqrt); DeclareUnaryTensorOp(Exp); DeclareUnaryTensorOp(Log); DeclareUnaryTensorOp(LinearRectifierDerivative); DeclareUnaryTensorOp(Cosine); DeclareUnaryTensorOp(NegativeSine);
DeclareUnaryTensorOp(SaturateBetaAlpha); DeclareUnaryTensorOp(SumAlpha); DeclareUnaryTensorOp(SubDifferenceToAlpha); DeclareUnaryTensorOp(SubDifferenceFromAlpha);
ForAllUnaryOps(DeclareUnaryTensorOp);
ForAllParameterizedUnaryOps(DeclareUnaryTensorOp);
//DeclareUnaryTensorOp(Copy);
//DeclareUnaryTensorOp(Negate); DeclareUnaryTensorOp(Not);
//DeclareUnaryTensorOp(Abs);
//DeclareUnaryTensorOp(Sigmoid); DeclareUnaryTensorOp(SigmoidDerivative); DeclareUnaryTensorOp(Tanh); DeclareUnaryTensorOp(Sqrt); DeclareUnaryTensorOp(Exp); DeclareUnaryTensorOp(Log); DeclareUnaryTensorOp(LinearRectifierDerivative); DeclareUnaryTensorOp(Cosine); DeclareUnaryTensorOp(NegativeSine);
//DeclareUnaryTensorOp(SaturateBetaAlpha); DeclareUnaryTensorOp(SumAlpha); DeclareUnaryTensorOp(SubDifferenceToAlpha); DeclareUnaryTensorOp(SubDifferenceFromAlpha);
#pragma pop_macro("DeclareUnaryTensorOp")
#pragma push_macro("DeclareBinaryTensorOp")
#define DeclareBinaryTensorOp(oper) \
void Do ## oper ## Of(ElemType beta, const TensorView & a, const TensorView & b, ElemType alpha) { DoBinaryOpOf(beta, a, b, alpha, ElementWiseOperator::op ## oper); }
DeclareBinaryTensorOp(Sum); DeclareBinaryTensorOp(Difference); DeclareBinaryTensorOp(ElementWiseProduct); DeclareBinaryTensorOp(ElementWiseQuotient);
DeclareBinaryTensorOp(LogSum); DeclareBinaryTensorOp(Max); DeclareBinaryTensorOp(Min);
DeclareBinaryTensorOp(EQ); DeclareBinaryTensorOp(NE); DeclareBinaryTensorOp(GT); DeclareBinaryTensorOp(LT); DeclareBinaryTensorOp(GE); DeclareBinaryTensorOp(LE);
ForAllBinaryOps(DeclareBinaryTensorOp);
//DeclareBinaryTensorOp(Sum); DeclareBinaryTensorOp(Difference); DeclareBinaryTensorOp(ElementWiseProduct); DeclareBinaryTensorOp(ElementWiseQuotient);
//DeclareBinaryTensorOp(LogSum); DeclareBinaryTensorOp(Max); DeclareBinaryTensorOp(Min);
//DeclareBinaryTensorOp(EQ); DeclareBinaryTensorOp(NE); DeclareBinaryTensorOp(GT); DeclareBinaryTensorOp(LT); DeclareBinaryTensorOp(GE); DeclareBinaryTensorOp(LE);
#pragma pop_macro("DeclareBinaryTensorOp")
#pragma push_macro("DeclareTernaryTensorOp")
#define DeclareTernaryTensorOp(oper) \
void Do ## oper ## Of(ElemType beta, const TensorView & a, const TensorView & b, const TensorView & c, ElemType alpha) { DoTernaryOpOf(beta, a, b, c, alpha, ElementWiseOperator::op ## oper); }
DeclareTernaryTensorOp(Cond);
ForAllTernaryOps(DeclareTernaryTensorOp);
#pragma pop_macro("DeclareTernaryTensorOp")
static void Test();