Changed "NeutralValue" to function templates

This commit is contained in:
Thilo Will 2016-07-22 16:49:54 +02:00
Родитель 80fdb8f53d
Коммит bd776dc849
1 изменённых файлов: 7 добавлений и 11 удалений

Просмотреть файл

@ -264,16 +264,12 @@ struct TensorOps
//----------------------------------------------------------------------------
// For reductions we need the neutral elements of the corresponding binary ops
//----------------------------------------------------------------------------
class BinaryOpConstants
template <typename ElemType> __device__ ElemType NeutralValue(ElementWiseOperator op)
{
public:
template <typename ElemType> __device__ static ElemType NeutralValue(ElementWiseOperator op)
{
return 0; // error, only the explicit instantiations below should be used.
}
return 0; // error, only the explicit instantiations below should be used.
};
template<> __device__ static float BinaryOpConstants::NeutralValue<float>(ElementWiseOperator op)
template<> __device__ float NeutralValue<float>(ElementWiseOperator op)
{
switch (op)
{
@ -282,9 +278,9 @@ template<> __device__ static float BinaryOpConstants::NeutralValue<float>(Elemen
case ElementWiseOperator::opSum: return 0;
default: return 0; // error
}
}
};
template<> __device__ static double BinaryOpConstants::NeutralValue<double>(ElementWiseOperator op)
template<> __device__ double NeutralValue<double>(ElementWiseOperator op)
{
switch (op)
{
@ -293,7 +289,7 @@ template<> __device__ static double BinaryOpConstants::NeutralValue<double>(Elem
case ElementWiseOperator::opSum: return 0;
default: return 0; // error
}
}
};
// ----------------------------------------------------------------------------
@ -505,7 +501,7 @@ struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
CUDA_LONG reductionEnd = min(reductionBegin + reductionChunkSize, reductionDim);
// compute the operation for this input coordinate
ReduceElemType aggregate = BinaryOpConstants::NeutralValue<ReduceElemType>(reductionOp);
ReduceElemType aggregate = NeutralValue<ReduceElemType>(reductionOp);
for (CUDA_LONG redId = reductionBegin + tid; redId < reductionEnd; redId += tids)
{