This commit is contained in:
Frank Seide 2015-12-28 21:35:57 -08:00
Родитель e37a053842
Коммит c3cf76fc14
1 изменённых файлов: 32 добавлений и 31 удалений

Просмотреть файл

@ -258,7 +258,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// template-recursive version loops over indices
static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & regularOpStrides, const FixedMatrix<C_int, N, K> & regularStrides,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides, CUDA_LONG reductionDim)
{
// map id (location on grid) to index[k]
C_size_t stride = regularOpStrides[(C_size_t)k];
@ -268,7 +268,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * regularStrides(i,(C_size_t)k); // now this dimension is taken care of
// process the previous index
TensorOpElement<ElemType, N, M, K, parallelReduce, k - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
TensorOpElement<ElemType, N, M, K, parallelReduce, k - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim);
}
};
@ -279,7 +279,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// template-recursive version loops over indices
static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & regularOpStrides, const FixedMatrix<C_int, N, K> & regularStrides,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides, CUDA_LONG reductionDim)
{
// map id (location on grid) to index[k]
C_size_t index = id; // this dimension
@ -287,7 +287,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * regularStrides(i,0); // now this dimension is taken care of
// process the previous index
TensorOpElement<ElemType, N, M, K, parallelReduce, -1>::Compute(/*id*/0, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
TensorOpElement<ElemType, N, M, K, parallelReduce, -1>::Compute(/*id*/0, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim);
}
};
@ -313,7 +313,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// now the output pointers point to the right element (input pointers may still iterate for reduction)
static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & /*regularOpStrides*/, const FixedMatrix<C_int, N, K> & /*regularStrides*/,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides, CUDA_LONG /*reductionDim*/)
{
// compute the operation for this output coordinate
// This may still involve a reduction over inverse-broadcasting dimensions.
@ -331,29 +331,30 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// now the output pointers point to the right element (input pointers may still iterate for reduction)
static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & /*regularOpStrides*/, const FixedMatrix<C_int, N, K> & /*regularStrides*/,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides, CUDA_LONG reductionDim)
{
CUDA_LONG reductionDim = blockDim.x;
CUDA_LONG redId = threadIdx.x;
CUDA_LONG redId = threadIdx.x; // note: blockDim.x might be out of bounds w.r.t. redId in case we split reduction
reductionDim = blockDim.x;
// accumulator
__shared__ double accumulators[GridDim::maxThreadsPerBlock];
// compute the operation for this input coordinate
accumulators[redId] = TensorOpParallelReduce<ElemType, N, M, M - 1>::Compute(redId, pointers, op, reducingOpDims, reducingStrides);
if (redId < reductionDim)
accumulators[redId] = TensorOpParallelReduce<ElemType, N, M, M - 1>::Compute(redId, pointers, op, reducingOpDims, reducingStrides);
// reduce
// reduce --cf https://docs.nvidia.com/cuda/samples/6_Advanced/reduction/doc/reduction.pdf
__syncthreads();
static_assert(GridDim::maxThreadsPerBlock <= 512, "GridDim::maxThreadsPerBlock too large, need to add manually unrolled steps");
if (redId < 256 && redId + 256 < reductionDim) { accumulators[redId] += accumulators[redId + 256]; __syncthreads(); }
if (redId < 128 && redId + 128 < reductionDim) { accumulators[redId] += accumulators[redId + 128]; __syncthreads(); }
if (redId < 64 && redId + 64 < reductionDim) { accumulators[redId] += accumulators[redId + 64]; __syncthreads(); }
if (redId < 32 && redId + 32 < reductionDim) accumulators[redId] += accumulators[redId + 32];
if (redId < 16 && redId + 16 < reductionDim) accumulators[redId] += accumulators[redId + 16];
if (redId < 8 && redId + 8 < reductionDim) accumulators[redId] += accumulators[redId + 8];
if (redId < 4 && redId + 4 < reductionDim) accumulators[redId] += accumulators[redId + 4];
if (redId < 2 && redId + 2 < reductionDim) accumulators[redId] += accumulators[redId + 2];
if (redId < 1 && redId + 1 < reductionDim) accumulators[redId] += accumulators[redId + 1];
if (redId < 256 && redId + 256 < reductionDim) accumulators[redId] += accumulators[redId + 256]; if (0 + 256 < reductionDim) __syncthreads(); // sync if condition true for at least one thread
if (redId < 128 && redId + 128 < reductionDim) accumulators[redId] += accumulators[redId + 128]; if (0 + 128 < reductionDim) __syncthreads();
if (redId < 64 && redId + 64 < reductionDim) accumulators[redId] += accumulators[redId + 64]; if (0 + 64 < reductionDim) __syncthreads();
if (redId < 32 && redId + 32 < reductionDim) accumulators[redId] += accumulators[redId + 32]; if (0 + 32 < reductionDim) __syncthreads(); // somehow I still need to sync, contradicting the PDF
if (redId < 16 && redId + 16 < reductionDim) accumulators[redId] += accumulators[redId + 16]; if (0 + 16 < reductionDim) __syncthreads();
if (redId < 8 && redId + 8 < reductionDim) accumulators[redId] += accumulators[redId + 8]; __syncthreads();
if (redId < 4 && redId + 4 < reductionDim) accumulators[redId] += accumulators[redId + 4]; __syncthreads();
if (redId < 2 && redId + 2 < reductionDim) accumulators[redId] += accumulators[redId + 2]; __syncthreads();
if (redId < 1 && redId + 1 < reductionDim) accumulators[redId] += accumulators[redId + 1];
// now set final value to output coordinate
if (redId == 0)
@ -371,21 +372,21 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// the top-level kernel
template<class ElemType, C_size_t N, C_int M, C_int K>
__global__ void _launchTensorOp(ElemType beta, FixedArray<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG numElements)
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides, CUDA_LONG numElements,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG reductionDim)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id < numElements)
TensorOpElement<ElemType, N, M, K, false, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
if (id < numElements) // note: there are no __syncthread() calls inside
TensorOpElement<ElemType, N, M, K, false, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim);
}
template<class ElemType, C_size_t N, C_int M, C_int K>
__global__ void _launchTensorOpParallelReduction(ElemType beta, FixedArray<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG numElements)
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides, CUDA_LONG numElements,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG reductionDim)
{
CUDA_LONG id = gridDim.y * blockIdx.x + blockIdx.y; // input dimensions are Y dimension of blocks in this case, so we can use thread dim for shared-memory/parallelization
if (id < numElements)
TensorOpElement<ElemType, N, M, K, true, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
if (id < numElements) // note: we have __syncthread() calls but only entire blocks in sync, so this is OK
TensorOpElement<ElemType, N, M, K, true, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, reductionDim);
}
// launch tensor op with CUDA
@ -426,9 +427,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
for (C_size_t k = 0; k < reducingOpDimVector.size(); k++)
reductionDim *= (C_size_t)reducingOpDimVector[k];
let & props = GridDim::GetDeviceProps();
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize)
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim > GridDim::maxThreadsPerBlock)
{
printf("%d %d\n", (int)reductionDim, (int)props.multiProcessorCount * props.warpSize && reductionDim > GridDim::maxThreadsPerBlock);
fprintf(stderr, "%d %d\n", (int)reductionDim, (int)props.multiProcessorCount * props.warpSize);
}
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim <= GridDim::maxThreadsPerBlock)
{
@ -439,7 +440,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
GridDim grid(NN);
let blocksPerGrid = dim3(grid.m_blocksPerGrid, grid.m_threadsPerBlock); // block Y is element dimension
let threadsPerBlock = dim3(reductionDim); // X dimension is reduction dimension
_launchTensorOpParallelReduction<ElemType, N, M, K> << <blocksPerGrid, threadsPerBlock, reductionDim * sizeof(double), t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
_launchTensorOpParallelReduction<ElemType, N, M, K> << <blocksPerGrid, threadsPerBlock, reductionDim * sizeof(double), t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides, reductionDim);
}
else
{
@ -471,7 +472,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
#endif
{
GridDim grid(NN);
_launchTensorOp<ElemType, N, M, K> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
_launchTensorOp<ElemType, N, M, K> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides, 1);
}
if (do_sync) CUDA_CALL(cudaEventRecord(done));
if (do_sync) CUDA_CALL(cudaEventSynchronize(done));