Revert "passing through reductionOp. Not yet compiling."
This reverts commit 84c9b0caa0
.
This commit is contained in:
Родитель
84c9b0caa0
Коммит
73d1e32d3a
|
@ -268,7 +268,7 @@ struct TensorOps
|
|||
//#define ReduceElemType double
|
||||
#define ReduceElemType ElemType // (note: we could use 'double' here, but that would cause problems with CUDA cards that don't support double)
|
||||
|
||||
template <class ElemType, ElementWiseOperator reductionOp, C_size_t N, C_int M, C_int m>
|
||||
template <class ElemType, C_size_t N, C_int M, C_int m>
|
||||
struct TensorOpReduce
|
||||
{
|
||||
// this version for m >= 0
|
||||
|
@ -277,7 +277,7 @@ struct TensorOpReduce
|
|||
{
|
||||
// start with index 0
|
||||
// We may use 'double' since we are memory-bound anyway.
|
||||
ReduceElemType aggregate = TensorOpReduce<ElemType, reductionOp, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
|
||||
ReduceElemType aggregate = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
|
||||
// apply this index to the pointers
|
||||
C_size_t dim = reducingOpDims[m];
|
||||
for (C_size_t k = 1 /*done with k=0 already*/; k < dim; k++)
|
||||
|
@ -285,21 +285,8 @@ struct TensorOpReduce
|
|||
// bump the pointers
|
||||
for (C_size_t i = 0; i < N - 1; i++) // N-1 because output is not used here
|
||||
pointers[i] += reducingStrides(i, (C_size_t) m);
|
||||
ElemType val = TensorOpReduce<ElemType, reductionOp, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
|
||||
|
||||
switch (reductionOp)
|
||||
{
|
||||
case ElementWiseOperator::opSum:
|
||||
aggregate += val;
|
||||
break;
|
||||
case ElementWiseOperator::opMax:
|
||||
aggregate = std::max<ReduceElemType>(aggregate, val);
|
||||
break;
|
||||
case ElementWiseOperator::opMin:
|
||||
aggregate = std::max<ReduceElemType>(aggregate, val);
|
||||
break;
|
||||
}
|
||||
|
||||
ElemType val = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
|
||||
aggregate += val;
|
||||
}
|
||||
return (ElemType) aggregate;
|
||||
}
|
||||
|
@ -308,7 +295,7 @@ struct TensorOpReduce
|
|||
// this one terminates the template recursion over reduction dimensions
|
||||
// The pointers are pointing to the input element.
|
||||
template <class ElemType, C_size_t N, C_int M>
|
||||
struct TensorOpReduce<ElemType, ElementWiseOperator reductionOp, N, M, /*m=*/-1>
|
||||
struct TensorOpReduce<ElemType, N, M, /*m=*/-1>
|
||||
{
|
||||
// this version for m = -1
|
||||
// the pointers are pointing to the right location(s) to take the operation over
|
||||
|
|
Загрузка…
Ссылка в новой задаче