Speed up TimesNodeBase for sparse by avoid unroll
This change improves speed of CrossEntropyWithSoftmax/ClassificationError with sparse label. Example/LanguageUnderstanding/ATIS/Python/LanguageUnderstanding.py got 4x speed up.
This commit is contained in:
Родитель
a9c5c611fd
Коммит
16daa4acb5
|
@ -316,30 +316,35 @@ protected:
|
|||
|
||||
private:
|
||||
// Check if TimesNodeBase could be simplified to ElementTimes to avoid unroll when:
|
||||
// 1. input0: DENSE, is rank-1 and transposed, or is rank-2 with Dim(0)==1
|
||||
// 2. input1: DENSE, is rank-1
|
||||
// 1. input0: is rank-1 and transposed, or is rank-2 with Dim(0)==1
|
||||
// 2. input1: is rank-1
|
||||
// 3. output: rank-1 and reduced to a scalar (Dim(0)==1)
|
||||
// NOTE we might relax the condition on DENSE when ElementTimes support sparse in future
|
||||
bool IsReduceableDotProduct(const FrameRange& fr)
|
||||
// 4. m_transpose (becomes Matrix::InnerProduct), or both input being dense
|
||||
bool IsReduceableDotProduct(const FrameRange& fr, bool& hasSparse)
|
||||
{
|
||||
const auto& shape0 = InputRef(0).GetSampleLayout();
|
||||
const auto& shape1 = InputRef(1).GetSampleLayout();
|
||||
const auto& shapeOut = GetSampleLayout();
|
||||
|
||||
bool input0Sparse = (InputRef(0).Value().GetMatrixType() != DENSE);
|
||||
bool input1Sparse = (InputRef(1).Value().GetMatrixType() != DENSE);
|
||||
|
||||
bool input0_ok =
|
||||
((shape0.GetRank() == 1 && m_transpose) ||
|
||||
(shape0.GetRank() == 2 && shape0.GetDim(0) == 1)) &&
|
||||
(InputRef(0).Value().GetMatrixType() == DENSE); // TODO: add support in ElementTimes for sparse and remove this limitation
|
||||
(shape0.GetRank() == 1 && m_transpose) ||
|
||||
(shape0.GetRank() == 2 && shape0.GetDim(0) == 1);
|
||||
|
||||
bool input1_ok =
|
||||
(shape1.GetRank() == 1) &&
|
||||
(InputRef(1).Value().GetMatrixType() == DENSE);
|
||||
(shape1.GetRank() == 1);
|
||||
|
||||
bool outputScalar =
|
||||
(shapeOut.GetRank() == 1) &&
|
||||
(shapeOut.GetDim(0) == 1);
|
||||
|
||||
return input0_ok && input1_ok && outputScalar;
|
||||
bool notBothSparse = !(input0Sparse && input1Sparse);
|
||||
|
||||
hasSparse = (input0Sparse || input1Sparse);
|
||||
|
||||
return input0_ok && input1_ok && outputScalar && notBothSparse && (m_transpose || !hasSparse);
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -350,9 +355,24 @@ public:
|
|||
if (!fr.IsOneColumnWrt(InputRef(0).GetMBLayout()))
|
||||
{
|
||||
// speed up using ElementTimes to avoid unroll if possible
|
||||
if (IsReduceableDotProduct(fr))
|
||||
bool hasSparse;
|
||||
if (IsReduceableDotProduct(fr, hasSparse))
|
||||
{
|
||||
ElementTimesNode<ElemType>::ForwardPropImpl(*this, fr, false/*allowBroadcast*/);
|
||||
// for sparse transposed, use InnerProduct
|
||||
if (hasSparse)
|
||||
{
|
||||
Matrix<ElemType> value = ValueFor(fr);
|
||||
Matrix<ElemType> input0 = InputRef(0).ValueFor(fr);
|
||||
Matrix<ElemType> input1 = InputRef(1).ValueFor(fr);
|
||||
if (input0.GetMatrixType() == SPARSE)
|
||||
Matrix<ElemType>::InnerProduct(input0, input1, value, true/*isColWise*/);
|
||||
else
|
||||
Matrix<ElemType>::InnerProduct(input1, input0, value, true/*isColWise*/);
|
||||
}
|
||||
else
|
||||
{
|
||||
ElementTimesNode<ElemType>::ForwardPropImpl(*this, fr, false/*allowBroadcast*/);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -380,9 +400,25 @@ public:
|
|||
if (!fr.IsOneColumnWrt(InputRef(0).GetMBLayout()))
|
||||
{
|
||||
// speed up using ElementTimes to avoid unroll if possible
|
||||
if (IsReduceableDotProduct(fr))
|
||||
bool hasSparse;
|
||||
if (IsReduceableDotProduct(fr, hasSparse))
|
||||
{
|
||||
ElementTimesNode<ElemType>::BackpropToImpl(*this, inputIndex, fr, false/*allowBroadcast*/);
|
||||
if (hasSparse)
|
||||
{
|
||||
Matrix<ElemType> gradient = GradientFor(fr);
|
||||
Matrix<ElemType> inputValue = InputRef(1 - inputIndex).ValueFor(fr);
|
||||
Matrix<ElemType> inputGradient = InputRef(inputIndex).GradientFor(fr);
|
||||
Matrix<ElemType> gradientDiagonal(gradient.GetNumCols(), gradient.GetNumCols(), gradient.GetDeviceId());
|
||||
gradientDiagonal.SetDiagonalValue(gradient);
|
||||
Matrix<ElemType>::MultiplyAndWeightedAdd(
|
||||
(ElemType)1.0, inputValue, false, gradientDiagonal, true,
|
||||
Input(inputIndex)->ParentOverwritesGradient() ? (ElemType)0.0 : (ElemType)1.0,
|
||||
inputGradient);
|
||||
}
|
||||
else
|
||||
{
|
||||
ElementTimesNode<ElemType>::BackpropToImpl(*this, inputIndex, fr, false/*allowBroadcast*/);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -964,7 +964,7 @@ void CPUMatrix<ElemType>::SetDiagonalValue(const CPUMatrix<ElemType>& vector)
|
|||
|
||||
if (vector.GetNumElements() == 1) // reduce to simple form
|
||||
SetDiagonalValue(vector(0, 0));
|
||||
else if (vector.GetNumRows() != GetNumRows())
|
||||
else if (vector.GetNumRows() != GetNumRows() && vector.GetNumCols() != GetNumRows())
|
||||
LogicError("SetDiagonalValue: input vector's dimension does not agree with [this].");
|
||||
else
|
||||
{
|
||||
|
|
|
@ -1126,6 +1126,62 @@ template <class ElemType>
|
|||
return result;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void CPUSparseMatrix<ElemType>::InnerProduct(const CPUSparseMatrix<ElemType>& a, const CPUMatrix<ElemType>& b, CPUMatrix<ElemType>& c, const bool isColWise)
|
||||
{
|
||||
if (a.IsEmpty() || b.IsEmpty())
|
||||
LogicError("InnerProduct: one of the input matrices is empty.");
|
||||
|
||||
const int m = (int)a.GetNumRows();
|
||||
const int n = (int)a.GetNumCols();
|
||||
const int k = (int)b.GetNumRows();
|
||||
const int l = (int)b.GetNumCols();
|
||||
|
||||
assert(m > 0 && n > 0 && k > 0 && l > 0); // converting from size_t to int may cause overflow
|
||||
assert(m == k && n == l); // converting from size_t to int may cause overflow
|
||||
if (m != k || n != l)
|
||||
InvalidArgument("InnerProduct: Matrices a and b should have same dimension.");
|
||||
|
||||
if (isColWise) // col-wise
|
||||
{
|
||||
c.RequireSize(1, n);
|
||||
|
||||
#pragma omp parallel for
|
||||
foreach_column(j, c)
|
||||
{
|
||||
ElemType sum = 0;
|
||||
for (CPUSPARSE_INDEX_TYPE iRow = a.ColLocation()[j]; iRow < a.ColLocation()[j+1]; ++iRow)
|
||||
{
|
||||
size_t row = a.RowLocation()[iRow];
|
||||
sum += a.Data()[iRow] * b(row, j);
|
||||
}
|
||||
c(0, j) = sum;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c.RequireSize(m, 1);
|
||||
|
||||
#pragma omp parallel for
|
||||
foreach_row(i, c)
|
||||
{
|
||||
ElemType sum = 0;
|
||||
for(CPUSPARSE_INDEX_TYPE j = 0; j < n; ++j)
|
||||
{
|
||||
for (CPUSPARSE_INDEX_TYPE iRow = a.ColLocation()[j]; iRow < a.ColLocation()[j + 1]; ++iRow)
|
||||
{
|
||||
if (a.RowLocation()[iRow] == i)
|
||||
{
|
||||
sum += a.Data()[iRow] * b(i, j);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
c(i, 0) = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A helper method used in MomentumSGDUpdate and NesterovAcceleratedMomentumSGDUpdate.
|
||||
// Modifies the smoothed gradients "c", as well as the current gradients "this" on which this method is invoked.
|
||||
// Classic momentum (unitGainFactor == 1.0):
|
||||
|
|
|
@ -140,6 +140,8 @@ public:
|
|||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
static void InnerProduct(const CPUSparseMatrix<ElemType>& a, const CPUMatrix<ElemType>& b, CPUMatrix<ElemType>& c, const bool isColWise);
|
||||
|
||||
static void AddScaledDifference(const ElemType /*alpha*/, const CPUSparseMatrix<ElemType>& /*a*/, const CPUMatrix<ElemType>& /*b*/, CPUMatrix<ElemType>& /*c*/,
|
||||
bool /*bDefaultZero*/)
|
||||
{
|
||||
|
|
|
@ -1270,7 +1270,7 @@ void GPUMatrix<ElemType>::SetDiagonalValue(const GPUMatrix<ElemType>& vector)
|
|||
if (vector.GetNumElements() == 1) // reduce to simple form
|
||||
SetDiagonalValue(vector.Data()[0]);
|
||||
|
||||
else if (vector.GetNumRows() != GetNumRows())
|
||||
else if (vector.GetNumRows() != GetNumRows() && vector.GetNumCols() != GetNumRows())
|
||||
LogicError("SetDiagonalValue: input vector's dimension does not agree with [this].");
|
||||
else
|
||||
{
|
||||
|
@ -3848,7 +3848,7 @@ void GPUMatrix<ElemType>::InnerProduct(const GPUMatrix<ElemType>& a, const GPUMa
|
|||
else
|
||||
c.RequireSize(m, 1);
|
||||
|
||||
if ((isColWise && m == 1) || !isColWise && n == 1) // in this case it's equivalent to element-wise product
|
||||
if ((isColWise && m == 1) || (!isColWise && n == 1)) // in this case it's equivalent to element-wise product
|
||||
{
|
||||
c.AssignElementProductOf(a, b);
|
||||
}
|
||||
|
|
|
@ -2317,6 +2317,51 @@ __global__ void _innerProduct(
|
|||
c[id] = sum;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _innerProduct4SparseCSC(
|
||||
ElemType* c,
|
||||
const ElemType* a,
|
||||
const GPUSPARSE_INDEX_TYPE* aRowIndex,
|
||||
const GPUSPARSE_INDEX_TYPE* aColCSCIndex,
|
||||
const ElemType* b,
|
||||
const CUDA_LONG M, // a.GetNumRows();
|
||||
const CUDA_LONG N, // a.GetNumCols();
|
||||
const bool isColWise)
|
||||
{
|
||||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if ((isColWise && id >= N) || (!isColWise && id >= M))
|
||||
return;
|
||||
|
||||
ElemType sum = 0;
|
||||
CUDA_LONG index;
|
||||
|
||||
if (isColWise)
|
||||
{
|
||||
for (CUDA_LONG i = aColCSCIndex[id]; i < aColCSCIndex[id+1]; i++)
|
||||
{
|
||||
index = IDX2C(aRowIndex[i], id, M);
|
||||
sum += a[i] * b[index];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (CUDA_LONG j = 0; j < N; ++j)
|
||||
{
|
||||
for (CUDA_LONG i = aColCSCIndex[j]; i < aColCSCIndex[j+1]; i++)
|
||||
{
|
||||
if (aRowIndex[i] == id)
|
||||
{
|
||||
index = IDX2C(id, j, M);
|
||||
sum += a[i] * b[index];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c[id] = sum;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _assignSignOf(
|
||||
ElemType* a,
|
||||
|
|
|
@ -2087,6 +2087,55 @@ ElemType GPUSparseMatrix<ElemType>::InnerProductOfMatrices(const GPUMatrix<ElemT
|
|||
return GPUSparseMatrix<ElemType>::InnerProductOfMatrices(b, a);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::InnerProduct(const GPUSparseMatrix<ElemType>& a, const GPUMatrix<ElemType>& b, GPUMatrix<ElemType>& c, const bool isColWise)
|
||||
{
|
||||
if (a.GetComputeDeviceId() != b.GetComputeDeviceId() || b.GetComputeDeviceId() != c.GetComputeDeviceId()) // different GPUs
|
||||
InvalidArgument("All matrices must be on the same GPU");
|
||||
|
||||
if (a.IsEmpty() || b.IsEmpty())
|
||||
LogicError("Scale: one of the input matrices is empty.");
|
||||
|
||||
if (a.GetFormat() != MatrixFormat::matrixFormatSparseCSC)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
const int m = (int)a.GetNumRows();
|
||||
const int n = (int)a.GetNumCols();
|
||||
const int k = (int)b.GetNumRows();
|
||||
const int l = (int)b.GetNumCols();
|
||||
|
||||
assert(m > 0 && n > 0 && k > 0 && l > 0); // converting from size_t to int may cause overflow
|
||||
assert(m == k && n == l); // converting from size_t to int may cause overflow
|
||||
if (m != k || n != l)
|
||||
InvalidArgument("Matrices a and b should have same dimension.");
|
||||
|
||||
if (isColWise)
|
||||
c.RequireSize(1, n);
|
||||
else
|
||||
c.RequireSize(m, 1);
|
||||
|
||||
c.PrepareDevice();
|
||||
|
||||
int blocksPerGrid = 0;
|
||||
if (isColWise) // col-wise
|
||||
{
|
||||
blocksPerGrid = (int)ceil(1.0 * n / GridDim::maxThreadsPerBlock);
|
||||
}
|
||||
else
|
||||
{
|
||||
blocksPerGrid = (int)ceil(1.0 * m / GridDim::maxThreadsPerBlock);
|
||||
}
|
||||
|
||||
SyncGuard syncGuard;
|
||||
_innerProduct4SparseCSC<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
c.Data(),
|
||||
a.Data(), a.RowLocation(), a.ColLocation(),
|
||||
b.Data(),
|
||||
m, n, isColWise);
|
||||
}
|
||||
|
||||
// This is an utility function useful for debugging issues with sparse matrices.
|
||||
// It just checks that the CSC format indices are not corrupted / pointing to invalid memory.
|
||||
template <class ElemType>
|
||||
|
|
|
@ -420,6 +420,7 @@ public:
|
|||
|
||||
static ElemType InnerProductOfMatrices(const GPUSparseMatrix<ElemType>& a, const GPUMatrix<ElemType>& b);
|
||||
static ElemType InnerProductOfMatrices(const GPUMatrix<ElemType>& a, const GPUSparseMatrix<ElemType>& b);
|
||||
static void InnerProduct(const GPUSparseMatrix<ElemType>& a, const GPUMatrix<ElemType>& b, GPUMatrix<ElemType>& c, const bool isColWise);
|
||||
static void ScaleAndAdd(ElemType alpha, const GPUSparseMatrix<ElemType>& a, ElemType beta, const GPUSparseMatrix<ElemType>& b, GPUSparseMatrix<ElemType>& c);
|
||||
static void ScaleAndAdd(ElemType alpha, const GPUSparseMatrix<ElemType>& a, ElemType beta, const GPUMatrix<ElemType>& b, GPUMatrix<ElemType>& c);
|
||||
static void ScaleAndAdd(ElemType alpha, const GPUMatrix<ElemType>& a, ElemType beta, const GPUSparseMatrix<ElemType>& b, GPUMatrix<ElemType>& c);
|
||||
|
|
|
@ -1404,7 +1404,7 @@ void Matrix<ElemType>::SetDiagonalValue(const Matrix<ElemType>& vector)
|
|||
SetDiagonalValue(vector.m_GPUMatrix->Get00Element()) // BUGBUG: efficiency
|
||||
);
|
||||
}
|
||||
else if (vector.GetNumRows() != GetNumRows())
|
||||
else if (vector.GetNumRows() != GetNumRows() && vector.GetNumCols() != GetNumRows())
|
||||
LogicError("SetDiagonalValue: input vector's dimension does not agree with [this].");
|
||||
else
|
||||
{
|
||||
|
@ -5132,17 +5132,17 @@ void Matrix<ElemType>::InnerProduct(const Matrix<ElemType>& a, const Matrix<Elem
|
|||
|
||||
DecideAndMoveToRightDevice(a, b, c);
|
||||
|
||||
if (a.GetMatrixType() != b.GetMatrixType())
|
||||
if (b.GetMatrixType() != DENSE) // only support a being sparse/dense. Both b and c should be dense
|
||||
NOT_IMPLEMENTED;
|
||||
|
||||
c.SwitchToMatrixType(a.GetMatrixType(), a.GetFormat(), false);
|
||||
c.SwitchToMatrixType(b.GetMatrixType(), b.GetFormat(), false);
|
||||
|
||||
DISPATCH_MATRIX_ON_FLAG(&c,
|
||||
&c,
|
||||
DISPATCH_MATRIX_ON_FLAG(&a,
|
||||
&a,
|
||||
CPUMatrix<ElemType>::InnerProduct(*a.m_CPUMatrix, *b.m_CPUMatrix, *c.m_CPUMatrix, isColWise),
|
||||
GPUMatrix<ElemType>::InnerProduct(*a.m_GPUMatrix, *b.m_GPUMatrix, *c.m_GPUMatrix, isColWise),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
CPUSparseMatrix<ElemType>::InnerProduct(*a.m_CPUSparseMatrix, *b.m_CPUMatrix, *c.m_CPUMatrix, isColWise),
|
||||
GPUSparseMatrix<ElemType>::InnerProduct(*a.m_GPUSparseMatrix, *b.m_GPUMatrix, *c.m_GPUMatrix, isColWise));
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -357,6 +357,11 @@ ElemType GPUSparseMatrix<ElemType>::InnerProductOfMatrices(const GPUMatrix<ElemT
|
|||
return ElemType(0);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::InnerProduct(const GPUSparseMatrix<ElemType>&, const GPUMatrix<ElemType>&, GPUMatrix<ElemType>&, const bool)
|
||||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
bool GPUSparseMatrix<ElemType>::AreEqual(const GPUSparseMatrix<ElemType>& a, const GPUSparseMatrix<ElemType>& /*b*/,
|
||||
const ElemType threshold)
|
||||
|
|
|
@ -424,6 +424,17 @@ BOOST_FIXTURE_TEST_CASE(GPUSparseMatrixInnerProduct, RandomSeedFixture)
|
|||
const float x1 = GPUSparseMatrix<float>::InnerProductOfMatrices(matrixOp1CSC, matrixOp2);
|
||||
|
||||
BOOST_CHECK(fabsf(x1 - y) < c_epsilonFloatE5);
|
||||
|
||||
for(bool isColWise : {true, false})
|
||||
{
|
||||
GPUMatrix<float> matrixInnerProductDense(c_deviceIdZero);
|
||||
GPUMatrix<float> matrixInnerProductSparse(c_deviceIdZero);
|
||||
|
||||
GPUMatrix<float>::InnerProduct(matrixOp1Dense, matrixOp2, matrixInnerProductDense, isColWise);
|
||||
GPUSparseMatrix<float>::InnerProduct(matrixOp1CSC, matrixOp2, matrixInnerProductSparse, isColWise);
|
||||
|
||||
BOOST_CHECK(matrixInnerProductDense.IsEqualTo(matrixInnerProductSparse, c_epsilonFloatE5));
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(GPUSparseMatrixColumnSlice, RandomSeedFixture)
|
||||
|
|
Загрузка…
Ссылка в новой задаче