Merge pull request #791 from kangshiyin/trace-mat-mat

2 CUDA kernels for TraceMatMat with/without transpose for all matrix size.
This commit is contained in:
Daniel Povey 2016-05-18 23:37:47 -04:00
Родитель 0d4f1b2483 70df8813d6
Коммит fb3c66c0ed
4 изменённых файлов: 143 добавлений и 95 удалений

Просмотреть файл

@ -107,8 +107,8 @@ void cudaF_vec_mul_elements(int Gr, int Bl, float* v, const float* a, int dim);
void cudaF_vec_soft_max(int Gr, int Bl, float* v, int dim);
void cudaF_vec_min(const float* v, float* value, int dim);
void cudaF_vec_max(const float* v, float* value, int dim);
void cudaF_trace_mat_mat_trans(const float* A, const float* B, MatrixDim dA, int B_stride, float* value);
void cudaF_trace_mat_mat(const float* A, const float* B, MatrixDim dA, int B_stride, float* value);
void cudaF_trace_mat_mat_trans(dim3 Gr, dim3 Bl, const float* A, const float* B, MatrixDim dA, int B_stride, float* value);
void cudaF_trace_mat_mat(dim3 Gr, dim3 Bl, const float* A, const float* B, MatrixDim dA, int B_stride, float* value);
void cudaF_add_diag_mat_mat(int Gr, int Bl, float alpha, float* v, int v_dim, const float* M,
int M_cols, int M_row_stride, int M_col_stride, const float *N, int N_row_stride,
int N_col_stride, int threads_per_element, float beta);
@ -248,8 +248,8 @@ void cudaD_vec_mul_elements(int Gr, int Bl, double* v, const double* a, int dim)
void cudaD_vec_soft_max(int Gr, int Bl, double* v, int dim);
void cudaD_vec_min(const double* v, double* value, int dim);
void cudaD_vec_max(const double* v, double* value, int dim);
void cudaD_trace_mat_mat_trans(const double* A, const double* B, MatrixDim dA, int B_stride, double* value);
void cudaD_trace_mat_mat(const double* A, const double* B, MatrixDim dA, int B_stride, double* value);
void cudaD_trace_mat_mat_trans(dim3 Gr, dim3 Bl, const double* A, const double* B, MatrixDim dA, int B_stride, double* value);
void cudaD_trace_mat_mat(dim3 Gr, dim3 Bl, const double* A, const double* B, MatrixDim dA, int B_stride, double* value);
void cudaD_add_diag_mat_mat(int Gr, int Bl, double alpha, double* v, int v_dim, const double* M,
int M_cols, int M_row_stride, int M_col_stride, const double *N, int N_row_stride,
int N_col_stride, int threads_per_element, double beta);

Просмотреть файл

