TensorOpElement::Compute() now uses tree-based reduction

This commit is contained in:
Frank Seide 2015-12-28 20:53:04 -08:00
Родитель 24900d073f
Коммит e37a053842
1 изменённых файлов: 15 добавлений и 8 удалений

Просмотреть файл

@ -343,15 +343,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
accumulators[redId] = TensorOpParallelReduce<ElemType, N, M, M - 1>::Compute(redId, pointers, op, reducingOpDims, reducingStrides);
// reduce
__syncthreads(); // watch out to do this outside 'if (inRange)'
if (redId == 0)
{
double sum = 0;
for (int i = 0; i < reductionDim; i++)
sum += accumulators[i];
accumulators[0] = sum;
}
__syncthreads();
static_assert(GridDim::maxThreadsPerBlock <= 512, "GridDim::maxThreadsPerBlock too large, need to add manually unrolled steps");
if (redId < 256 && redId + 256 < reductionDim) { accumulators[redId] += accumulators[redId + 256]; __syncthreads(); }
if (redId < 128 && redId + 128 < reductionDim) { accumulators[redId] += accumulators[redId + 128]; __syncthreads(); }
if (redId < 64 && redId + 64 < reductionDim) { accumulators[redId] += accumulators[redId + 64]; __syncthreads(); }
if (redId < 32 && redId + 32 < reductionDim) accumulators[redId] += accumulators[redId + 32];
if (redId < 16 && redId + 16 < reductionDim) accumulators[redId] += accumulators[redId + 16];
if (redId < 8 && redId + 8 < reductionDim) accumulators[redId] += accumulators[redId + 8];
if (redId < 4 && redId + 4 < reductionDim) accumulators[redId] += accumulators[redId + 4];
if (redId < 2 && redId + 2 < reductionDim) accumulators[redId] += accumulators[redId + 2];
if (redId < 1 && redId + 1 < reductionDim) accumulators[redId] += accumulators[redId + 1];
// now set final value to output coordinate
if (redId == 0)
{
@ -423,6 +426,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
for (C_size_t k = 0; k < reducingOpDimVector.size(); k++)
reductionDim *= (C_size_t)reducingOpDimVector[k];
let & props = GridDim::GetDeviceProps();
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize)
{
printf("%d %d\n", (int)reductionDim, (int)props.multiProcessorCount * props.warpSize && reductionDim > GridDim::maxThreadsPerBlock);
}
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim <= GridDim::maxThreadsPerBlock)
{
if (reductionDim <= GridDim::maxThreadsPerBlock)