Changed "NeutralValue" to function templates
This commit is contained in:
Родитель
80fdb8f53d
Коммит
bd776dc849
|
@ -264,16 +264,12 @@ struct TensorOps
|
|||
//----------------------------------------------------------------------------
|
||||
// For reductions we need the neutral elements of the corresponding binary ops
|
||||
//----------------------------------------------------------------------------
|
||||
class BinaryOpConstants
|
||||
template <typename ElemType> __device__ ElemType NeutralValue(ElementWiseOperator op)
|
||||
{
|
||||
public:
|
||||
template <typename ElemType> __device__ static ElemType NeutralValue(ElementWiseOperator op)
|
||||
{
|
||||
return 0; // error, only the explicit instantiations below should be used.
|
||||
}
|
||||
return 0; // error, only the explicit instantiations below should be used.
|
||||
};
|
||||
|
||||
template<> __device__ static float BinaryOpConstants::NeutralValue<float>(ElementWiseOperator op)
|
||||
template<> __device__ float NeutralValue<float>(ElementWiseOperator op)
|
||||
{
|
||||
switch (op)
|
||||
{
|
||||
|
@ -282,9 +278,9 @@ template<> __device__ static float BinaryOpConstants::NeutralValue<float>(Elemen
|
|||
case ElementWiseOperator::opSum: return 0;
|
||||
default: return 0; // error
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<> __device__ static double BinaryOpConstants::NeutralValue<double>(ElementWiseOperator op)
|
||||
template<> __device__ double NeutralValue<double>(ElementWiseOperator op)
|
||||
{
|
||||
switch (op)
|
||||
{
|
||||
|
@ -293,7 +289,7 @@ template<> __device__ static double BinaryOpConstants::NeutralValue<double>(Elem
|
|||
case ElementWiseOperator::opSum: return 0;
|
||||
default: return 0; // error
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -505,7 +501,7 @@ struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
|
|||
CUDA_LONG reductionEnd = min(reductionBegin + reductionChunkSize, reductionDim);
|
||||
|
||||
// compute the operation for this input coordinate
|
||||
ReduceElemType aggregate = BinaryOpConstants::NeutralValue<ReduceElemType>(reductionOp);
|
||||
ReduceElemType aggregate = NeutralValue<ReduceElemType>(reductionOp);
|
||||
|
||||
for (CUDA_LONG redId = reductionBegin + tid; redId < reductionEnd; redId += tids)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче