Adding robustness to cuda convolution (handling failures).

This commit is contained in:
Cha Zhang 2017-02-03 13:06:10 -08:00
Родитель dcf2835a38
Коммит 5be5a3e71b
1 изменённых файлов: 31 добавлений и 6 удалений

Просмотреть файл

@ -246,7 +246,9 @@ protected:
m_fwdAlgo.Algo.algo, ptr(workspace), m_fwdAlgo.Algo.memory, &C::Zero, m_outT, ptr(out));
// There might be a case where cuDNN fails due to workspace being too small, try using no-workspace algo instead.
// REVIEW alexeyk: NVIDIA is currently reviewing this issue.
if (CUDNN_STATUS_INVALID_VALUE == err && m_fwdAlgo.Algo.memory > 0)
// chazhang: it seems even the no-workspace algo can fail from time to time. Hence we give it a second change here anyway (and second time it usually succeed)
// NVIDIA should definitely investigate on this
if (CUDNN_STATUS_SUCCESS != err)
{
if (m_forceDeterministicAlgorithms)
RuntimeError("Falling back of the algorithms is not allowed. Please set 'forceDeterministicAlgorithms=false'.");
@ -260,7 +262,6 @@ protected:
// Only supported in MatrixPool enable
// NOTE: it's unnecessary to keep the workspace.
workspace.Resize(0, 0);
CUDNN_CALL(err);
}
@ -290,9 +291,21 @@ protected:
if (m_backDataAlgo.Algo.memory > 0)
workspace.Resize((m_backDataAlgo.Algo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
// Compute gradients with respect to the output tensor (data).
CUDNN_CALL(cudnnConvolutionBackwardData(*m_cudnn, &C::One, *m_kernelT, ptr(kernel), m_outT, ptr(srcGrad), *m_conv, m_backDataAlgo.Algo.algo,
ptr(workspace), m_backDataAlgo.Algo.memory, accumulateGradient ? &C::One : &C::Zero, m_inT, ptr(grad)));
auto err = cudnnConvolutionBackwardData(*m_cudnn, &C::One, *m_kernelT, ptr(kernel), m_outT, ptr(srcGrad), *m_conv, m_backDataAlgo.Algo.algo,
ptr(workspace), m_backDataAlgo.Algo.memory, accumulateGradient ? &C::One : &C::Zero, m_inT, ptr(grad));
// handle NVIDIA failure, need to be careful this is only doable if accumulateGradient == false, otherwise, the state may already be messed up
if (CUDNN_STATUS_SUCCESS != err && !accumulateGradient)
{
if (m_forceDeterministicAlgorithms)
RuntimeError("Falling back of the algorithms is not allowed. Please set 'forceDeterministicAlgorithms=false'.");
auto err2 = cudnnConvolutionBackwardData(*m_cudnn, &C::One, *m_kernelT, ptr(kernel), m_outT, ptr(srcGrad), *m_conv, m_backDataAlgo.NoWorkspaceAlgo,
nullptr, 0, &C::Zero, m_inT, ptr(grad));
// Update original error in case of success.
if (CUDNN_STATUS_SUCCESS == err2)
err = CUDNN_STATUS_SUCCESS;
}
workspace.Resize(0, 0);
CUDNN_CALL(err);
}
void BackwardKernelCore(const Mat& srcGrad, const Mat& in, Mat& kernelGrad, bool accumulateGradient, bool /*allowReuse*/, Mat& workspace) override
@ -321,9 +334,21 @@ protected:
if (m_backFiltAlgo.Algo.memory > 0)
workspace.Resize((m_backFiltAlgo.Algo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
// Compute gradients with respect to the output tensor (data).
CUDNN_CALL(cudnnConvolutionBackwardFilter(*m_cudnn, &C::One, m_inT, ptr(in), m_outT, ptr(srcGrad), *m_conv, m_backFiltAlgo.Algo.algo,
ptr(workspace), m_backFiltAlgo.Algo.memory, accumulateGradient ? &C::One : &C::Zero, *m_kernelT, ptr(kernelGrad)));
auto err = cudnnConvolutionBackwardFilter(*m_cudnn, &C::One, m_inT, ptr(in), m_outT, ptr(srcGrad), *m_conv, m_backFiltAlgo.Algo.algo,
ptr(workspace), m_backFiltAlgo.Algo.memory, accumulateGradient ? &C::One : &C::Zero, *m_kernelT, ptr(kernelGrad));
// handle NVIDIA failure, need to be careful this is only doable if accumulateGradient == false, otherwise, the state may already be messed up
if (CUDNN_STATUS_SUCCESS != err && !accumulateGradient)
{
if (m_forceDeterministicAlgorithms)
RuntimeError("Falling back of the algorithms is not allowed. Please set 'forceDeterministicAlgorithms=false'.");
auto err2 = cudnnConvolutionBackwardFilter(*m_cudnn, &C::One, m_inT, ptr(in), m_outT, ptr(srcGrad), *m_conv, m_backFiltAlgo.NoWorkspaceAlgo,
nullptr, 0, &C::Zero, *m_kernelT, ptr(kernelGrad));
// Update original error in case of success.
if (CUDNN_STATUS_SUCCESS == err2)
err = CUDNN_STATUS_SUCCESS;
}
workspace.Resize(0, 0);
CUDNN_CALL(err);
}
void EnsurePoolingInitialized() override