TensorOpElement::Compute() now uses tree-based reduction
This commit is contained in:
Родитель
24900d073f
Коммит
e37a053842
|
@ -343,15 +343,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
accumulators[redId] = TensorOpParallelReduce<ElemType, N, M, M - 1>::Compute(redId, pointers, op, reducingOpDims, reducingStrides);
|
||||
|
||||
// reduce
|
||||
__syncthreads(); // watch out to do this outside 'if (inRange)'
|
||||
if (redId == 0)
|
||||
{
|
||||
double sum = 0;
|
||||
for (int i = 0; i < reductionDim; i++)
|
||||
sum += accumulators[i];
|
||||
accumulators[0] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
static_assert(GridDim::maxThreadsPerBlock <= 512, "GridDim::maxThreadsPerBlock too large, need to add manually unrolled steps");
|
||||
if (redId < 256 && redId + 256 < reductionDim) { accumulators[redId] += accumulators[redId + 256]; __syncthreads(); }
|
||||
if (redId < 128 && redId + 128 < reductionDim) { accumulators[redId] += accumulators[redId + 128]; __syncthreads(); }
|
||||
if (redId < 64 && redId + 64 < reductionDim) { accumulators[redId] += accumulators[redId + 64]; __syncthreads(); }
|
||||
if (redId < 32 && redId + 32 < reductionDim) accumulators[redId] += accumulators[redId + 32];
|
||||
if (redId < 16 && redId + 16 < reductionDim) accumulators[redId] += accumulators[redId + 16];
|
||||
if (redId < 8 && redId + 8 < reductionDim) accumulators[redId] += accumulators[redId + 8];
|
||||
if (redId < 4 && redId + 4 < reductionDim) accumulators[redId] += accumulators[redId + 4];
|
||||
if (redId < 2 && redId + 2 < reductionDim) accumulators[redId] += accumulators[redId + 2];
|
||||
if (redId < 1 && redId + 1 < reductionDim) accumulators[redId] += accumulators[redId + 1];
|
||||
|
||||
// now set final value to output coordinate
|
||||
if (redId == 0)
|
||||
{
|
||||
|
@ -423,6 +426,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
for (C_size_t k = 0; k < reducingOpDimVector.size(); k++)
|
||||
reductionDim *= (C_size_t)reducingOpDimVector[k];
|
||||
let & props = GridDim::GetDeviceProps();
|
||||
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize)
|
||||
{
|
||||
printf("%d %d\n", (int)reductionDim, (int)props.multiProcessorCount * props.warpSize && reductionDim > GridDim::maxThreadsPerBlock);
|
||||
}
|
||||
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim <= GridDim::maxThreadsPerBlock)
|
||||
{
|
||||
if (reductionDim <= GridDim::maxThreadsPerBlock)
|
||||
|
|
Загрузка…
Ссылка в новой задаче