Merge pull request #780 from kangshiyin/faster-FindRowMaxId

A new CUDA kernel for CuMatrixBase<Real>::FindRowMaxId;
This commit is contained in:
Daniel Povey 2016-05-18 14:42:23 -04:00
Родитель 609653117b 24b886a243
Коммит 26ef88fed9
4 изменённых файлов: 66 добавлений и 60 удалений

Просмотреть файл

@ -148,7 +148,7 @@ void cudaF_tanh(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src
void cudaF_diff_tanh(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int e_stride, int y_stride);
void cudaF_regularize_l1(dim3 Gr, dim3 Bl, float *wei, float *grad, float l1, float lr, MatrixDim d, int stride_grad);
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d);
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, MatrixDim d);
void cudaF_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, float *mat_net_out, float *vec_log_post, MatrixDim d);
void cudaF_copy_rows_from_vec(dim3 Gr, dim3 Bl, float *mat_out, MatrixDim d_out, const float *v_in);
@ -291,7 +291,7 @@ void cudaD_tanh(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int s
void cudaD_diff_tanh(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int e_stride, int y_stride);
void cudaD_regularize_l1(dim3 Gr, dim3 Bl, double *wei, double *grad, double l1, double lr, MatrixDim d, int stride_grad);
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d);
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, MatrixDim d);
void cudaD_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, double *mat_net_out, double *vec_log_post, MatrixDim d);
void cudaD_copy_rows_from_vec(dim3 Gr, dim3 Bl, double *mat_out, MatrixDim d_out, const double *v_in);

Просмотреть файл

@ -2045,36 +2045,63 @@ static void _regularize_l1(Real* wei, Real* grad, Real l1, Real lr, MatrixDim d,
}
}
template<typename Real>
__global__
static void _find_row_max_id(const Real* mat, Real* vec_val, int32_cuda* vec_id, int32_cuda voff, MatrixDim d) {
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
static void _find_row_max_id(const Real* mat, Real* vec_val, int32_cuda* vec_id,
MatrixDim d) {
const int32_cuda i = blockIdx.x;
const int32_cuda base = i * d.stride;
const int32_cuda tid = threadIdx.x;
if(blockIdx.x > 0) return;
if(blockDim.y != 1) return;
__shared__ Real smax[CU1DBLOCK];
__shared__ int32_cuda sidx[CU1DBLOCK];
__shared__ Real value[CU1DBLOCK];
__shared__ int32_cuda index[CU1DBLOCK];
Real tmax = -1e20;
int32_cuda tidx = -1;
//copy to shared memory
value[threadIdx.x] = mat[i+j*d.stride];
index[threadIdx.x] = threadIdx.x;
__syncthreads();
//get the id of the max value
int32_cuda out_max = _max_id_reduce(value, index);
__syncthreads();
//see if it's bigger value
if(threadIdx.x == 0) {
if(vec_val[j] <= mat[out_max+j*d.stride]) {
vec_val[j] = mat[out_max+j*d.stride];
vec_id[j] = voff+out_max;
// Loop over blocks for coalesced memory access.
for (int32_cuda j = tid; j < d.cols; j += CU1DBLOCK) {
const Real val = mat[base + j];
if (val > tmax) {
tmax = val;
tidx = j;
}
}
smax[tid] = tmax;
sidx[tid] = tidx;
// Parallel reduce
#pragma unroll
for (int32_cuda num_working_threads = CU1DBLOCK / 2;
num_working_threads >= warpSize; num_working_threads >>= 1) {
__syncthreads();
if (tid < num_working_threads) {
if (smax[tid + num_working_threads] > smax[tid]) {
smax[tid] = smax[tid + num_working_threads];
sidx[tid] = sidx[tid + num_working_threads];
}
}
}
// Warp reduce without __syncthreads()
// (note.: synchronizes implicitly within a warp at the multiprocessor)
if (tid < warpSize / 2) {
#pragma unroll
for (int32_cuda num_working_threads = warpSize / 2; num_working_threads > 0;
num_working_threads >>= 1) {
if (smax[tid + num_working_threads] > smax[tid]) {
smax[tid] = smax[tid + num_working_threads];
sidx[tid] = sidx[tid + num_working_threads];
}
}
}
if (tid == 0) {
if (vec_val) {
vec_val[i] = smax[0];
}
vec_id[i] = sidx[0];
}
}
@ -2534,8 +2561,8 @@ void cudaF_regularize_l1(dim3 Gr, dim3 Bl, float* wei, float* grad, float l1, fl
_regularize_l1<<<Gr,Bl>>>(wei,grad,l1,lr,d,stride_grad);
}
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float* mat, float* vec_val, int32_cuda* vec_id, int32_cuda voff, MatrixDim d) {
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, voff, d);
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float* mat, float* vec_val, int32_cuda* vec_id, MatrixDim d) {
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, d);
}
void cudaF_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda* vec_tgt, float* mat_net_out, float* vec_log_post, MatrixDim d) {
@ -2996,8 +3023,8 @@ void cudaD_regularize_l1(dim3 Gr, dim3 Bl, double* wei, double* grad, double l1,
_regularize_l1<<<Gr,Bl>>>(wei,grad,l1,lr,d,stride_grad);
}
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double* mat, double* vec_val, int32_cuda* vec_id, int32_cuda voff, MatrixDim d) {
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, voff, d);
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double* mat, double* vec_val, int32_cuda* vec_id, MatrixDim d) {
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, d);
}
void cudaD_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda* vec_tgt, double* mat_net_out, double* vec_log_post, MatrixDim d) {

Просмотреть файл

@ -249,7 +249,7 @@ inline void cuda_softmax_reduce(size_t Gr, size_t Bl, float *y, const float *x,
inline void cuda_log_softmax_reduce(size_t Gr, size_t Bl, float *y, const float *x, MatrixDim d, int src_stride) { cudaF_log_softmax_reduce(Gr,Bl,y,x,d,src_stride); }
inline void cuda_regularize_l1(dim3 Gr, dim3 Bl, float *wei, float *grad, float l1, float lr, MatrixDim d, int stride_grad) { cudaF_regularize_l1(Gr,Bl,wei,grad,l1,lr,d,stride_grad); }
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d) { cudaF_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,voff,d); }
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, MatrixDim d) { cudaF_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,d); }
inline void cuda_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, float *mat_net_out, float *vec_log_post, MatrixDim d) { cudaF_diff_xent(Gr,Bl,vec_tgt,mat_net_out,vec_log_post,d); }
inline void cuda_copy_rows_from_vec(dim3 Gr, dim3 Bl, float *mat_out, MatrixDim d_out, const float *v_in) {
cudaF_copy_rows_from_vec(Gr, Bl, mat_out, d_out, v_in);
@ -428,7 +428,7 @@ inline void cuda_softmax_reduce(size_t Gr, size_t Bl, double *y, const double *x
inline void cuda_log_softmax_reduce(size_t Gr, size_t Bl, double *y, const double *x, MatrixDim d, int src_stride) { cudaD_log_softmax_reduce(Gr,Bl,y,x,d,src_stride); }
inline void cuda_regularize_l1(dim3 Gr, dim3 Bl, double *wei, double *grad, double l1, double lr, MatrixDim d, int stride_grad) { cudaD_regularize_l1(Gr,Bl,wei,grad,l1,lr,d,stride_grad); }
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d) { cudaD_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,voff,d); }
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, MatrixDim d) { cudaD_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,d); }
inline void cuda_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, double *mat_net_out, double *vec_log_post, MatrixDim d) {
cudaD_diff_xent(Gr,Bl,vec_tgt,mat_net_out,vec_log_post,d);
}

