ReduceOp passed through on all paths
This commit is contained in:
Родитель
2cc578ef0e
Коммит
204aca8563
|
@ -273,12 +273,12 @@ template <class ElemType, C_size_t N, C_int M, C_int m>
|
|||
struct TensorOpReduce
|
||||
{
|
||||
// this version for m >= 0
|
||||
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
|
||||
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op, ElementWiseOperator reductionOp,
|
||||
const FixedArray<C_unsigned_int, M>& reducingOpDims, const FixedMatrix<C_int, N, M>& reducingStrides)
|
||||
{
|
||||
// start with index 0
|
||||
// We may use 'double' since we are memory-bound anyway.
|
||||
ReduceElemType aggregate = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
|
||||
ReduceElemType aggregate = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reductionOp, reducingOpDims, reducingStrides);
|
||||
// apply this index to the pointers
|
||||
C_size_t dim = reducingOpDims[m];
|
||||
for (C_size_t k = 1 /*done with k=0 already*/; k < dim; k++)
|
||||
|
@ -286,8 +286,21 @@ struct TensorOpReduce
|
|||
// bump the pointers
|
||||
for (C_size_t i = 0; i < N - 1; i++) // N-1 because output is not used here
|
||||
pointers[i] += reducingStrides(i, (C_size_t) m);
|
||||
ElemType val = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
|
||||
aggregate += val;
|
||||
ElemType val = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reductionOp, reducingOpDims, reducingStrides);
|
||||
switch (reductionOp)
|
||||
{//TODO replace this with some operation
|
||||
case ElementWiseOperator::opMax:
|
||||
aggregate += val;
|
||||
break;
|
||||
case ElementWiseOperator::opMin:
|
||||
if (val < aggregate)
|
||||
aggregate = val;
|
||||
break;
|
||||
case ElementWiseOperator::opSum:
|
||||
if (val > aggregate)
|
||||
aggregate = val;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return (ElemType) aggregate;
|
||||
}
|
||||
|
@ -300,7 +313,7 @@ struct TensorOpReduce<ElemType, N, M, /*m=*/-1>
|
|||
{
|
||||
// this version for m = -1
|
||||
// the pointers are pointing to the right location(s) to take the operation over
|
||||
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
|
||||
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op, ElementWiseOperator reductionOp,
|
||||
const FixedArray<C_unsigned_int, M>& /*reducingOpDims*/, const FixedMatrix<C_int, N, M>& /*reducingStrides*/)
|
||||
{
|
||||
return TensorOps<ElemType>::Compute(pointers, op); // finally computing something!
|
||||
|
@ -404,7 +417,7 @@ struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/false, /*k=*/-1>
|
|||
{
|
||||
// compute the operation for this output coordinate
|
||||
// This may still involve a reduction over inverse-broadcasting dimensions.
|
||||
ElemType val = TensorOpReduce<ElemType, N, M, M - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
|
||||
ElemType val = TensorOpReduce<ElemType, N, M, M - 1>::Compute(pointers, op, reductionOp, reducingOpDims, reducingStrides);
|
||||
// scale
|
||||
val *= alpha;
|
||||
// combine with previous value in target matrix, then write it out
|
||||
|
|
Загрузка…
Ссылка в новой задаче