diff --git a/Source/Math/GPUTensor.cu b/Source/Math/GPUTensor.cu index 04a65bd70..8dcc2265d 100644 --- a/Source/Math/GPUTensor.cu +++ b/Source/Math/GPUTensor.cu @@ -258,7 +258,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { // template-recursive version loops over indices static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray & pointers, ElemType alpha, ElementWiseOperator op, const FixedArray & regularOpStrides, const FixedMatrix & regularStrides, - const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides) + const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides, CUDA_LONG reductionDim) { // map id (location on grid) to index[k] C_size_t stride = regularOpStrides[(C_size_t)k]; @@ -268,7 +268,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { for (C_size_t i = 0; i < N; i++) pointers[i] += index * regularStrides(i,(C_size_t)k); // now this dimension is taken care of // process the previous index - TensorOpElement::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides); + TensorOpElement::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim); } }; @@ -279,7 +279,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { // template-recursive version loops over indices static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray & pointers, ElemType alpha, ElementWiseOperator op, const FixedArray & regularOpStrides, const FixedMatrix & regularStrides, - const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides) + const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides, CUDA_LONG reductionDim) { // map id (location on grid) to index[k] C_size_t index = id; // this dimension @@ -287,7 +287,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { for (C_size_t i = 0; i < N; i++) pointers[i] += index * regularStrides(i,0); // now this dimension is taken care of // process the previous index - TensorOpElement::Compute(/*id*/0, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides); + TensorOpElement::Compute(/*id*/0, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim); } }; @@ -313,7 +313,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { // now the output pointers point to the right element (input pointers may still iterate for reduction) static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray & pointers, ElemType alpha, ElementWiseOperator op, const FixedArray & /*regularOpStrides*/, const FixedMatrix & /*regularStrides*/, - const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides) + const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides, CUDA_LONG /*reductionDim*/) { // compute the operation for this output coordinate // This may still involve a reduction over inverse-broadcasting dimensions. @@ -331,29 +331,30 @@ namespace Microsoft { namespace MSR { namespace CNTK { // now the output pointers point to the right element (input pointers may still iterate for reduction) static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray & pointers, ElemType alpha, ElementWiseOperator op, const FixedArray & /*regularOpStrides*/, const FixedMatrix & /*regularStrides*/, - const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides) + const FixedArray & reducingOpDims, const FixedMatrix & reducingStrides, CUDA_LONG reductionDim) { - CUDA_LONG reductionDim = blockDim.x; - CUDA_LONG redId = threadIdx.x; + CUDA_LONG redId = threadIdx.x; // note: blockDim.x might be out of bounds w.r.t. redId in case we split reduction + reductionDim = blockDim.x; // accumulator __shared__ double accumulators[GridDim::maxThreadsPerBlock]; // compute the operation for this input coordinate - accumulators[redId] = TensorOpParallelReduce::Compute(redId, pointers, op, reducingOpDims, reducingStrides); + if (redId < reductionDim) + accumulators[redId] = TensorOpParallelReduce::Compute(redId, pointers, op, reducingOpDims, reducingStrides); - // reduce + // reduce --cf https://docs.nvidia.com/cuda/samples/6_Advanced/reduction/doc/reduction.pdf __syncthreads(); static_assert(GridDim::maxThreadsPerBlock <= 512, "GridDim::maxThreadsPerBlock too large, need to add manually unrolled steps"); - if (redId < 256 && redId + 256 < reductionDim) { accumulators[redId] += accumulators[redId + 256]; __syncthreads(); } - if (redId < 128 && redId + 128 < reductionDim) { accumulators[redId] += accumulators[redId + 128]; __syncthreads(); } - if (redId < 64 && redId + 64 < reductionDim) { accumulators[redId] += accumulators[redId + 64]; __syncthreads(); } - if (redId < 32 && redId + 32 < reductionDim) accumulators[redId] += accumulators[redId + 32]; - if (redId < 16 && redId + 16 < reductionDim) accumulators[redId] += accumulators[redId + 16]; - if (redId < 8 && redId + 8 < reductionDim) accumulators[redId] += accumulators[redId + 8]; - if (redId < 4 && redId + 4 < reductionDim) accumulators[redId] += accumulators[redId + 4]; - if (redId < 2 && redId + 2 < reductionDim) accumulators[redId] += accumulators[redId + 2]; - if (redId < 1 && redId + 1 < reductionDim) accumulators[redId] += accumulators[redId + 1]; + if (redId < 256 && redId + 256 < reductionDim) accumulators[redId] += accumulators[redId + 256]; if (0 + 256 < reductionDim) __syncthreads(); // sync if condition true for at least one thread + if (redId < 128 && redId + 128 < reductionDim) accumulators[redId] += accumulators[redId + 128]; if (0 + 128 < reductionDim) __syncthreads(); + if (redId < 64 && redId + 64 < reductionDim) accumulators[redId] += accumulators[redId + 64]; if (0 + 64 < reductionDim) __syncthreads(); + if (redId < 32 && redId + 32 < reductionDim) accumulators[redId] += accumulators[redId + 32]; if (0 + 32 < reductionDim) __syncthreads(); // somehow I still need to sync, contradicting the PDF + if (redId < 16 && redId + 16 < reductionDim) accumulators[redId] += accumulators[redId + 16]; if (0 + 16 < reductionDim) __syncthreads(); + if (redId < 8 && redId + 8 < reductionDim) accumulators[redId] += accumulators[redId + 8]; __syncthreads(); + if (redId < 4 && redId + 4 < reductionDim) accumulators[redId] += accumulators[redId + 4]; __syncthreads(); + if (redId < 2 && redId + 2 < reductionDim) accumulators[redId] += accumulators[redId + 2]; __syncthreads(); + if (redId < 1 && redId + 1 < reductionDim) accumulators[redId] += accumulators[redId + 1]; // now set final value to output coordinate if (redId == 0) @@ -371,21 +372,21 @@ namespace Microsoft { namespace MSR { namespace CNTK { // the top-level kernel template __global__ void _launchTensorOp(ElemType beta, FixedArray pointers, ElemType alpha, ElementWiseOperator op, - FixedArray regularOpStrides, FixedMatrix regularStrides, - FixedArray reducingOpDims, FixedMatrix reducingStrides, CUDA_LONG numElements) + FixedArray regularOpStrides, FixedMatrix regularStrides, CUDA_LONG numElements, + FixedArray reducingOpDims, FixedMatrix reducingStrides, CUDA_LONG reductionDim) { CUDA_LONG id = GridDim::GetLinearThreadId(); - if (id < numElements) - TensorOpElement::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides); + if (id < numElements) // note: there are no __syncthread() calls inside + TensorOpElement::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim); } template __global__ void _launchTensorOpParallelReduction(ElemType beta, FixedArray pointers, ElemType alpha, ElementWiseOperator op, - FixedArray regularOpStrides, FixedMatrix regularStrides, - FixedArray reducingOpDims, FixedMatrix reducingStrides, CUDA_LONG numElements) + FixedArray regularOpStrides, FixedMatrix regularStrides, CUDA_LONG numElements, + FixedArray reducingOpDims, FixedMatrix reducingStrides, CUDA_LONG reductionDim) { CUDA_LONG id = gridDim.y * blockIdx.x + blockIdx.y; // input dimensions are Y dimension of blocks in this case, so we can use thread dim for shared-memory/parallelization - if (id < numElements) - TensorOpElement::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides); + if (id < numElements) // note: we have __syncthread() calls but only entire blocks in sync, so this is OK + TensorOpElement::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim); } // launch tensor op with CUDA @@ -426,9 +427,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { for (C_size_t k = 0; k < reducingOpDimVector.size(); k++) reductionDim *= (C_size_t)reducingOpDimVector[k]; let & props = GridDim::GetDeviceProps(); - if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize) + if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim > GridDim::maxThreadsPerBlock) { - printf("%d %d\n", (int)reductionDim, (int)props.multiProcessorCount * props.warpSize && reductionDim > GridDim::maxThreadsPerBlock); + fprintf(stderr, "%d %d\n", (int)reductionDim, (int)props.multiProcessorCount * props.warpSize); } if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim <= GridDim::maxThreadsPerBlock) { @@ -439,7 +440,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { GridDim grid(NN); let blocksPerGrid = dim3(grid.m_blocksPerGrid, grid.m_threadsPerBlock); // block Y is element dimension let threadsPerBlock = dim3(reductionDim); // X dimension is reduction dimension - _launchTensorOpParallelReduction << > >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N); + _launchTensorOpParallelReduction << > >(beta, pointers, alpha, op, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides, reductionDim); } else { @@ -471,7 +472,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { #endif { GridDim grid(NN); - _launchTensorOp << > >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N); + _launchTensorOp << > >(beta, pointers, alpha, op, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides, 1); } if (do_sync) CUDA_CALL(cudaEventRecord(done)); if (do_sync) CUDA_CALL(cudaEventSynchronize(done));