Просмотреть файл

@ -1489,36 +1489,15 @@ void CuMatrixBase<Real>::FindRowMaxId(CuArray<int32> *id) const {
#if HAVE_CUDA == 1
if (CuDevice::Instantiate().Enabled()) {
Timer tim;
// initialize the vectors
CuVector<Real> max(num_rows_);
max.Set(-1e21);
id->Resize(num_rows_);
id->Set(-1);
MatrixDim d = Dim();
MatrixDim d=Dim(); // only stride will be used!
// CUDA thread layout: one thread block per matrix-row.
dim3 dimBlock(CU1DBLOCK);
dim3 dimGrid(num_rows_);
cuda_find_row_max_id(dimGrid, dimBlock, data_, NULL, id->Data(), d);
CU_SAFE_CALL(cudaGetLastError());
// process per 256 column blocks
for (int32 block = 0; (block+1)*256 <= num_cols_; block++) {
dim3 dimBlock(CU1DBLOCK, 1);
dim3 dimGrid(1, num_rows_);
int32 offset = block*CU1DBLOCK;
cuda_find_row_max_id(dimGrid, dimBlock, data_ + offset,
max.data_, id->Data(), offset, d);
}
// process the remainder
int32 div = num_cols_ / 256;
int32 mod = num_cols_ % 256;
if (mod != 0) {
dim3 dimBlock(mod, 1);
dim3 dimGrid(1, num_rows_);
int32 offset=div*256;
cuda_find_row_max_id(dimGrid, dimBlock, data_ + offset,
max.data_, id->Data(), offset, d);
}
// now we have the indices!
CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
@ -1529,11 +1508,11 @@ void CuMatrixBase<Real>::FindRowMaxId(CuArray<int32> *id) const {
id->Set(-1);
// find maxima
MatrixIndexT num_rows = num_rows_, num_cols = num_cols_;
for(MatrixIndexT r = 0; r < num_rows; r++) {
for (MatrixIndexT r = 0; r < num_rows; r++) {
Real max = -1e21;
int32 max_id = -1;
const Real *row_data = Mat().RowData(r);
for(MatrixIndexT c = 0; c < num_cols; c++) {
for (MatrixIndexT c = 0; c < num_cols; c++) {
if (max < row_data[c]) {
max = row_data[c];
max_id = c;