@ -834,78 +834,122 @@ static void _vec_max(const Real* v, Real* value, int dim) {
}
// _trace_mat_mat expects to be called with 1 blocks, each of dimension
// CU1DBLOCK. Each block outputs a partial sum to value[blockIdx.x],
// i.e. value[0 through 0].
template<typename Real, int num_blocks>
// _trace_mat_mat reduce the partial sum to value[blockIdx.y * gridDim.x + blockIdx.x]
// It use shared mem to transpose matrix B to ensure coalesced memory access
template<int TileDim, typename Real>
__global__
static void _trace_mat_mat(const Real* A, const Real* B, MatrixDim dA, int B_stride, Real* value) {
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
static void _trace_mat_mat(const Real* A, const Real* B, MatrixDim dA,
int B_stride, Real* value) {
// Reuse shared mem and make indexing easier. "+1" to avoid bank conflict
__shared__ union {
Real trans[TileDim][TileDim + 1];
Real sum[CU1DBLOCK];
} smem;
const int32_cuda tid = threadIdx.y * blockDim.x + threadIdx.x; // linear thread id;
const int32_cuda grid_height = gridDim.y * TileDim;
if(blockIdx.x > num_blocks || threadIdx.x > CU1DBLOCK) return;
const int32_cuda ja = blockIdx.x * TileDim + threadIdx.x;
const int32_cuda ib = blockIdx.x * TileDim + threadIdx.y;
int32_cuda ia = blockIdx.y * TileDim + threadIdx.y;
int32_cuda jb = blockIdx.y * TileDim + threadIdx.x;
int num_elements = dA.rows * dA.cols,
num_threads = CU1DBLOCK * num_blocks;
int block_size = (num_elements + num_threads - 1) / num_threads;
int loop_start = i * block_size, loop_end = (i + 1) * block_size;
if (loop_end > num_elements)
loop_end = num_elements;
// Grid reduce
Real tsum = Real(0);
for (int32_cuda i0 = 0; i0 < dA.rows; i0 += grid_height) {
// Load from B, transpose the block and store in shared mem
if (jb < dA.rows) {
# pragma unroll
for (int i = 0; i < TileDim; i += CU1DBLOCK / TileDim) {
if (ib + i < dA.cols) {
smem.trans[threadIdx.x][threadIdx.y + i] =
B[(ib + i) * B_stride + jb];
}
}
}
__syncthreads();
Real sum = 0.0;
for (int j = loop_start; j < loop_end; j++) {
// for (int j = i; j < num_elements; j += num_threads) {
int row = j / dA.cols, col = j % dA.cols; // "row" is row-index in A, "col" is
// col-index in A; in B, it's reversed.
int index_A = col + row * dA.stride,
index_B = row + col * B_stride;
sum += A[index_A] * B[index_B];
// Load from A, sum up the product.
if (ja < dA.cols) {
# pragma unroll
for (int i = 0; i < TileDim; i += CU1DBLOCK / TileDim) {
if (ia + i < dA.rows) {
tsum += A[(ia + i) * dA.stride + ja]
* smem.trans[threadIdx.y + i][threadIdx.x];
}
}
}
__syncthreads();
ia += grid_height;
jb += grid_height;
}
__shared__ Real row_data[CU1DBLOCK];
row_data[threadIdx.x] = sum;
smem.sum[tid] = tsum;
__syncthreads();
Real ans = _sum_reduce(row_data);
if (threadIdx.x == 0)
value[blockIdx.x] = ans;
// Block reduce
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift)
smem.sum[tid] += smem.sum[tid + shift];
__syncthreads();
}
// Warp reduce. Implicitly synchronized within a warp.
if (tid < warpSize) {
# pragma unroll
for (int shift = warpSize; shift > 0; shift >>= 1) {
smem.sum[tid] += smem.sum[tid + shift];
}
}
// output 1 sum per thread block
if (tid == 0) {
value[blockIdx.y * gridDim.x + blockIdx.x] = smem.sum[0];
}
}
// _trace_mat_mat_trans expects to be called with 4 blocks, each of dimension
// CU1DBLOCK. Each block outputs a partial sum to value[blockIdx.x],
// i.e. value[0 through 3].
template<typename Real, int num_blocks>
// _trace_mat_mat_trans reduce the partial sum to value[blockIdx.y * gridDim.x + blockIdx.x]
template<typename Real>
__global__
static void _trace_mat_mat_trans(const Real* A, const Real* B, MatrixDim dA, int B_stride, Real* value) {
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ Real ssum[CU1DBLOCK];
const int32_cuda tid = threadIdx.y * blockDim.x + threadIdx.x; // linear thread id;
const int32_cuda j = blockIdx.x * blockDim.x + threadIdx.x;
const int32_cuda grid_height = gridDim.y * blockDim.y;
int32_cuda i = blockIdx.y * blockDim.y + threadIdx.y;
if(blockIdx.x > num_blocks || threadIdx.x > CU1DBLOCK) return;
int num_elements = dA.rows * dA.cols,
num_threads = CU1DBLOCK * num_blocks;
// int block_size = (num_elements + num_threads - 1) / num_threads;
// int loop_start = i * block_size, loop_end = (i + 1) * block_size;
// if (loop_end > num_elements)
// loop_end = num_elements;
Real sum = 0.0;
// for (int j = loop_start; j < loop_end; j++) {
for (int j = i; j < num_elements; j += num_threads) {
int row = j / dA.cols, col = j % dA.cols; // "row" is row-index in A, "col" is
// col-index in A; in B, it's reversed.
int index_A = col + row * dA.stride,
index_B = col + row * B_stride;
sum += A[index_A] * B[index_B];
// Grid reduce
Real tsum = Real(0);
if (j < dA.cols) {
while (i < dA.rows) {
tsum += A[i * dA.stride + j] * B[i * B_stride + j];
i += grid_height;
}
}
__shared__ Real row_data[CU1DBLOCK];
row_data[threadIdx.x] = sum;
ssum[tid] = tsum;
__syncthreads();
Real ans = _sum_reduce(row_data);
if (threadIdx.x == 0)
value[blockIdx.x] = ans;
// Block reduce
# pragma unroll
for (int shift = CU1DBLOCK / 2; shift > warpSize; shift >>= 1) {
if (tid < shift)
ssum[tid] += ssum[tid + shift];
__syncthreads();
}
// Warp reduce. Implicitly synchronized within a warp.
if (tid < warpSize) {
# pragma unroll
for (int shift = warpSize; shift > 0; shift >>= 1) {
ssum[tid] += ssum[tid + shift];
}
}
// output 1 sum per thread block
if (tid == 0) {
value[blockIdx.y * gridDim.x + blockIdx.x] = ssum[0];
}
}
@ -2387,12 +2431,12 @@ void cudaF_vec_max(const float* v, float* value, int dim) {
_vec_max<<<1,CU1DBLOCK>>>(v, value, dim);
}
void cudaF_trace_mat_mat_trans(const float* A, const float* B, MatrixDim dA, int B_stride, float* value) {
_trace_mat_mat_trans<float,4> <<<4,CU1DBLOCK>>>(A,B,dA,B_stride,value);
void cudaF_trace_mat_mat_trans(dim3 Gr, dim3 Bl, const float* A, const float* B, MatrixDim dA, int B_stride, float* value) {
_trace_mat_mat_trans<<<Gr,Bl>>>(A,B,dA,B_stride,value);
}
void cudaF_trace_mat_mat(const float* A, const float* B, MatrixDim dA, int B_stride, float* value) {
_trace_mat_mat<float,2> <<<2,CU1DBLOCK>>>(A,B,dA,B_stride,value);
void cudaF_trace_mat_mat(dim3 Gr, dim3 Bl, const float* A, const float* B, MatrixDim dA, int B_stride, float* value) {
_trace_mat_mat<32><<<Gr,Bl>>>(A,B,dA,B_stride,value);
}
@ -2851,12 +2895,12 @@ void cudaD_vec_max(const double* v, double* value, int dim) {
_vec_max<<<1,CU1DBLOCK>>>(v, value, dim);
}
void cudaD_trace_mat_mat_trans(const double* A, const double* B, MatrixDim dA, int B_stride, double* value) {
_trace_mat_mat_trans<double,4> <<<4,CU1DBLOCK>>>(A,B,dA,B_stride,value);
void cudaD_trace_mat_mat_trans(dim3 Gr, dim3 Bl, const double* A, const double* B, MatrixDim dA, int B_stride, double* value) {
_trace_mat_mat_trans<<<Gr,Bl>>>(A,B,dA,B_stride,value);
}
void cudaD_trace_mat_mat(const double* A, const double* B, MatrixDim dA, int B_stride, double* value) {
_trace_mat_mat<double,2> <<<2,CU1DBLOCK>>>(A,B,dA,B_stride,value);
void cudaD_trace_mat_mat(dim3 Gr, dim3 Bl, const double* A, const double* B, MatrixDim dA, int B_stride, double* value) {
_trace_mat_mat<32><<<Gr,Bl>>>(A,B,dA,B_stride,value);
}
void cudaD_add_diag_mat_mat(int Gr, int Bl, double alpha, double* v, int v_dim, const double* M,

Просмотреть файл

@ -192,8 +192,8 @@ inline void cuda_vec_mul_elements(int Gr, int Bl, float* v, const float* a, int
inline void cuda_vec_soft_max(int Gr, int Bl, float* v, int dim) { cudaF_vec_soft_max(Gr,Bl,v,dim); }
inline void cuda_vec_min(const float* v, float* value, int dim) { cudaF_vec_min(v,value,dim); }
inline void cuda_vec_max(const float* v, float* value, int dim) { cudaF_vec_max(v,value,dim); }
inline void cuda_trace_mat_mat_trans(const float* A, const float* B, MatrixDim dA, int B_stride, float* value) { cudaF_trace_mat_mat_trans(A,B,dA,B_stride,value); }
inline void cuda_trace_mat_mat(const float* A, const float* B, MatrixDim dA, int B_stride, float* value) { cudaF_trace_mat_mat(A,B,dA,B_stride,value); }
inline void cuda_trace_mat_mat_trans(dim3 Gr, dim3 Bl, const float* A, const float* B, MatrixDim dA, int B_stride, float* value) { cudaF_trace_mat_mat_trans(Gr,Bl,A,B,dA,B_stride,value); }
inline void cuda_trace_mat_mat(dim3 Gr, dim3 Bl, const float* A, const float* B, MatrixDim dA, int B_stride, float* value) { cudaF_trace_mat_mat(Gr,Bl,A,B,dA,B_stride,value); }
inline void cuda_add_diag_mat_mat(int Gr, int Bl, float alpha, float* v, int v_dim, const float* M,
int M_cols, int M_row_stride, int M_col_stride, const float *N, int N_row_stride,
int N_col_stride, int threads_per_element, float beta) {
@ -377,8 +377,8 @@ inline void cuda_vec_mul_elements(int Gr, int Bl, double* v, const double* a, in
inline void cuda_vec_soft_max(int Gr, int Bl, double* v, int dim) { cudaD_vec_soft_max(Gr,Bl,v,dim); }
inline void cuda_vec_min(const double* v, double* value, int dim) { cudaD_vec_min(v,value,dim); }
inline void cuda_vec_max(const double* v, double* value, int dim) { cudaD_vec_max(v,value,dim); }
inline void cuda_trace_mat_mat_trans(const double* A, const double* B, MatrixDim dA, int B_stride, double* value) { cudaD_trace_mat_mat_trans(A,B,dA,B_stride,value); }
inline void cuda_trace_mat_mat(const double* A, const double* B, MatrixDim dA, int B_stride, double* value) { cudaD_trace_mat_mat(A,B,dA,B_stride,value); }
inline void cuda_trace_mat_mat_trans(dim3 Gr, dim3 Bl, const double* A, const double* B, MatrixDim dA, int B_stride, double* value) { cudaD_trace_mat_mat_trans(Gr,Bl,A,B,dA,B_stride,value); }
inline void cuda_trace_mat_mat(dim3 Gr, dim3 Bl, const double* A, const double* B, MatrixDim dA, int B_stride, double* value) { cudaD_trace_mat_mat(Gr,Bl,A,B,dA,B_stride,value); }
inline void cuda_add_diag_mat_mat(int Gr, int Bl, double alpha, double* v, int v_dim, const double* M,
int M_cols, int M_row_stride, int M_col_stride, const double *N, int N_row_stride,
int N_col_stride, int threads_per_element, double beta) {

Просмотреть файл

@ -1726,31 +1726,35 @@ Real TraceMatMat(const CuMatrixBase<Real> &A,
} else {
KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols());
}
if (A.NumRows() * A.NumCols() > 16384) {
// This version in which we don't use a special-purpose kernel, but
// do AddDiagMat on a temporary vector and returns its sum, seems to be
// faster for larger matrices. The cutoff is approximate and
// we only looked at the time on square matrices, which
// is what we test in cu-matrix-speed-test.cc.
CuVector<Real> sum_vec(A.NumRows());
sum_vec.AddDiagMatMat(1.0, A, kNoTrans,
B, trans, 0.0);
return sum_vec.Sum();
} else {
Timer tim;
// the sizes of result_vec must match what we
// call the kernels with, in cu-kernels.cu
CuVector<Real> result_vec(trans == kTrans ? 4 : 2, kUndefined);
if (trans == kNoTrans) {
cuda_trace_mat_mat(A.Data(), B.Data(), A.Dim(), B.Stride(), result_vec.Data());
} else {
cuda_trace_mat_mat_trans(A.Data(), B.Data(), A.Dim(), B.Stride(), result_vec.Data());
Timer tim;
// 2D blocks: each (8x32) block sums up (32x32) elements.
// 2D grid: try to cover all the matrix A unless it is too big.
// Kernel will reduce to ~256 elements with good performance,
// if the matrix is not in a very bad shape.
// (wider or taller than 32x8192)
// CPU will then reduce to 1 element.
const int kWarpSize = 32;
dim3 dimBlock(kWarpSize, CU1DBLOCK / kWarpSize);
dim3 dimGrid(n_blocks(A.NumCols(), kWarpSize),
n_blocks(A.NumRows(), kWarpSize));
if (dimGrid.x * dimGrid.y > 256) {
dimGrid.y = 256 / dimGrid.x;
if (dimGrid.y == 0) {
dimGrid.y = 1;
}
CU_SAFE_CALL(cudaGetLastError());
Vector<Real> result_cpu(result_vec); // copying from CUDA faster than summing in CUDA.
result = result_cpu.Sum();
CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
}
CuVector<Real> result_vec(dimGrid.x * dimGrid.y, kUndefined);
if (trans == kNoTrans) {
cuda_trace_mat_mat(dimGrid, dimBlock, A.Data(), B.Data(), A.Dim(),
B.Stride(), result_vec.Data());
} else {
cuda_trace_mat_mat_trans(dimGrid, dimBlock, A.Data(), B.Data(), A.Dim(),
B.Stride(), result_vec.Data());
}
CU_SAFE_CALL(cudaGetLastError());
Vector<Real> result_cpu(result_vec); // copying from CUDA faster than summing in CUDA.
result = result_cpu.Sum();
CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
#endif
{