moved all tensor ops to a new header TensorOps.h so they can be shared between matrix types;
also moved the float/double-unified math overloads (e.g. exp_()) there, as well as additional typically needed functions such as Sigmoid()
This commit is contained in:
Родитель
625a877c29
Коммит
cb668a9378
|
@ -9154,7 +9154,7 @@ L
|
||||||
\begin_layout Standard
|
\begin_layout Standard
|
||||||
\begin_inset Formula
|
\begin_inset Formula
|
||||||
\begin{eqnarray}
|
\begin{eqnarray}
|
||||||
\alpha_{t}\left(i\right) & \leftarrow & h_{it}+logadd_{k}\left(\delta_{t-1}(k)+\eta a_{ki}\right)\\
|
\alpha_{t}\left(i\right) & \leftarrow & h_{it}+LogAdd{k}\left(\delta_{t-1}(k)+\eta a_{ki}\right)\\
|
||||||
\mathbf{\frac{\partial R}{\partial\delta_{t-1}(i)}} & \leftarrow & \sum_{j}\frac{\partial C_{logadd}}{\partial\delta_{t}(j)}\frac{\exp(\delta_{t-1}(i)+a_{i,j})}{\sum_{k}\exp(\delta_{t-1}(k)+a_{k,j})}\\
|
\mathbf{\frac{\partial R}{\partial\delta_{t-1}(i)}} & \leftarrow & \sum_{j}\frac{\partial C_{logadd}}{\partial\delta_{t}(j)}\frac{\exp(\delta_{t-1}(i)+a_{i,j})}{\sum_{k}\exp(\delta_{t-1}(k)+a_{k,j})}\\
|
||||||
\mathbf{\frac{\partial R}{\partial\delta_{T}(i)}} & \leftarrow & \frac{\exp(\delta_{T}(i))}{\sum_{k}\exp(\delta_{T}(k))}\\
|
\mathbf{\frac{\partial R}{\partial\delta_{T}(i)}} & \leftarrow & \frac{\exp(\delta_{T}(i))}{\sum_{k}\exp(\delta_{T}(k))}\\
|
||||||
\frac{\partial R}{\partial h_{t}(i)} & \leftarrow & l_{t}(i)-\frac{\partial R}{\partial\delta_{t}(i)}\\
|
\frac{\partial R}{\partial h_{t}(i)} & \leftarrow & l_{t}(i)-\frac{\partial R}{\partial\delta_{t}(i)}\\
|
||||||
|
|
|
@ -9,12 +9,13 @@
|
||||||
#include "stdafx.h"
|
#include "stdafx.h"
|
||||||
#include "Basics.h"
|
#include "Basics.h"
|
||||||
#include "File.h"
|
#include "File.h"
|
||||||
|
#include "CPUMatrix.h"
|
||||||
|
#include "TensorOps.h"
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include "CPUMatrix.h"
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
|
@ -4304,7 +4305,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
if (sample_id == 0)
|
if (sample_id == 0)
|
||||||
sample_prob = -sample_prob;
|
sample_prob = -sample_prob;
|
||||||
double score_noise = log_num_noise_samples + sample_prob;
|
double score_noise = log_num_noise_samples + sample_prob;
|
||||||
double z = logadd(score, score_noise);
|
double z = LogAdd(score, score_noise);
|
||||||
double logprob = score - z;
|
double logprob = score - z;
|
||||||
double logprob_noise = score_noise - z;
|
double logprob_noise = score_noise - z;
|
||||||
tmp(sample_id, instance_id) = (ElemType)-std::exp(logprob);
|
tmp(sample_id, instance_id) = (ElemType)-std::exp(logprob);
|
||||||
|
@ -5258,32 +5259,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
#pragma endregion Static BLAS Functions
|
#pragma endregion Static BLAS Functions
|
||||||
|
|
||||||
template<typename ElemType>
|
// 'double' version of LogAdd
|
||||||
ElemType logadd_(ElemType x, ElemType y)
|
double LogAddD(double x, double y) { return LogAdd(x, y); }
|
||||||
{
|
|
||||||
if (x < y)
|
|
||||||
{
|
|
||||||
ElemType temp = x; x = y; y = temp;
|
|
||||||
}
|
|
||||||
ElemType diff = y - x;
|
|
||||||
if (diff < (ElemType)MINLOGEXP)
|
|
||||||
{
|
|
||||||
return (x < (ElemType)LSMALL) ? (ElemType)LZERO : x;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
ElemType z = exp_(diff);
|
|
||||||
return x + log_((ElemType)1.0 + z);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
double logadd(double x, double y) { return logadd_(x, y); }
|
|
||||||
|
|
||||||
template<class ElemType>
|
template<class ElemType>
|
||||||
ElemType CPUMatrix<ElemType>::LogAddSumOfElements() const
|
ElemType CPUMatrix<ElemType>::LogAddSumOfElements() const
|
||||||
{
|
{
|
||||||
ElemType fAlpha = (ElemType)LZERO;
|
ElemType fAlpha = (ElemType)LZERO;
|
||||||
for (int k = 0; k < GetNumElements(); k++)
|
for (int k = 0; k < GetNumElements(); k++)
|
||||||
fAlpha = (ElemType) logadd(fAlpha, m_pArray[k]);
|
fAlpha = (ElemType) LogAddD(fAlpha, m_pArray[k]);
|
||||||
return fAlpha;
|
return fAlpha;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5330,7 +5314,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
fSum = (ElemType)LZERO;
|
fSum = (ElemType)LZERO;
|
||||||
for (int j = 0; j < iNumLab; j++)
|
for (int j = 0; j < iNumLab; j++)
|
||||||
{
|
{
|
||||||
fSum = (ElemType)logadd((double)fSum, alpha(j, t));
|
fSum = (ElemType)LogAddD(fSum, alpha(j, t));
|
||||||
}
|
}
|
||||||
|
|
||||||
fTmp = alpha(k, t) - fSum;
|
fTmp = alpha(k, t) - fSum;
|
||||||
|
@ -5343,10 +5327,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
fSum = (ElemType)LZERO;
|
fSum = (ElemType)LZERO;
|
||||||
for (int m = 0; m < iNumLab; m++)
|
for (int m = 0; m < iNumLab; m++)
|
||||||
{
|
{
|
||||||
fSum = (ElemType)logadd((double)fSum, alpha(m, t) + pair_scores(j, m));
|
fSum = (ElemType)LogAddD(fSum, alpha(m, t) + pair_scores(j, m));
|
||||||
}
|
}
|
||||||
|
|
||||||
fTmp = (ElemType)logadd(fTmp, beta(j, t + 1) + alpha(k, t) + pair_scores(j, k) - fSum);
|
fTmp = (ElemType)LogAddD(fTmp, beta(j, t + 1) + alpha(k, t) + pair_scores(j, k) - fSum);
|
||||||
}
|
}
|
||||||
beta(k, t) = fTmp;
|
beta(k, t) = fTmp;
|
||||||
}
|
}
|
||||||
|
@ -5455,7 +5439,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
else{
|
else{
|
||||||
fTmp2 = a(k, 0);
|
fTmp2 = a(k, 0);
|
||||||
}
|
}
|
||||||
fSum = (ElemType)logadd(fSum, fTmp2 + pair_scores(j, k));
|
fSum = (ElemType)LogAddD(fSum, fTmp2 + pair_scores(j, k));
|
||||||
}
|
}
|
||||||
|
|
||||||
fTmp -= fSum;
|
fTmp -= fSum;
|
||||||
|
@ -5537,6 +5521,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
// TensorView support
|
// TensorView support
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// To save time, this makes extensive use of templates and macros.
|
||||||
|
|
||||||
// perform loop over reduction index m
|
// perform loop over reduction index m
|
||||||
// This function is declared inside a wrapper struct to allow partial specialization (m = -1).
|
// This function is declared inside a wrapper struct to allow partial specialization (m = -1).
|
||||||
template<class ElemType, size_t N, typename OPFN, int m>
|
template<class ElemType, size_t N, typename OPFN, int m>
|
||||||
|
@ -5654,43 +5640,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class ElemType>
|
|
||||||
static inline ElemType Sigmoid(ElemType z)
|
|
||||||
{
|
|
||||||
if (z >= 0)
|
|
||||||
return 1 / (1 + exp_(-z));
|
|
||||||
else
|
|
||||||
{
|
|
||||||
ElemType v = exp_(z);
|
|
||||||
return v / (1 + v);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template<class ElemType>
|
|
||||||
static inline ElemType SigmoidDerivative(ElemType z)
|
|
||||||
{
|
|
||||||
ElemType v = Sigmoid(z);
|
|
||||||
return v * (1 - v);
|
|
||||||
}
|
|
||||||
template<class ElemType>
|
|
||||||
static inline ElemType LinearRectifierDerivative(ElemType z)
|
|
||||||
{
|
|
||||||
return z > 0 ? (ElemType)1 : 0;
|
|
||||||
}
|
|
||||||
template<class ElemType>
|
|
||||||
static inline ElemType Sqrt(ElemType z)
|
|
||||||
{
|
|
||||||
return sqrt_(max(0, z));
|
|
||||||
}
|
|
||||||
|
|
||||||
// define a static function for every operation
|
|
||||||
#define DefUnaryOp(op, expr) template<class ElemType> static inline ElemType Op ## op(ElemType a) { return expr; }
|
|
||||||
|
|
||||||
DefUnaryOp(Copy, a);
|
|
||||||
DefUnaryOp(Negate, -a); DefUnaryOp(Not, !a);
|
|
||||||
DefUnaryOp(Abs, fabs_(a));
|
|
||||||
DefUnaryOp(Sigmoid, Sigmoid(a)); DefUnaryOp(SigmoidDerivative, SigmoidDerivative(a)); DefUnaryOp(Tanh, tanh_(a)); DefUnaryOp(Sqrt, Sqrt(a)); DefUnaryOp(Exp, exp_(a)); DefUnaryOp(Log, log_(a)); DefUnaryOp(LinearRectifierDerivative, LinearRectifierDerivative(a)); DefUnaryOp(Cosine, cos_(a)); DefUnaryOp(NegativeSine, -sin_(a));
|
|
||||||
//DefUnaryOp(SaturateBetaAlpha); DefUnaryOp(SumAlpha); DefUnaryOp(SubDifferenceToAlpha); DefUnaryOp(SubDifferenceFromAlpha);
|
|
||||||
|
|
||||||
// perform unary operation 'op' on a giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
|
// perform unary operation 'op' on a giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
|
||||||
// This maps 'op' to a lambda.
|
// This maps 'op' to a lambda.
|
||||||
template<class ElemType>
|
template<class ElemType>
|
||||||
|
@ -5699,29 +5648,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 2> & regularStrides,
|
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 2> & regularStrides,
|
||||||
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 2> & reducingStrides)
|
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 2> & reducingStrides)
|
||||||
{
|
{
|
||||||
|
#define CaseUnaryTensorOp(oper) \
|
||||||
|
case ElementWiseOperator::op ## oper: \
|
||||||
|
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 2> & pp) { return Op ## oper((*(pp[0]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
|
||||||
|
|
||||||
array<ElemType*, 2> pointers = { a.m_pArray, m_pArray };
|
array<ElemType*, 2> pointers = { a.m_pArray, m_pArray };
|
||||||
#define CaseUnaryTensorOp(oper) \
|
|
||||||
case ElementWiseOperator::op ## oper: \
|
|
||||||
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 2> & pp) { return Op ## oper((*(pp[0]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
|
|
||||||
switch (op)
|
switch (op)
|
||||||
{
|
{
|
||||||
CaseUnaryTensorOp(Copy);
|
ForAllUnaryOps(CaseUnaryTensorOp);
|
||||||
CaseUnaryTensorOp(Negate); CaseUnaryTensorOp(Not);
|
|
||||||
CaseUnaryTensorOp(Abs);
|
|
||||||
CaseUnaryTensorOp(Sigmoid); CaseUnaryTensorOp(SigmoidDerivative); CaseUnaryTensorOp(Tanh); CaseUnaryTensorOp(Sqrt); CaseUnaryTensorOp(Exp); CaseUnaryTensorOp(Log); CaseUnaryTensorOp(LinearRectifierDerivative); CaseUnaryTensorOp(Cosine); CaseUnaryTensorOp(NegativeSine);
|
|
||||||
// functions with lambda arguments--these are different
|
|
||||||
//CaseUnaryTensorOp(SaturateBetaAlpha); CaseUnaryTensorOp(SumAlpha); CaseUnaryTensorOp(SubDifferenceToAlpha); CaseUnaryTensorOp(SubDifferenceFromAlpha);
|
|
||||||
default: LogicError("TensorUnaryOp: Unknown op code %d.", (int)op);
|
default: LogicError("TensorUnaryOp: Unknown op code %d.", (int)op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// define a static function for every operation
|
|
||||||
#define DefBinaryOp(op, expr) template<class ElemType> static inline ElemType Op ## op(ElemType a, ElemType b) { return expr; }
|
|
||||||
|
|
||||||
DefBinaryOp(Sum, a + b); DefBinaryOp(Difference, a - b); DefBinaryOp(ElementWiseProduct, a*b); DefBinaryOp(ElementWiseQuotient, a / b);
|
|
||||||
DefBinaryOp(LogSum, logadd_(a, b)); DefBinaryOp(Max, a > b ? a : b); DefBinaryOp(Min, a < b ? a : b);
|
|
||||||
DefBinaryOp(EQ, a == b); DefBinaryOp(NE, a != b); DefBinaryOp(GT, a > b); DefBinaryOp(LT, a < b); DefBinaryOp(GE, a >= b); DefBinaryOp(LE, a <= b);
|
|
||||||
|
|
||||||
// perform binary operation 'op' on a and b giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
|
// perform binary operation 'op' on a and b giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
|
||||||
// This maps 'op' to a lambda.
|
// This maps 'op' to a lambda.
|
||||||
template<class ElemType>
|
template<class ElemType>
|
||||||
|
@ -5730,24 +5668,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 3> & regularStrides,
|
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 3> & regularStrides,
|
||||||
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 3> & reducingStrides)
|
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 3> & reducingStrides)
|
||||||
{
|
{
|
||||||
|
#define CaseBinaryTensorOp(oper) \
|
||||||
|
case ElementWiseOperator::op ## oper: \
|
||||||
|
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 3> & pp) { return Op ## oper((*(pp[0])), (*(pp[1]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
|
||||||
|
|
||||||
array<ElemType*, 3> pointers = { a.m_pArray, b.m_pArray, m_pArray };
|
array<ElemType*, 3> pointers = { a.m_pArray, b.m_pArray, m_pArray };
|
||||||
#define CaseBinaryTensorOp(oper) \
|
|
||||||
case ElementWiseOperator::op ## oper: \
|
|
||||||
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 3> & pp) { return Op ## oper((*(pp[0])), (*(pp[1]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
|
|
||||||
switch (op)
|
switch (op)
|
||||||
{
|
{
|
||||||
CaseBinaryTensorOp(Sum); CaseBinaryTensorOp(Difference); CaseBinaryTensorOp(ElementWiseProduct); CaseBinaryTensorOp(ElementWiseQuotient);
|
ForAllBinaryOps(CaseBinaryTensorOp);
|
||||||
CaseBinaryTensorOp(LogSum); CaseBinaryTensorOp(Max); CaseBinaryTensorOp(Min);
|
|
||||||
CaseBinaryTensorOp(EQ); CaseBinaryTensorOp(NE); CaseBinaryTensorOp(GT); CaseBinaryTensorOp(LT); CaseBinaryTensorOp(GE); CaseBinaryTensorOp(LE);
|
|
||||||
default: LogicError("TensorBinaryOp: Unknown op code %d.", (int)op);
|
default: LogicError("TensorBinaryOp: Unknown op code %d.", (int)op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// define a static function for every operation
|
|
||||||
#define DefTernaryOp(op, expr) template<class ElemType> static inline ElemType Op ## op(ElemType a, ElemType b, ElemType c) { return expr; }
|
|
||||||
|
|
||||||
DefTernaryOp(Cond, a ? b : c);
|
|
||||||
|
|
||||||
// perform ternary operation 'op' on a, and c giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
|
// perform ternary operation 'op' on a, and c giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
|
||||||
// This maps 'op' to a lambda.
|
// This maps 'op' to a lambda.
|
||||||
template<class ElemType>
|
template<class ElemType>
|
||||||
|
@ -5756,18 +5688,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 4> & regularStrides,
|
const std::vector<size_t> & regularOpDims, const std::array<std::vector<ptrdiff_t>, 4> & regularStrides,
|
||||||
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 4> & reducingStrides)
|
const std::vector<size_t> & reducingOpDims, const std::array<std::vector<ptrdiff_t>, 4> & reducingStrides)
|
||||||
{
|
{
|
||||||
|
#define CaseTernaryTensorOp(oper) \
|
||||||
|
case ElementWiseOperator::op ## oper: \
|
||||||
|
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 4> & pp) { return Op ## oper((*(pp[0])), (*(pp[1])), (*(pp[2]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
|
||||||
|
|
||||||
array<ElemType*, 4> pointers = { a.m_pArray, b.m_pArray, c.m_pArray, m_pArray };
|
array<ElemType*, 4> pointers = { a.m_pArray, b.m_pArray, c.m_pArray, m_pArray };
|
||||||
#define CaseTernaryTensorOp(oper) \
|
|
||||||
case ElementWiseOperator::op ## oper: \
|
|
||||||
return TensorOpWithFn(beta, pointers, alpha, [](const array<ElemType*, 4> & pp) { return Op ## oper((*(pp[0])), (*(pp[1])), (*(pp[2]))); }, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides)
|
|
||||||
switch (op)
|
switch (op)
|
||||||
{
|
{
|
||||||
CaseTernaryTensorOp(Cond);
|
ForAllTernaryOps(CaseTernaryTensorOp);
|
||||||
default: LogicError("TensorTernaryOp: Unknown op code %d.", (int)op);
|
default: LogicError("TensorTernaryOp: Unknown op code %d.", (int)op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The explicit instantiation part
|
// -----------------------------------------------------------------------
|
||||||
|
// explicit instantiations
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
template class MATH_API CPUMatrix<float>;
|
template class MATH_API CPUMatrix<float>;
|
||||||
template class MATH_API CPUMatrix<double>;
|
template class MATH_API CPUMatrix<double>;
|
||||||
|
|
||||||
|
|
|
@ -64,15 +64,23 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
// Note: not all of the above are actually implement at present; and not all that's implemented has an opcode.
|
// Note: not all of the above are actually implement at present; and not all that's implemented has an opcode.
|
||||||
};
|
};
|
||||||
|
|
||||||
// declare float and double versions of a func under f_
|
// helper to apply a C macro for all operations of each kind
|
||||||
// e.g. exp_ -> exp(double), expf(float)
|
#define ForAllUnaryOps(Macro) \
|
||||||
#define OverloadUnaryMathFns(func) \
|
Macro(Copy); \
|
||||||
static inline float func ## _(float arg) { return func ## f(arg); } \
|
Macro(Negate); Macro(Not); \
|
||||||
static inline double func ## _(double arg) { return func(arg); }
|
Macro(Abs); \
|
||||||
|
Macro(Sigmoid); Macro(SigmoidDerivative); Macro(Tanh); Macro(Sqrt); Macro(Exp); Macro(Log); Macro(LinearRectifierDerivative); Macro(Cosine); Macro(NegativeSine);
|
||||||
|
|
||||||
OverloadUnaryMathFns(fabs); OverloadUnaryMathFns(sqrt);
|
#define ForAllParameterizedUnaryOps(Macro) \
|
||||||
OverloadUnaryMathFns(exp); OverloadUnaryMathFns(log);
|
Macro(SaturateBetaAlpha); Macro(SumAlpha); Macro(SubDifferenceToAlpha); Macro(SubDifferenceFromAlpha);
|
||||||
OverloadUnaryMathFns(tanh); OverloadUnaryMathFns(cos); OverloadUnaryMathFns(sin);
|
|
||||||
|
#define ForAllBinaryOps(Macro) \
|
||||||
|
Macro(Sum); Macro(Difference); Macro(ElementWiseProduct); Macro(ElementWiseQuotient); \
|
||||||
|
Macro(LogSum); Macro(Max); Macro(Min); \
|
||||||
|
Macro(EQ); Macro(NE); Macro(GT); Macro(LT); Macro(GE); Macro(LE);
|
||||||
|
|
||||||
|
#define ForAllTernaryOps(Macro) \
|
||||||
|
Macro(Cond);
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// various enums to describe
|
// various enums to describe
|
||||||
|
|
|
@ -162,6 +162,7 @@
|
||||||
<ClInclude Include="CommonMatrix.h" />
|
<ClInclude Include="CommonMatrix.h" />
|
||||||
<ClInclude Include="ConvolutionEngine.h" />
|
<ClInclude Include="ConvolutionEngine.h" />
|
||||||
<ClInclude Include="CPUMatrix.h" />
|
<ClInclude Include="CPUMatrix.h" />
|
||||||
|
<ClInclude Include="TensorOps.h" />
|
||||||
<ClInclude Include="TensorView.h" />
|
<ClInclude Include="TensorView.h" />
|
||||||
<None Include="ClassDiagram.cd" />
|
<None Include="ClassDiagram.cd" />
|
||||||
<None Include="GPUWatcher.cu" />
|
<None Include="GPUWatcher.cu" />
|
||||||
|
|
|
@ -70,6 +70,9 @@
|
||||||
<ClInclude Include="TensorView.h">
|
<ClInclude Include="TensorView.h">
|
||||||
<Filter>Tensors</Filter>
|
<Filter>Tensors</Filter>
|
||||||
</ClInclude>
|
</ClInclude>
|
||||||
|
<ClInclude Include="TensorOps.h">
|
||||||
|
<Filter>Tensors</Filter>
|
||||||
|
</ClInclude>
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<None Include="GPUMatrix.h">
|
<None Include="GPUMatrix.h">
|
||||||
|
|
|
@ -4887,6 +4887,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
return x - y * floor(x / y);
|
return x - y * floor(x / y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use static LogAdd() as defined in TensorOps.h
|
||||||
|
// Not doing this currently because that one uses ElemType for all ops, while this one uses double inside. Must compare before making this change.
|
||||||
template<class ElemType>
|
template<class ElemType>
|
||||||
ElemType Matrix<ElemType>::LogAdd(ElemType x, ElemType y)
|
ElemType Matrix<ElemType>::LogAdd(ElemType x, ElemType y)
|
||||||
{
|
{
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
//
|
||||||
|
// <copyright file="TensorView.h" company="Microsoft">
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// </copyright>
|
||||||
|
//
|
||||||
|
|
||||||
|
// This implements the elementwise tensor operations, including helper macros and some actual functions.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Basics.h"
|
||||||
|
#include "CommonMatrix.h"
|
||||||
|
|
||||||
|
#pragma push_macro("TENSOR_OPS_DECL")
|
||||||
|
#ifndef TENSOR_OPS_DECL // to make these accessible to CUDA kernels, say '#define TENSOR_OPS_DECL __device__ __host__'
|
||||||
|
#define TENSOR_OPS_DECL
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#pragma push_macro("DECL")
|
||||||
|
#define DECL static inline TENSOR_OPS_DECL
|
||||||
|
|
||||||
|
// This class is exported from the Math.dll.
|
||||||
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// unified overloads for float/double math functions
|
||||||
|
//
|
||||||
|
// Declare float and double versions of the functions f we need as f_(),
|
||||||
|
// e.g. exp_ -> exp(double), expf(float).
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
#pragma push_macro("OverloadUnaryMathFns")
|
||||||
|
#define OverloadUnaryMathFns(func) \
|
||||||
|
DECL float func ## _(float arg) { return func ## f(arg); } \
|
||||||
|
DECL double func ## _(double arg) { return func(arg); }
|
||||||
|
|
||||||
|
OverloadUnaryMathFns(fabs); OverloadUnaryMathFns(sqrt);
|
||||||
|
OverloadUnaryMathFns(exp); OverloadUnaryMathFns(log);
|
||||||
|
OverloadUnaryMathFns(tanh); OverloadUnaryMathFns(cos); OverloadUnaryMathFns(sin);
|
||||||
|
#pragma push_macro("OverloadUnaryMathFns")
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// additional functions that are standard in our context
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
template<class ElemType>
|
||||||
|
DECL ElemType Sigmoid(ElemType z)
|
||||||
|
{
|
||||||
|
if (z >= 0)
|
||||||
|
return 1 / (1 + exp_(-z));
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ElemType v = exp_(z);
|
||||||
|
return v / (1 + v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ElemType>
|
||||||
|
DECL ElemType SigmoidDerivative(ElemType z)
|
||||||
|
{
|
||||||
|
ElemType v = Sigmoid(z);
|
||||||
|
return v * (1 - v);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ElemType>
|
||||||
|
DECL ElemType LinearRectifierDerivative(ElemType z)
|
||||||
|
{
|
||||||
|
return z > 0 ? (ElemType)1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class ElemType>
|
||||||
|
DECL ElemType Sqrt(ElemType z)
|
||||||
|
{
|
||||||
|
// BUGBUG: Why clip to 0? An invalid sqrt() should show up as a NaN in the result, instead of hiding it.
|
||||||
|
return sqrt_(z > 0 ? z : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: call this LogAdd() for consistency
|
||||||
|
template<typename ElemType>
|
||||||
|
DECL ElemType LogAdd(ElemType x, ElemType y)
|
||||||
|
{
|
||||||
|
if (x < y)
|
||||||
|
{
|
||||||
|
ElemType temp = x; x = y; y = temp;
|
||||||
|
}
|
||||||
|
ElemType diff = y - x;
|
||||||
|
if (diff < (ElemType)MINLOGEXP)
|
||||||
|
{
|
||||||
|
return (x < (ElemType)LSMALL) ? (ElemType)LZERO : x;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ElemType z = exp_(diff);
|
||||||
|
return x + log_((ElemType)1.0 + z);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// ElementWiseOperator implementations
|
||||||
|
//
|
||||||
|
// Define a static function for every ElementWiseOperator (CommonMatrix.h).
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
#pragma push_macro("DefUnaryOp")
|
||||||
|
#define DefUnaryOp(op, expr) template<class ElemType> DECL ElemType Op ## op(ElemType a) { return expr; }
|
||||||
|
|
||||||
|
DefUnaryOp(Copy, a);
|
||||||
|
DefUnaryOp(Negate, -a); DefUnaryOp(Not, !a);
|
||||||
|
DefUnaryOp(Abs, fabs_(a));
|
||||||
|
DefUnaryOp(Sigmoid, Sigmoid(a)); DefUnaryOp(SigmoidDerivative, SigmoidDerivative(a)); DefUnaryOp(Tanh, tanh_(a)); DefUnaryOp(Sqrt, Sqrt(a)); DefUnaryOp(Exp, exp_(a)); DefUnaryOp(Log, log_(a)); DefUnaryOp(LinearRectifierDerivative, LinearRectifierDerivative(a)); DefUnaryOp(Cosine, cos_(a)); DefUnaryOp(NegativeSine, -sin_(a));
|
||||||
|
#pragma pop_macro("DefUnaryOp")
|
||||||
|
|
||||||
|
// parameterized unary ops
|
||||||
|
//DefUnaryOp(SaturateBetaAlpha); DefUnaryOp(SumAlpha); DefUnaryOp(SubDifferenceToAlpha); DefUnaryOp(SubDifferenceFromAlpha);
|
||||||
|
|
||||||
|
#pragma push_macro("DefBinaryOp")
|
||||||
|
#define DefBinaryOp(op, expr) template<class ElemType> DECL ElemType Op ## op(ElemType a, ElemType b) { return expr; }
|
||||||
|
|
||||||
|
DefBinaryOp(Sum, a + b); DefBinaryOp(Difference, a - b); DefBinaryOp(ElementWiseProduct, a*b); DefBinaryOp(ElementWiseQuotient, a / b);
|
||||||
|
DefBinaryOp(LogSum, LogAdd(a, b)); DefBinaryOp(Max, a > b ? a : b); DefBinaryOp(Min, a < b ? a : b);
|
||||||
|
DefBinaryOp(EQ, a == b); DefBinaryOp(NE, a != b); DefBinaryOp(GT, a > b); DefBinaryOp(LT, a < b); DefBinaryOp(GE, a >= b); DefBinaryOp(LE, a <= b);
|
||||||
|
#pragma pop_macro("DefBinaryOp")
|
||||||
|
|
||||||
|
#pragma push_macro("DefTernaryOp")
|
||||||
|
#define DefTernaryOp(op, expr) template<class ElemType> DECL ElemType Op ## op(ElemType a, ElemType b, ElemType c) { return expr; }
|
||||||
|
|
||||||
|
DefTernaryOp(Cond, a ? b : c);
|
||||||
|
#pragma pop_macro("DefTernaryOp")
|
||||||
|
|
||||||
|
}}}
|
||||||
|
#pragma pop_macro("DECL")
|
||||||
|
#pragma pop_macro("TENSOR_OPS_DECL")
|
|
@ -4,7 +4,7 @@
|
||||||
// </copyright>
|
// </copyright>
|
||||||
//
|
//
|
||||||
|
|
||||||
// This implements the TensorView class, which is a layer around Matrix that reinterprets its content as a generic tensor.
|
// This implements the TensorView class, which is a layer around Matrix that reinterprets its content as a generic tensor. [fseide]
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
@ -56,26 +56,35 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
// If beta == 0, c is not read out, i.e. it can be uninitialized or contain NaNs.
|
// If beta == 0, c is not read out, i.e. it can be uninitialized or contain NaNs.
|
||||||
// -------------------------------------------------------------------
|
// -------------------------------------------------------------------
|
||||||
|
|
||||||
|
#pragma push_macro("DeclareUnaryTensorOp")
|
||||||
#define DeclareUnaryTensorOp(oper) \
|
#define DeclareUnaryTensorOp(oper) \
|
||||||
void Do ## oper ## Of(ElemType beta, const TensorView & a, ElemType alpha) { DoUnaryOpOf(beta, a, alpha, ElementWiseOperator::op ## oper); }
|
void Do ## oper ## Of(ElemType beta, const TensorView & a, ElemType alpha) { DoUnaryOpOf(beta, a, alpha, ElementWiseOperator::op ## oper); }
|
||||||
|
|
||||||
DeclareUnaryTensorOp(Copy);
|
ForAllUnaryOps(DeclareUnaryTensorOp);
|
||||||
DeclareUnaryTensorOp(Negate); DeclareUnaryTensorOp(Not);
|
ForAllParameterizedUnaryOps(DeclareUnaryTensorOp);
|
||||||
DeclareUnaryTensorOp(Abs);
|
//DeclareUnaryTensorOp(Copy);
|
||||||
DeclareUnaryTensorOp(Sigmoid); DeclareUnaryTensorOp(SigmoidDerivative); DeclareUnaryTensorOp(Tanh); DeclareUnaryTensorOp(Sqrt); DeclareUnaryTensorOp(Exp); DeclareUnaryTensorOp(Log); DeclareUnaryTensorOp(LinearRectifierDerivative); DeclareUnaryTensorOp(Cosine); DeclareUnaryTensorOp(NegativeSine);
|
//DeclareUnaryTensorOp(Negate); DeclareUnaryTensorOp(Not);
|
||||||
DeclareUnaryTensorOp(SaturateBetaAlpha); DeclareUnaryTensorOp(SumAlpha); DeclareUnaryTensorOp(SubDifferenceToAlpha); DeclareUnaryTensorOp(SubDifferenceFromAlpha);
|
//DeclareUnaryTensorOp(Abs);
|
||||||
|
//DeclareUnaryTensorOp(Sigmoid); DeclareUnaryTensorOp(SigmoidDerivative); DeclareUnaryTensorOp(Tanh); DeclareUnaryTensorOp(Sqrt); DeclareUnaryTensorOp(Exp); DeclareUnaryTensorOp(Log); DeclareUnaryTensorOp(LinearRectifierDerivative); DeclareUnaryTensorOp(Cosine); DeclareUnaryTensorOp(NegativeSine);
|
||||||
|
//DeclareUnaryTensorOp(SaturateBetaAlpha); DeclareUnaryTensorOp(SumAlpha); DeclareUnaryTensorOp(SubDifferenceToAlpha); DeclareUnaryTensorOp(SubDifferenceFromAlpha);
|
||||||
|
#pragma pop_macro("DeclareUnaryTensorOp")
|
||||||
|
|
||||||
|
#pragma push_macro("DeclareBinaryTensorOp")
|
||||||
#define DeclareBinaryTensorOp(oper) \
|
#define DeclareBinaryTensorOp(oper) \
|
||||||
void Do ## oper ## Of(ElemType beta, const TensorView & a, const TensorView & b, ElemType alpha) { DoBinaryOpOf(beta, a, b, alpha, ElementWiseOperator::op ## oper); }
|
void Do ## oper ## Of(ElemType beta, const TensorView & a, const TensorView & b, ElemType alpha) { DoBinaryOpOf(beta, a, b, alpha, ElementWiseOperator::op ## oper); }
|
||||||
|
|
||||||
DeclareBinaryTensorOp(Sum); DeclareBinaryTensorOp(Difference); DeclareBinaryTensorOp(ElementWiseProduct); DeclareBinaryTensorOp(ElementWiseQuotient);
|
ForAllBinaryOps(DeclareBinaryTensorOp);
|
||||||
DeclareBinaryTensorOp(LogSum); DeclareBinaryTensorOp(Max); DeclareBinaryTensorOp(Min);
|
//DeclareBinaryTensorOp(Sum); DeclareBinaryTensorOp(Difference); DeclareBinaryTensorOp(ElementWiseProduct); DeclareBinaryTensorOp(ElementWiseQuotient);
|
||||||
DeclareBinaryTensorOp(EQ); DeclareBinaryTensorOp(NE); DeclareBinaryTensorOp(GT); DeclareBinaryTensorOp(LT); DeclareBinaryTensorOp(GE); DeclareBinaryTensorOp(LE);
|
//DeclareBinaryTensorOp(LogSum); DeclareBinaryTensorOp(Max); DeclareBinaryTensorOp(Min);
|
||||||
|
//DeclareBinaryTensorOp(EQ); DeclareBinaryTensorOp(NE); DeclareBinaryTensorOp(GT); DeclareBinaryTensorOp(LT); DeclareBinaryTensorOp(GE); DeclareBinaryTensorOp(LE);
|
||||||
|
#pragma pop_macro("DeclareBinaryTensorOp")
|
||||||
|
|
||||||
|
#pragma push_macro("DeclareTernaryTensorOp")
|
||||||
#define DeclareTernaryTensorOp(oper) \
|
#define DeclareTernaryTensorOp(oper) \
|
||||||
void Do ## oper ## Of(ElemType beta, const TensorView & a, const TensorView & b, const TensorView & c, ElemType alpha) { DoTernaryOpOf(beta, a, b, c, alpha, ElementWiseOperator::op ## oper); }
|
void Do ## oper ## Of(ElemType beta, const TensorView & a, const TensorView & b, const TensorView & c, ElemType alpha) { DoTernaryOpOf(beta, a, b, c, alpha, ElementWiseOperator::op ## oper); }
|
||||||
|
|
||||||
DeclareTernaryTensorOp(Cond);
|
ForAllTernaryOps(DeclareTernaryTensorOp);
|
||||||
|
#pragma pop_macro("DeclareTernaryTensorOp")
|
||||||
|
|
||||||
static void Test();
|
static void Test();
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче