merged GPUMatrixCUDAKernels.cuh DEF_ELEMENT_PRIMITIVE macro with TensorOps.h OverloadUnaryMathFns, as both did the same thing;
new #define ENABLE_BROADCASTING_ELEMENTTIMES to specifically select whether we want to replace ScaleNode etc with ElementTimesNode
This commit is contained in:
Родитель
42e39e218d
Коммит
526615abbf
|
@ -1050,7 +1050,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
ComputationNodePtr scalar = builder.CreateLearnableParameter(msra::strfun::wstrprintf(L"SV%d", i), 1, 1);
|
||||
scalar->Value().SetValue((ElemType)0.01);
|
||||
#ifndef ENABLE_TENSORVIEW
|
||||
#ifndef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
ComputationNodePtr scaled = builder.Scale(scalar, directOutput, msra::strfun::wstrprintf(L"S%d", i));
|
||||
#else
|
||||
ComputationNodePtr scaled = builder.ElementTimes(scalar, directOutput, msra::strfun::wstrprintf(L"S%d", i));
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// please keep this table sorted
|
||||
if (nodeType == OperationNameOf(CRFNode)) return New<CRFNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode))return New<ClassBasedCrossEntropyWithSoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
#ifdef ENABLE_TENSORVIEW
|
||||
#ifdef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
else if (nodeType == L"ColumnElementTimes") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
#else
|
||||
else if (nodeType == OperationNameOf(ColumnElementTimesNode)) return New<ColumnElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
@ -76,7 +76,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
else if (nodeType == OperationNameOf(ReconcileMBLayoutNode)) return New<ReconcileMBLayoutNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(RectifiedLinearNode)) return New<RectifiedLinearNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(ReshapeNode)) return New<ReshapeNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
#ifdef ENABLE_TENSORVIEW
|
||||
#ifdef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
else if (nodeType == L"RowElementTimes") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
#else
|
||||
else if (nodeType == OperationNameOf(RowElementTimesNode)) return New<RowElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
@ -85,7 +85,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
else if (nodeType == OperationNameOf(DiagonalNode)) return New<DiagonalNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(RowSliceNode)) return New<RowSliceNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(RowStackNode)) return New<RowStackNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
#ifdef ENABLE_TENSORVIEW
|
||||
#ifdef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
else if (nodeType == L"Scale") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
#else
|
||||
else if (nodeType == OperationNameOf(ScaleNode)) return New<ScaleNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
@ -486,7 +486,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return net.AddNodeToNetAndAttachInputs(New<SumElementsNode<ElemType>>(net.GetDeviceId(), nodeName), a);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_TENSORVIEW
|
||||
#ifndef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
template<class ElemType> shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Scale(const ComputationNodePtr scalar, const ComputationNodePtr matrix, const std::wstring nodeName)
|
||||
{
|
||||
return net.AddNodeToNetAndAttachInputs(New<ScaleNode<ElemType>>(net.GetDeviceId(), nodeName), scalar, matrix);
|
||||
|
@ -513,7 +513,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return net.AddNodeToNetAndAttachInputs(New<ElementTimesNode<ElemType>>(net.GetDeviceId(), nodeName), a, b);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_TENSORVIEW
|
||||
#ifndef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
template<class ElemType> shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::RowElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName)
|
||||
{
|
||||
return net.AddNodeToNetAndAttachInputs(New<RowElementTimesNode<ElemType>>(net.GetDeviceId(), nodeName), a, b);
|
||||
|
|
|
@ -111,14 +111,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
ComputationNodePtr Hardmax(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr LogSoftmax(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
#ifndef ENABLE_TENSORVIEW
|
||||
#ifndef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
ComputationNodePtr Scale(const ComputationNodePtr scalar, const ComputationNodePtr matrix, const std::wstring nodeName = L"");
|
||||
#endif
|
||||
ComputationNodePtr Transpose(const ComputationNodePtr matrix, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr Times(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr TransposeTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr ElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
#ifndef ENABLE_TENSORVIEW
|
||||
#ifndef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
ComputationNodePtr RowElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr ColumnElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
#endif
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <iostream>
|
||||
|
||||
#define ENABLE_TENSORVIEW // flip this switch once the tensor lib is confirmed to be working
|
||||
// #define ENABLE_BROADCASTING_ELEMENTTIMES // if set then ScaleNode and Row/ColumnElementTimes are redirected to ElementTimes
|
||||
|
||||
#define DEFAULT_HIDDEN_ACTIVATION 0.1
|
||||
|
||||
|
|
|
@ -304,7 +304,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
template class MinusNode<float>;
|
||||
template class MinusNode<double>;
|
||||
|
||||
#ifndef ENABLE_TENSORVIEW
|
||||
#ifndef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
// -----------------------------------------------------------------------
|
||||
// ScaleNode (scalar scaling factor, matrix)
|
||||
//
|
||||
|
@ -742,7 +742,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
template class ElementTimesNode<float>;
|
||||
template class ElementTimesNode<double>;
|
||||
|
||||
#ifndef ENABLE_TENSORVIEW
|
||||
#ifndef ENABLE_BROADCASTING_ELEMENTTIMES
|
||||
// -----------------------------------------------------------------------
|
||||
// RowElementTimesNode (left, right) --TODO: what are left and right?
|
||||
//
|
||||
|
|
|
@ -39,7 +39,7 @@
|
|||
|
||||
MATH_API DEVICEID_TYPE EnforceOneGPUOnly(DEVICEID_TYPE requestedDeviceId);
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ElementWiseOperator -- This enum represents which function to apply.
|
||||
|
|
|
@ -10,12 +10,16 @@
|
|||
|
||||
#ifndef CPUONLY
|
||||
|
||||
#pragma push_macro("TENSOR_OPS_DECL")
|
||||
#define TENSOR_OPS_DECL __device__ __host__
|
||||
#include "CommonMatrix.h"
|
||||
#include "GPUMatrix.h"
|
||||
#include "TensorOps.h" // for exp_() etc.
|
||||
#include "device_functions.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <assert.h>
|
||||
#include <float.h>
|
||||
#pragma pop_macro("TENSOR_OPS_DECL")
|
||||
|
||||
// REVIEW alexeyk: disable warnings properly for GCC/clang
|
||||
#ifdef _MSC_VER
|
||||
|
@ -39,6 +43,43 @@
|
|||
|
||||
#define IDX2C(i,j,ld) (((j)*(ld))+(i)) // 0 based indexing
|
||||
|
||||
// CUDA atomicAdd() only exists for 'float'. This is the 'double' version.
|
||||
static __inline__ __device__ double atomicAdd(double* address, double val)
|
||||
{
|
||||
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
} while (assumed != old);
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
|
||||
// TODO: replace this with TensorOps.h LogAdd(). It differs in using ElemType throughout, while this one seems to use 'double' versions of exp() and log().
|
||||
// The 'k' in the name is to avoid naming conflicts with various versions of logadd() that are defined throughout the codebase.
|
||||
template<class ElemType>
|
||||
static inline __device__ __host__ ElemType logaddk(ElemType x, ElemType y)
|
||||
{
|
||||
ElemType temp, diff, z;
|
||||
|
||||
if (x < y)
|
||||
{
|
||||
temp = x; x = y; y = temp;
|
||||
}
|
||||
diff = y - x;
|
||||
if (diff < MINLOGEXP)
|
||||
{
|
||||
return (x < LSMALL) ? LZERO : x;
|
||||
}
|
||||
else
|
||||
{
|
||||
z = exp(diff);
|
||||
return x + log(1.0 + z);
|
||||
}
|
||||
}
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GridDim -- helper to choose the CUDA grid dimensions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
@ -49,7 +90,6 @@ static INT CeilDiv(INT a, INT2 b)
|
|||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
// TODO: move the computation of 'id' here as well
|
||||
struct GridDim
|
||||
{
|
||||
static const CUDA_LONG maxThreadsPerBlock = 512; // use this many threads per block
|
||||
|
@ -124,9 +164,6 @@ struct GridDim
|
|||
#define UNUSED_FUNCTION_ATTRIBUTE
|
||||
#endif
|
||||
|
||||
// Predefine this for later.
|
||||
static __inline__ __device__ double atomicAdd(double* address, double val) UNUSED_FUNCTION_ATTRIBUTE;
|
||||
|
||||
// ===========================================================================
|
||||
// CUDA kernels follow, lots of them
|
||||
// ===========================================================================
|
||||
|
@ -138,18 +175,6 @@ static __inline__ __device__ double atomicAdd(double* address, double val) UNUSE
|
|||
// (ElemenType *res, CUDA_LONG N), a pointer and length of the output block. Each thread computes a function
|
||||
// of the inputs for one value in the output.
|
||||
|
||||
// This macro overloads _x() with float and double arguments, and inlines the correct library function. This simplifies templated kernel code.
|
||||
// TODO: merge with similar definition in TensorOps.h
|
||||
#define DEF_ELEMENT_PRIMITIVE(x) __device__ __forceinline__ float _##x(float f) { return x##f(f); } __device__ __forceinline__ double _##x(double f) { return x(f); }
|
||||
|
||||
DEF_ELEMENT_PRIMITIVE(exp)
|
||||
DEF_ELEMENT_PRIMITIVE(log)
|
||||
DEF_ELEMENT_PRIMITIVE(tanh)
|
||||
DEF_ELEMENT_PRIMITIVE(sqrt)
|
||||
DEF_ELEMENT_PRIMITIVE(fabs)
|
||||
DEF_ELEMENT_PRIMITIVE(cos)
|
||||
DEF_ELEMENT_PRIMITIVE(sin)
|
||||
|
||||
template<class ElemType>
|
||||
__global__ void _elementWisePowerOnCuda(
|
||||
const ElemType alpha,
|
||||
|
@ -188,6 +213,7 @@ __global__ void _elementWisePowerOnCuda(
|
|||
};
|
||||
|
||||
// Note that this code is inefficient on CUDA due to diverging code paths.
|
||||
// Use Sigmoid() in TensorOps.h instead, which solves this problem.
|
||||
template<class ElemType>
|
||||
__global__ void _elementWiseSigmoidOnCuda(
|
||||
const ElemType *a,
|
||||
|
@ -200,12 +226,12 @@ __global__ void _elementWiseSigmoidOnCuda(
|
|||
#else
|
||||
if (a[id] >= 0)
|
||||
{
|
||||
ElemType e = _exp(-a[id]);
|
||||
ElemType e = exp_(-a[id]);
|
||||
res[id] = 1 / (1 + e);
|
||||
}
|
||||
else
|
||||
{
|
||||
ElemType e = _exp(a[id]);
|
||||
ElemType e = exp_(a[id]);
|
||||
res[id] = e / (1 + e);
|
||||
}
|
||||
#endif
|
||||
|
@ -227,7 +253,7 @@ __global__ void _assignSigmoidOf(
|
|||
res[id] = Microsoft::MSR::CNTK::Sigmoid(a[id]);
|
||||
#else
|
||||
ElemType negElem = -a[id];
|
||||
ElemType e = _exp(negElem);
|
||||
ElemType e = exp_(negElem);
|
||||
|
||||
res[id] = 1 / (e + 1);
|
||||
#endif
|
||||
|
@ -260,7 +286,7 @@ __global__ void _elementWiseTanhOnCuda(
|
|||
const CUDA_LONG N)
|
||||
{
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id,N);
|
||||
res[id] = _tanh(a[id]);
|
||||
res[id] = tanh_(a[id]);
|
||||
};
|
||||
|
||||
//to prevent negative values caused by floating operations, we force inputs to be >=0
|
||||
|
@ -272,7 +298,7 @@ __global__ void _elementWiseSqrtOnCuda(
|
|||
const CUDA_LONG N)
|
||||
{
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id,N);
|
||||
res[id] = _sqrt(max((ElemType)0, a[id]));
|
||||
res[id] = sqrt_(max((ElemType)0, a[id]));
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -282,7 +308,7 @@ __global__ void _elementWiseExpOnCuda(
|
|||
const CUDA_LONG N)
|
||||
{
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id,N);
|
||||
res[id] = _exp(a[id]);
|
||||
res[id] = exp_(a[id]);
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -292,7 +318,7 @@ __global__ void _elementWiseLogOnCuda(
|
|||
const CUDA_LONG N)
|
||||
{
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id,N);
|
||||
res[id] = (a[id] < EPS_IN_LOG) ? LOG_OF_EPS_IN_LOG : _log(a[id]);
|
||||
res[id] = (a[id] < EPS_IN_LOG) ? LOG_OF_EPS_IN_LOG : log_(a[id]);
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -302,7 +328,7 @@ __global__ void _elementWiseAbsOnCuda(
|
|||
const CUDA_LONG N)
|
||||
{
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id,N);
|
||||
res[id] = _fabs(a[id]);
|
||||
res[id] = fabs_(a[id]);
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -312,7 +338,7 @@ __global__ void _elementWiseCosineOnCuda(
|
|||
const CUDA_LONG N)
|
||||
{
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id,N);
|
||||
res[id] = _cos(a[id]);
|
||||
res[id] = cos_(a[id]);
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -322,7 +348,7 @@ __global__ void _elementWiseNegativeSineOnCuda(
|
|||
const CUDA_LONG N)
|
||||
{
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id,N);
|
||||
res[id] = -_sin(a[id]);
|
||||
res[id] = -sin_(a[id]);
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -3464,7 +3490,7 @@ __global__ void _assignNoiseContrastiveEstimation(
|
|||
if (positive)
|
||||
prob = -prob;
|
||||
ElemType score_noise = log_num_noise_samples + prob;
|
||||
ElemType z = logadd(tmp[i], score_noise);
|
||||
ElemType z = logaddk(tmp[i], score_noise);
|
||||
ElemType logprob = tmp[i] - z;
|
||||
ElemType logprob_noise = score_noise - z;
|
||||
tmp[i] = -exp(logprob);
|
||||
|
@ -3756,40 +3782,6 @@ __global__ void _normalGradForSparseBlock(
|
|||
lhsValues[index] = rhs[IDX2C(row, col, numRows)];
|
||||
}
|
||||
|
||||
static __inline__ __device__ double atomicAdd(double* address, double val)
|
||||
{
|
||||
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
static __inline__ __device__ ElemType logadd(ElemType x, ElemType y)
|
||||
{
|
||||
ElemType temp, diff, z;
|
||||
|
||||
if (x < y)
|
||||
{
|
||||
temp = x; x = y; y = temp;
|
||||
}
|
||||
diff = y - x;
|
||||
if (diff < MINLOGEXP)
|
||||
{
|
||||
return (x < LSMALL)?LZERO:x;
|
||||
}
|
||||
else
|
||||
{
|
||||
z = exp(diff);
|
||||
return x + log(1.0 + z);
|
||||
}
|
||||
}
|
||||
|
||||
//This function should be called with 1024 threads per block and 1 block
|
||||
//THIS IS NOT THE MOST EFFICIENT IMPLEMENTATION!!!
|
||||
template<class ElemType>
|
||||
|
@ -4554,7 +4546,7 @@ __global__ void _rcrfBackwardCompute(
|
|||
fSum = LZERO;
|
||||
for (int j = 0; j < iNumLab; j++)
|
||||
{
|
||||
fSum = logadd(fSum, alpha[IDX2C(j, t, iNumLab)]);
|
||||
fSum = logaddk(fSum, alpha[IDX2C(j, t, iNumLab)]);
|
||||
}
|
||||
|
||||
fTmp = alpha[IDX2C(id, t, iNumLab)] - fSum;
|
||||
|
@ -4566,10 +4558,10 @@ __global__ void _rcrfBackwardCompute(
|
|||
fSum = LZERO;
|
||||
for (int m = 0; m < iNumLab; m++)
|
||||
{
|
||||
fSum = logadd(fSum, alpha[IDX2C(m, t, iNumLab)] + pair_scores[IDX2C(j, m, iNumLab)]);
|
||||
fSum = logaddk(fSum, alpha[IDX2C(m, t, iNumLab)] + pair_scores[IDX2C(j, m, iNumLab)]);
|
||||
}
|
||||
|
||||
fTmp = logadd(fTmp, beta[IDX2C(j, t + 1, iNumLab)] + alpha[IDX2C(id, t, iNumLab)] + pair_scores[IDX2C(j, id, iNumLab)] - fSum);
|
||||
fTmp = logaddk(fTmp, beta[IDX2C(j, t + 1, iNumLab)] + alpha[IDX2C(id, t, iNumLab)] + pair_scores[IDX2C(j, id, iNumLab)] - fSum);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4630,7 +4622,7 @@ __global__ void _rcrfBackwardCompute(
|
|||
{
|
||||
for (int j = 0; j < iNumLab; j++)
|
||||
{
|
||||
fTmp = logadd(fTmp, beta_t1[j] + alpha[id] + pair_scores[j] - zeta[j]);
|
||||
fTmp = logaddk(fTmp, beta_t1[j] + alpha[id] + pair_scores[j] - zeta[j]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4671,9 +4663,9 @@ __global__ void _rcrfBackwardComputeZeta(
|
|||
for (int m = 0; m < iNumLab; m++)
|
||||
{
|
||||
if (t == iNumPos - 1)
|
||||
fSum = logadd(fSum, alpha[IDX2C(m, 0, iNumLab)]);
|
||||
fSum = logaddk(fSum, alpha[IDX2C(m, 0, iNumLab)]);
|
||||
else
|
||||
fSum = logadd(fSum, alpha[IDX2C(m, 0, iNumLab)] + pair_scores[m]);
|
||||
fSum = logaddk(fSum, alpha[IDX2C(m, 0, iNumLab)] + pair_scores[m]);
|
||||
}
|
||||
|
||||
gzeta[id] = fSum;
|
||||
|
@ -4725,7 +4717,7 @@ __global__ void _rcrfTransGrdComputeZeta(
|
|||
else
|
||||
fTmp = alpha[m];
|
||||
|
||||
fSum = logadd(fSum, pair_scores[m] + fTmp);
|
||||
fSum = logaddk(fSum, pair_scores[m] + fTmp);
|
||||
}
|
||||
|
||||
gzeta[id] = fSum;
|
||||
|
@ -4828,7 +4820,7 @@ __global__ void _reductionLogAddSum(
|
|||
{
|
||||
ElemType lSum = LZERO;
|
||||
if (tid < s){
|
||||
lSum = logadd(partialLogAddSum[tid], partialLogAddSum[tid + s]);
|
||||
lSum = logaddk(partialLogAddSum[tid], partialLogAddSum[tid + s]);
|
||||
partialLogAddSum[tid] = lSum;
|
||||
}
|
||||
}
|
||||
|
@ -4953,4 +4945,6 @@ __global__ void _maskColumnsValue(ElemType *a, const char *columnsMask, CUDA_LON
|
|||
}
|
||||
}
|
||||
|
||||
}}}
|
||||
|
||||
#endif // !CPUONLY
|
||||
|
|
|
@ -25,18 +25,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// -----------------------------------------------------------------------
|
||||
// unified overloads for float/double math functions
|
||||
//
|
||||
// Declare float and double versions of the functions f we need as f_(),
|
||||
// e.g. exp_ -> exp(double), expf(float).
|
||||
// Declare float and double versions of the functions x we need as x_().
|
||||
// This macro overloads x_() with float and double arguments, and inlines the correct library function,
|
||||
// e.g. exp_ -> exp(double), expf(float). This simplifies templated kernel code.
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#pragma push_macro("OverloadUnaryMathFns")
|
||||
#define OverloadUnaryMathFns(func) \
|
||||
DECL float func ## _(float arg) { return func ## f(arg); } \
|
||||
DECL double func ## _(double arg) { return func(arg); }
|
||||
#define OverloadUnaryMathFns(x) DECL float x ## _(float f) { return x ## f(f); } DECL double x ## _(double f) { return x(f); }
|
||||
|
||||
OverloadUnaryMathFns(exp);
|
||||
OverloadUnaryMathFns(log);
|
||||
OverloadUnaryMathFns(tanh);
|
||||
OverloadUnaryMathFns(sqrt);
|
||||
OverloadUnaryMathFns(fabs);
|
||||
OverloadUnaryMathFns(cos);
|
||||
OverloadUnaryMathFns(sin);
|
||||
|
||||
OverloadUnaryMathFns(fabs); OverloadUnaryMathFns(sqrt);
|
||||
OverloadUnaryMathFns(exp); OverloadUnaryMathFns(log);
|
||||
OverloadUnaryMathFns(tanh); OverloadUnaryMathFns(cos); OverloadUnaryMathFns(sin);
|
||||
#pragma push_macro("OverloadUnaryMathFns")
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
@ -85,7 +89,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return sqrt_(z > 0 ? z : 0);
|
||||
}
|
||||
|
||||
// TODO: call this LogAdd() for consistency
|
||||
template<typename ElemType>
|
||||
DECL ElemType LogAdd(ElemType x, ElemType y)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче