зеркало из https://github.com/mozilla/kaldi.git
Merge pull request #780 from kangshiyin/faster-FindRowMaxId
A new CUDA kernel for CuMatrixBase<Real>::FindRowMaxId;
This commit is contained in:
Коммит
26ef88fed9
|
@ -148,7 +148,7 @@ void cudaF_tanh(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src
|
|||
void cudaF_diff_tanh(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int e_stride, int y_stride);
|
||||
|
||||
void cudaF_regularize_l1(dim3 Gr, dim3 Bl, float *wei, float *grad, float l1, float lr, MatrixDim d, int stride_grad);
|
||||
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d);
|
||||
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, MatrixDim d);
|
||||
void cudaF_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, float *mat_net_out, float *vec_log_post, MatrixDim d);
|
||||
void cudaF_copy_rows_from_vec(dim3 Gr, dim3 Bl, float *mat_out, MatrixDim d_out, const float *v_in);
|
||||
|
||||
|
@ -291,7 +291,7 @@ void cudaD_tanh(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int s
|
|||
void cudaD_diff_tanh(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int e_stride, int y_stride);
|
||||
|
||||
void cudaD_regularize_l1(dim3 Gr, dim3 Bl, double *wei, double *grad, double l1, double lr, MatrixDim d, int stride_grad);
|
||||
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d);
|
||||
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, MatrixDim d);
|
||||
void cudaD_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, double *mat_net_out, double *vec_log_post, MatrixDim d);
|
||||
void cudaD_copy_rows_from_vec(dim3 Gr, dim3 Bl, double *mat_out, MatrixDim d_out, const double *v_in);
|
||||
|
||||
|
|
|
@ -2045,36 +2045,63 @@ static void _regularize_l1(Real* wei, Real* grad, Real l1, Real lr, MatrixDim d,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename Real>
|
||||
__global__
|
||||
static void _find_row_max_id(const Real* mat, Real* vec_val, int32_cuda* vec_id, int32_cuda voff, MatrixDim d) {
|
||||
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
static void _find_row_max_id(const Real* mat, Real* vec_val, int32_cuda* vec_id,
|
||||
MatrixDim d) {
|
||||
const int32_cuda i = blockIdx.x;
|
||||
const int32_cuda base = i * d.stride;
|
||||
const int32_cuda tid = threadIdx.x;
|
||||
|
||||
if(blockIdx.x > 0) return;
|
||||
if(blockDim.y != 1) return;
|
||||
__shared__ Real smax[CU1DBLOCK];
|
||||
__shared__ int32_cuda sidx[CU1DBLOCK];
|
||||
|
||||
__shared__ Real value[CU1DBLOCK];
|
||||
__shared__ int32_cuda index[CU1DBLOCK];
|
||||
Real tmax = -1e20;
|
||||
int32_cuda tidx = -1;
|
||||
|
||||
//copy to shared memory
|
||||
value[threadIdx.x] = mat[i+j*d.stride];
|
||||
index[threadIdx.x] = threadIdx.x;
|
||||
__syncthreads();
|
||||
|
||||
//get the id of the max value
|
||||
int32_cuda out_max = _max_id_reduce(value, index);
|
||||
__syncthreads();
|
||||
|
||||
//see if it's bigger value
|
||||
if(threadIdx.x == 0) {
|
||||
if(vec_val[j] <= mat[out_max+j*d.stride]) {
|
||||
vec_val[j] = mat[out_max+j*d.stride];
|
||||
vec_id[j] = voff+out_max;
|
||||
// Loop over blocks for coalesced memory access.
|
||||
for (int32_cuda j = tid; j < d.cols; j += CU1DBLOCK) {
|
||||
const Real val = mat[base + j];
|
||||
if (val > tmax) {
|
||||
tmax = val;
|
||||
tidx = j;
|
||||
}
|
||||
}
|
||||
|
||||
smax[tid] = tmax;
|
||||
sidx[tid] = tidx;
|
||||
|
||||
// Parallel reduce
|
||||
#pragma unroll
|
||||
for (int32_cuda num_working_threads = CU1DBLOCK / 2;
|
||||
num_working_threads >= warpSize; num_working_threads >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < num_working_threads) {
|
||||
if (smax[tid + num_working_threads] > smax[tid]) {
|
||||
smax[tid] = smax[tid + num_working_threads];
|
||||
sidx[tid] = sidx[tid + num_working_threads];
|
||||
}
|
||||
}
|
||||
}
|
||||
// Warp reduce without __syncthreads()
|
||||
// (note.: synchronizes implicitly within a warp at the multiprocessor)
|
||||
if (tid < warpSize / 2) {
|
||||
#pragma unroll
|
||||
for (int32_cuda num_working_threads = warpSize / 2; num_working_threads > 0;
|
||||
num_working_threads >>= 1) {
|
||||
if (smax[tid + num_working_threads] > smax[tid]) {
|
||||
smax[tid] = smax[tid + num_working_threads];
|
||||
sidx[tid] = sidx[tid + num_working_threads];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
if (vec_val) {
|
||||
vec_val[i] = smax[0];
|
||||
}
|
||||
vec_id[i] = sidx[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -2534,8 +2561,8 @@ void cudaF_regularize_l1(dim3 Gr, dim3 Bl, float* wei, float* grad, float l1, fl
|
|||
_regularize_l1<<<Gr,Bl>>>(wei,grad,l1,lr,d,stride_grad);
|
||||
}
|
||||
|
||||
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float* mat, float* vec_val, int32_cuda* vec_id, int32_cuda voff, MatrixDim d) {
|
||||
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, voff, d);
|
||||
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float* mat, float* vec_val, int32_cuda* vec_id, MatrixDim d) {
|
||||
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, d);
|
||||
}
|
||||
|
||||
void cudaF_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda* vec_tgt, float* mat_net_out, float* vec_log_post, MatrixDim d) {
|
||||
|
@ -2996,8 +3023,8 @@ void cudaD_regularize_l1(dim3 Gr, dim3 Bl, double* wei, double* grad, double l1,
|
|||
_regularize_l1<<<Gr,Bl>>>(wei,grad,l1,lr,d,stride_grad);
|
||||
}
|
||||
|
||||
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double* mat, double* vec_val, int32_cuda* vec_id, int32_cuda voff, MatrixDim d) {
|
||||
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, voff, d);
|
||||
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double* mat, double* vec_val, int32_cuda* vec_id, MatrixDim d) {
|
||||
_find_row_max_id<<<Gr,Bl>>>(mat, vec_val, vec_id, d);
|
||||
}
|
||||
|
||||
void cudaD_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda* vec_tgt, double* mat_net_out, double* vec_log_post, MatrixDim d) {
|
||||
|
|
|
@ -249,7 +249,7 @@ inline void cuda_softmax_reduce(size_t Gr, size_t Bl, float *y, const float *x,
|
|||
inline void cuda_log_softmax_reduce(size_t Gr, size_t Bl, float *y, const float *x, MatrixDim d, int src_stride) { cudaF_log_softmax_reduce(Gr,Bl,y,x,d,src_stride); }
|
||||
|
||||
inline void cuda_regularize_l1(dim3 Gr, dim3 Bl, float *wei, float *grad, float l1, float lr, MatrixDim d, int stride_grad) { cudaF_regularize_l1(Gr,Bl,wei,grad,l1,lr,d,stride_grad); }
|
||||
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d) { cudaF_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,voff,d); }
|
||||
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, MatrixDim d) { cudaF_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,d); }
|
||||
inline void cuda_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, float *mat_net_out, float *vec_log_post, MatrixDim d) { cudaF_diff_xent(Gr,Bl,vec_tgt,mat_net_out,vec_log_post,d); }
|
||||
inline void cuda_copy_rows_from_vec(dim3 Gr, dim3 Bl, float *mat_out, MatrixDim d_out, const float *v_in) {
|
||||
cudaF_copy_rows_from_vec(Gr, Bl, mat_out, d_out, v_in);
|
||||
|
@ -428,7 +428,7 @@ inline void cuda_softmax_reduce(size_t Gr, size_t Bl, double *y, const double *x
|
|||
inline void cuda_log_softmax_reduce(size_t Gr, size_t Bl, double *y, const double *x, MatrixDim d, int src_stride) { cudaD_log_softmax_reduce(Gr,Bl,y,x,d,src_stride); }
|
||||
|
||||
inline void cuda_regularize_l1(dim3 Gr, dim3 Bl, double *wei, double *grad, double l1, double lr, MatrixDim d, int stride_grad) { cudaD_regularize_l1(Gr,Bl,wei,grad,l1,lr,d,stride_grad); }
|
||||
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d) { cudaD_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,voff,d); }
|
||||
inline void cuda_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, MatrixDim d) { cudaD_find_row_max_id(Gr,Bl,mat,vec_val,vec_id,d); }
|
||||
inline void cuda_diff_xent(dim3 Gr, dim3 Bl, const int32_cuda *vec_tgt, double *mat_net_out, double *vec_log_post, MatrixDim d) {
|
||||
cudaD_diff_xent(Gr,Bl,vec_tgt,mat_net_out,vec_log_post,d);
|
||||
}
|
||||
|
|
|
@ -1489,36 +1489,15 @@ void CuMatrixBase<Real>::FindRowMaxId(CuArray<int32> *id) const {
|
|||
#if HAVE_CUDA == 1
|
||||
if (CuDevice::Instantiate().Enabled()) {
|
||||
Timer tim;
|
||||
|
||||
// initialize the vectors
|
||||
CuVector<Real> max(num_rows_);
|
||||
max.Set(-1e21);
|
||||
id->Resize(num_rows_);
|
||||
id->Set(-1);
|
||||
MatrixDim d = Dim();
|
||||
|
||||
MatrixDim d=Dim(); // only stride will be used!
|
||||
// CUDA thread layout: one thread block per matrix-row.
|
||||
dim3 dimBlock(CU1DBLOCK);
|
||||
dim3 dimGrid(num_rows_);
|
||||
cuda_find_row_max_id(dimGrid, dimBlock, data_, NULL, id->Data(), d);
|
||||
CU_SAFE_CALL(cudaGetLastError());
|
||||
|
||||
// process per 256 column blocks
|
||||
for (int32 block = 0; (block+1)*256 <= num_cols_; block++) {
|
||||
dim3 dimBlock(CU1DBLOCK, 1);
|
||||
dim3 dimGrid(1, num_rows_);
|
||||
int32 offset = block*CU1DBLOCK;
|
||||
|
||||
cuda_find_row_max_id(dimGrid, dimBlock, data_ + offset,
|
||||
max.data_, id->Data(), offset, d);
|
||||
}
|
||||
|
||||
// process the remainder
|
||||
int32 div = num_cols_ / 256;
|
||||
int32 mod = num_cols_ % 256;
|
||||
if (mod != 0) {
|
||||
dim3 dimBlock(mod, 1);
|
||||
dim3 dimGrid(1, num_rows_);
|
||||
int32 offset=div*256;
|
||||
|
||||
cuda_find_row_max_id(dimGrid, dimBlock, data_ + offset,
|
||||
max.data_, id->Data(), offset, d);
|
||||
}
|
||||
// now we have the indices!
|
||||
CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
|
||||
} else
|
||||
|
@ -1529,11 +1508,11 @@ void CuMatrixBase<Real>::FindRowMaxId(CuArray<int32> *id) const {
|
|||
id->Set(-1);
|
||||
// find maxima
|
||||
MatrixIndexT num_rows = num_rows_, num_cols = num_cols_;
|
||||
for(MatrixIndexT r = 0; r < num_rows; r++) {
|
||||
for (MatrixIndexT r = 0; r < num_rows; r++) {
|
||||
Real max = -1e21;
|
||||
int32 max_id = -1;
|
||||
const Real *row_data = Mat().RowData(r);
|
||||
for(MatrixIndexT c = 0; c < num_cols; c++) {
|
||||
for (MatrixIndexT c = 0; c < num_cols; c++) {
|
||||
if (max < row_data[c]) {
|
||||
max = row_data[c];
|
||||
max_id = c;
|
||||
|
|
Загрузка…
Ссылка в новой задаче