Make sparse label CE backprop faster
When minibatch size is big (like 10000), the diagonal matrix in times node for gradient could be big. This change implements a ColumnwiseScaleAndWeightedAdd to reduce the cost in that case.
This commit is contained in:
Родитель
c4b9dd5cf0
Коммит
59a2c26d7d
|
@ -591,10 +591,8 @@ public:
|
|||
Matrix<ElemType> gradient = GradientFor(fr);
|
||||
Matrix<ElemType> inputValue = InputRef(1 - inputIndex).ValueFor(fr);
|
||||
Matrix<ElemType> inputGradient = InputRef(inputIndex).GradientFor(fr);
|
||||
Matrix<ElemType> gradientDiagonal(gradient.GetNumCols(), gradient.GetNumCols(), gradient.GetDeviceId());
|
||||
gradientDiagonal.SetDiagonalValue(gradient);
|
||||
Matrix<ElemType>::MultiplyAndWeightedAdd(
|
||||
(ElemType)1.0, inputValue, false, gradientDiagonal, true,
|
||||
Matrix<ElemType>::ColumnwiseScaleAndWeightedAdd(
|
||||
(ElemType)1.0, inputValue, gradient,
|
||||
Input(inputIndex)->ParentOverwritesGradient() ? (ElemType)0.0 : (ElemType)1.0,
|
||||
inputGradient);
|
||||
// TODO: better move this special-casing into TensorView::AssignElementwiseProductOf()
|
||||
|
|
|
@ -4836,6 +4836,29 @@ void CPUMatrix<ElemType>::Multiply1x1AndWeightedAdd(ElemType alpha, const CPUMat
|
|||
c(i, j) = b(i, j) * f + c(i, j) * beta;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const CPUMatrix<ElemType>& a, const CPUMatrix<ElemType>& v, ElemType beta, CPUMatrix<ElemType>& c)
|
||||
{
|
||||
if (v.GetNumRows() != 1 && v.GetNumCols() != 1)
|
||||
InvalidArgument("the argument v must be a vector"); // v is a vector
|
||||
|
||||
if (beta == 0)
|
||||
c.RequireSize(a.GetNumRows(), a.GetNumCols());
|
||||
else
|
||||
c.VerifySize(a.GetNumRows(), a.GetNumCols()); // Can't resize if beta != 0
|
||||
|
||||
const ElemType* vd = v.Data();
|
||||
|
||||
if (beta == 0) // don't even read the memory if beta is 0
|
||||
#pragma omp parallel for
|
||||
foreach_coord(i, j, c)
|
||||
c(i, j) = alpha * a(i, j) * vd[j];
|
||||
else
|
||||
#pragma omp parallel for
|
||||
foreach_coord(i, j, c)
|
||||
c(i, j) = alpha * a(i, j) * vd[j] + c(i, j) * beta;
|
||||
}
|
||||
|
||||
/* compute singular value decomposition as
|
||||
A = U*SIGMA*VT
|
||||
W is used as temp working memory
|
||||
|
|
|
@ -412,6 +412,8 @@ public:
|
|||
static void Multiply(const CPUMatrix<ElemType>& a, const CPUMatrix<ElemType>& b, CPUMatrix<ElemType>& c);
|
||||
static void Multiply1x1AndWeightedAdd(ElemType alpha, const CPUMatrix<ElemType>& a, const CPUMatrix<ElemType>& b, ElemType beta, CPUMatrix<ElemType>& c);
|
||||
|
||||
static void ColumnwiseScaleAndWeightedAdd(ElemType alpha, const CPUMatrix<ElemType>& a, const CPUMatrix<ElemType>& v, ElemType beta, CPUMatrix<ElemType>& c);
|
||||
|
||||
static void ScaleAndAdd(ElemType alpha, const CPUMatrix<ElemType>& a, CPUMatrix<ElemType>& c);
|
||||
static void AddScaledDifference(const ElemType alpha, const CPUMatrix<ElemType>& a, const CPUMatrix<ElemType>& b, CPUMatrix<ElemType>& c);
|
||||
static void AssignScaledDifference(const ElemType alpha, const CPUMatrix<ElemType>& a, const CPUMatrix<ElemType>& b, CPUMatrix<ElemType>& c);
|
||||
|
|
|
@ -1116,6 +1116,45 @@ void CPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const CPUMatrix<E
|
|||
}
|
||||
}
|
||||
|
||||
// c[:,j] = alpha * v[j] * a[:,j] + beta * c[:,j]
|
||||
template <class ElemType>
|
||||
void CPUSparseMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const CPUSparseMatrix<ElemType>& a, const CPUMatrix<ElemType>& v, ElemType beta, CPUMatrix<ElemType>& c)
|
||||
{
|
||||
if (v.GetNumRows() != 1 && v.GetNumCols() != 1)
|
||||
InvalidArgument("the argument v must be a vector"); // v is a vector
|
||||
|
||||
if (a.GetFormat() != matrixFormatSparseCSC)
|
||||
NOT_IMPLEMENTED;
|
||||
|
||||
if (beta == 0)
|
||||
{
|
||||
c.RequireSize(a.GetNumRows(), a.GetNumCols());
|
||||
c.SetValue((ElemType)0);
|
||||
}
|
||||
else
|
||||
c.VerifySize(a.GetNumRows(), a.GetNumCols()); // Can't resize if beta != 0
|
||||
|
||||
const ElemType* vd = v.Data();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (long col = 0; col < (long)a.GetNumCols(); col++)
|
||||
{
|
||||
auto start = a.SecondaryIndexLocation()[col];
|
||||
auto end = a.SecondaryIndexLocation()[col + 1];
|
||||
|
||||
for (auto p = start; p < end; p++)
|
||||
{
|
||||
auto row = a.MajorIndexLocation()[p];
|
||||
ElemType val = a.Buffer()[p];
|
||||
|
||||
if (beta == 0) // don't even read the memory if beta is 0
|
||||
c(row, col) = alpha * vd[col] * val;
|
||||
else
|
||||
c(row, col) = alpha * vd[col] * val + beta * c(row, col);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dense += sparse
|
||||
template <class ElemType>
|
||||
void CPUSparseMatrix<ElemType>::ScaleAndAdd(const ElemType alpha, const CPUSparseMatrix<ElemType>& lhs, CPUMatrix<ElemType>& rhs)
|
||||
|
|
|
@ -132,6 +132,8 @@ public:
|
|||
static void MultiplyAndAdd(ElemType alpha, const CPUMatrix<ElemType>& lhs, const bool transposeA,
|
||||
const CPUSparseMatrix<ElemType>& rhs, const bool transposeB, CPUSparseMatrix<ElemType>& c);
|
||||
|
||||
static void ColumnwiseScaleAndWeightedAdd(ElemType alpha, const CPUSparseMatrix<ElemType>& a, const CPUMatrix<ElemType>& v, ElemType beta, CPUMatrix<ElemType>& c);
|
||||
|
||||
static void ScaleAndAdd(const ElemType alpha, const CPUSparseMatrix<ElemType>& lhs, CPUMatrix<ElemType>& c);
|
||||
|
||||
static bool AreEqual(const CPUSparseMatrix<ElemType>& a, const CPUSparseMatrix<ElemType>& b, const ElemType threshold = 1e-8);
|
||||
|
|
|
@ -3487,6 +3487,22 @@ void GPUMatrix<ElemType>::Multiply(const GPUMatrix<ElemType>& a, const GPUMatrix
|
|||
return GPUMatrix<ElemType>::MultiplyAndWeightedAdd(1, a, false, b, false, 0, c);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& a, const GPUMatrix<ElemType>& v, ElemType beta, GPUMatrix<ElemType>& c)
|
||||
{
|
||||
if (v.GetNumRows() != 1 && v.GetNumCols() != 1)
|
||||
InvalidArgument("the argument v must be a vector"); // v is a vector
|
||||
|
||||
if (beta == 0)
|
||||
c.RequireSize(a.GetNumRows(), a.GetNumCols());
|
||||
else
|
||||
c.VerifySize(a.GetNumRows(), a.GetNumCols()); // Can't resize if beta != 0
|
||||
|
||||
int blocksPerGrid = (int)ceil(1.0 * c.GetNumElements() / GridDim::maxThreadsPerBlock);
|
||||
SyncGuard syncGuard;
|
||||
_columnwiseScaleAndWeightedAdd<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >>>(alpha, a.Data(), v.Data(), beta, c.Data(), a.GetNumRows(), a.GetNumCols());
|
||||
}
|
||||
|
||||
/// <summary>Matrix-scalar multiply with col-major matrices: c = alpha * a + c</summary>
|
||||
/// if a is a column vector, add to all columns of c
|
||||
/// if a is a row vector, add to all rows of c
|
||||
|
|
|
@ -521,6 +521,8 @@ public:
|
|||
static void Multiply(const GPUMatrix<ElemType>& a, const GPUMatrix<ElemType>& b, GPUMatrix<ElemType>& c);
|
||||
static void Multiply1x1AndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& a, const GPUMatrix<ElemType>& b, ElemType beta, GPUMatrix<ElemType>& c);
|
||||
|
||||
static void ColumnwiseScaleAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& a, const GPUMatrix<ElemType>& v, ElemType beta, GPUMatrix<ElemType>& c);
|
||||
|
||||
static void ScaleAndAdd(ElemType alpha, const GPUMatrix<ElemType>& a, GPUMatrix<ElemType>& c);
|
||||
static void ScaleAndAdd(ElemType alpha, const GPUMatrix<ElemType>& a, const GPUMatrix<ElemType>& b, GPUMatrix<ElemType>& c);
|
||||
static void AddScaledDifference(const ElemType alpha, const GPUMatrix<ElemType>& a, const GPUMatrix<ElemType>& b, GPUMatrix<ElemType>& c);
|
||||
|
|
|
@ -3136,6 +3136,55 @@ __global__ void _dense1DConvMultSparseCSCTransposeAndAddToDense(
|
|||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _columnwiseScaleAndWeightedAdd(
|
||||
ElemType alpha,
|
||||
const ElemType* aData,
|
||||
const ElemType* vData,
|
||||
ElemType beta,
|
||||
ElemType* cData,
|
||||
int m, int n)
|
||||
{
|
||||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (id >= m * n)
|
||||
return;
|
||||
|
||||
CUDA_LONG col = id / m;
|
||||
|
||||
if (beta == 0) // don't even read the memory if beta is 0
|
||||
cData[id] = alpha * vData[col] * aData[id];
|
||||
else
|
||||
cData[id] = alpha * vData[col] * aData[id] + beta * cData[id];
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _columnwiseScaleAndWeightedAdd4CSC(
|
||||
ElemType alpha,
|
||||
const ElemType* aData, const GPUSPARSE_INDEX_TYPE* aSecondaryIndices, const GPUSPARSE_INDEX_TYPE* aMajorIndices,
|
||||
const ElemType* vData,
|
||||
ElemType beta,
|
||||
ElemType* cData,
|
||||
int m, int n)
|
||||
{
|
||||
CUDA_LONG col = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (col >= n)
|
||||
return;
|
||||
|
||||
GPUSPARSE_INDEX_TYPE start = aSecondaryIndices[col];
|
||||
GPUSPARSE_INDEX_TYPE end = aSecondaryIndices[col + 1];
|
||||
|
||||
for (GPUSPARSE_INDEX_TYPE p = start; p < end; p++)
|
||||
{
|
||||
GPUSPARSE_INDEX_TYPE row = aMajorIndices[p];
|
||||
ElemType val = aData[p];
|
||||
|
||||
if (beta == 0) // don't even read the memory if beta is 0
|
||||
cData[IDX2C(row, col, m)] = alpha * vData[col] * val;
|
||||
else
|
||||
cData[IDX2C(row, col, m)] = alpha * vData[col] * val + beta * cData[IDX2C(row, col, m)];
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _reshape(
|
||||
const int oldNumRows, // old row count
|
||||
|
|
|
@ -1189,6 +1189,35 @@ void GPUSparseMatrix<ElemType>::ConvolveAndWeightedAdd(ElemType alpha, const GPU
|
|||
}
|
||||
}
|
||||
|
||||
// c[:,j] = alpha * v[j] * a[:,j] + beta * c[:,j]
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const GPUSparseMatrix<ElemType>& a, const GPUMatrix<ElemType>& v, ElemType beta, GPUMatrix<ElemType>& c)
|
||||
{
|
||||
if (v.GetNumRows() != 1 && v.GetNumCols() != 1)
|
||||
InvalidArgument("the argument v must be a vector"); // v is a vector
|
||||
|
||||
if (a.GetFormat() != matrixFormatSparseCSC)
|
||||
NOT_IMPLEMENTED;
|
||||
|
||||
if (beta == 0)
|
||||
{
|
||||
c.RequireSize(a.GetNumRows(), a.GetNumCols());
|
||||
c.SetValue((ElemType)0);
|
||||
}
|
||||
else
|
||||
c.VerifySize(a.GetNumRows(), a.GetNumCols()); // Can't resize if beta != 0
|
||||
|
||||
int blocksPerGrid = (int)ceil(1.0 * a.GetNumCols() / GridDim::maxThreadsPerBlock);
|
||||
SyncGuard syncGuard;
|
||||
_columnwiseScaleAndWeightedAdd4CSC<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(
|
||||
alpha,
|
||||
a.Data(), a.SecondaryIndexLocation(), a.MajorIndexLocation(),
|
||||
v.Data(),
|
||||
beta,
|
||||
c.Data(),
|
||||
a.GetNumRows(), a.GetNumCols());
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::TensorShuffleScaleAndAdd(ElemType keepWeight, const GPUSparseMatrix<ElemType>& a, size_t D, size_t S, size_t M, size_t K, size_t T,
|
||||
ElemType scaleFactor, const GPUSparseMatrix<ElemType>& b, GPUSparseMatrix<ElemType>& c)
|
||||
|
@ -3060,6 +3089,7 @@ template GPUMatrix<char> GPUSparseMatrix<char>::CopyColumnSliceToDense(size_t, s
|
|||
template GPUSparseMatrix<char>& GPUSparseMatrix<char>::operator=(GPUSparseMatrix<char>&&);
|
||||
template void GPUSparseMatrix<char>::Reshape(const size_t, const size_t);
|
||||
template void GPUSparseMatrix<char>::ScaleAndAdd(char, GPUSparseMatrix<char> const &, GPUMatrix<char> &);
|
||||
template void GPUSparseMatrix<char>::ColumnwiseScaleAndWeightedAdd(char, const GPUSparseMatrix<char>&, const GPUMatrix<char>&, char, GPUMatrix<char>&);
|
||||
|
||||
// Support <short>
|
||||
template GPUSparseMatrix<short>::GPUSparseMatrix(DEVICEID_TYPE, const MatrixFormat);
|
||||
|
@ -3084,6 +3114,7 @@ template GPUMatrix<short> GPUSparseMatrix<short>::CopyColumnSliceToDense(size_t,
|
|||
template GPUSparseMatrix<short>& GPUSparseMatrix<short>::operator=(GPUSparseMatrix<short>&&);
|
||||
template void GPUSparseMatrix<short>::Reshape(const size_t, const size_t);
|
||||
template void GPUSparseMatrix<short>::ScaleAndAdd(short, GPUSparseMatrix<short> const &, GPUMatrix<short> &);
|
||||
template void GPUSparseMatrix<short>::ColumnwiseScaleAndWeightedAdd(short, const GPUSparseMatrix<short>&, const GPUMatrix<short>&, short, GPUMatrix<short>&);
|
||||
|
||||
template GPUSparseMatrix<int>::GPUSparseMatrix(DEVICEID_TYPE, const MatrixFormat);
|
||||
template GPUSparseMatrix<int>::~GPUSparseMatrix();
|
||||
|
|
|
@ -409,6 +409,9 @@ public:
|
|||
const bool transposeD, ElemType beta, GPUMatrix<ElemType>& C);
|
||||
static void MultiplyAndAdd(ElemType alpha, const GPUMatrix<ElemType>& lhs, const bool transposeA, const GPUSparseMatrix<ElemType>& rhs,
|
||||
const bool transposeB, GPUSparseMatrix<ElemType>& c);
|
||||
|
||||
static void ColumnwiseScaleAndWeightedAdd(ElemType alpha, const GPUSparseMatrix<ElemType>& a, const GPUMatrix<ElemType>& v, ElemType beta, GPUMatrix<ElemType>& c);
|
||||
|
||||
static void ScaleAndAdd(const ElemType alpha, const GPUSparseMatrix<ElemType>& lhs, GPUMatrix<ElemType>& c);
|
||||
static void ConvolveAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& lhs, const bool transposeA, const GPUSparseMatrix<ElemType>& rhs,
|
||||
const bool transposeB, ElemType beta, GPUMatrix<ElemType>& c, size_t numChannels, size_t horizontalSubsample, bool padding, bool channelwise);
|
||||
|
|
|
@ -4954,6 +4954,25 @@ void Matrix<ElemType>::ConvolveAndWeightedAdd(ElemType alpha, const Matrix<ElemT
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>Columnwise scale with col-major matrix and accumulate.</summary>
|
||||
/// <param name="alpha">Scalar</param>
|
||||
/// <param name="a">Input matrix</param>
|
||||
/// <param name="v">Input scale vector for each column of a</param>
|
||||
/// <param name="beta">Scalar</param>
|
||||
/// <param name="c">Resulting matrix, the same shape as a</param>
|
||||
template <class ElemType>
|
||||
void Matrix<ElemType>::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const Matrix<ElemType>& a, const Matrix<ElemType>& v, ElemType beta, Matrix<ElemType>& c)
|
||||
{
|
||||
DecideAndMoveToRightDevice(a, v, c);
|
||||
|
||||
DISPATCH_MATRIX_ON_FLAG(&a,
|
||||
nullptr,
|
||||
CPUMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(alpha, *a.m_CPUMatrix, *v.m_CPUMatrix, beta, *c.m_CPUMatrix),
|
||||
GPUMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(alpha, *a.m_GPUMatrix, *v.m_GPUMatrix, beta, *c.m_GPUMatrix),
|
||||
CPUSparseMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(alpha, *a.m_CPUSparseMatrix, *v.m_CPUMatrix, beta, *c.m_CPUMatrix),
|
||||
GPUSparseMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(alpha, *a.m_GPUSparseMatrix, *v.m_GPUMatrix, beta, *c.m_GPUMatrix));
|
||||
}
|
||||
|
||||
/// <summary>Matrix-scalar multiply with col-major matrices: c = alpha * a + c</summary>
|
||||
/// if a is a column vector, add to all columns of c
|
||||
/// if a is a row vector, add to all rows of c
|
||||
|
|
|
@ -568,6 +568,8 @@ public:
|
|||
static void Multiply1x1AndWeightedAdd(ElemType alpha, const Matrix<ElemType>& a, const Matrix<ElemType>& b, ElemType beta, Matrix<ElemType>& c);
|
||||
static void ConvolveAndWeightedAdd(ElemType alpha, const Matrix<ElemType>& a, const bool transposeA, const Matrix<ElemType>& b, const bool transposeB, ElemType beta, Matrix<ElemType>& c, size_t numChannels, size_t horizontalSubsample, bool padding, bool channelwise);
|
||||
|
||||
static void ColumnwiseScaleAndWeightedAdd(ElemType alpha, const Matrix<ElemType>& a, const Matrix<ElemType>& v, ElemType beta, Matrix<ElemType>& c);
|
||||
|
||||
static void ScaleAndAdd(ElemType alpha, const Matrix<ElemType>& a, Matrix<ElemType>& c);
|
||||
static void ScaleAndAdd(ElemType alpha, const Matrix<ElemType>& a, ElemType beta, Matrix<ElemType>& c);
|
||||
static void AddScaledDifference(const ElemType alpha, const Matrix<ElemType>& a, const Matrix<ElemType>& b, Matrix<ElemType>& c);
|
||||
|
|
|
@ -233,6 +233,11 @@ void GPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const GPUMatrix<E
|
|||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const GPUSparseMatrix<ElemType>& a, const GPUMatrix<ElemType>& v, ElemType beta, GPUMatrix<ElemType>& c)
|
||||
{
|
||||
}
|
||||
|
||||
// used for gradients udpate
|
||||
template <class ElemType>
|
||||
void GPUSparseMatrix<ElemType>::ScaleAndAdd(const ElemType alpha, const GPUSparseMatrix<ElemType>& lhs, GPUMatrix<ElemType>& rhs)
|
||||
|
@ -1952,6 +1957,11 @@ void GPUMatrix<ElemType>::Multiply(const GPUMatrix<ElemType>& /*a*/, const GPUMa
|
|||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const GPUMatrix<ElemType>& a, const GPUMatrix<ElemType>& v, ElemType beta, GPUMatrix<ElemType>& c)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>Matrix-scalar multiply with col-major matrices: c = alpha * a + c</summary>
|
||||
/// if a is a column vector, add to all columns of c
|
||||
/// if a is a row vector, add to all rows of c
|
||||
|
|
|
@ -234,8 +234,8 @@ Test module "MathTests" has passed with:
|
|||
770 assertions out of 770 passed
|
||||
|
||||
Test suite "CPUMatrixSuite" has passed with:
|
||||
35 test cases out of 35 passed
|
||||
5248724 assertions out of 5248724 passed
|
||||
36 test cases out of 36 passed
|
||||
5248732 assertions out of 5248732 passed
|
||||
|
||||
Test case "CPUMatrixSuite/CPUMatrixConstructorNoFlags" has passed with:
|
||||
8 assertions out of 8 passed
|
||||
|
@ -318,6 +318,9 @@ Test module "MathTests" has passed with:
|
|||
Test case "CPUMatrixSuite/MatrixMultiplyAndPlusAndMinus" has passed with:
|
||||
1050120 assertions out of 1050120 passed
|
||||
|
||||
Test case "CPUMatrixSuite/MatrixColumnwiseScaleAndWeightedAdd" has passed with:
|
||||
8 assertions out of 8 passed
|
||||
|
||||
Test case "CPUMatrixSuite/MatrixScaleAndAdd" has passed with:
|
||||
1572864 assertions out of 1572864 passed
|
||||
|
||||
|
|
|
@ -119,6 +119,37 @@ BOOST_FIXTURE_TEST_CASE(MatrixMultiplyAndPlusAndMinus, RandomSeedFixture)
|
|||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixColumnwiseScaleAndWeightedAdd, RandomSeedFixture)
|
||||
{
|
||||
size_t m = 256;
|
||||
size_t n = 64;
|
||||
for(int deviceId : {-1, 0})
|
||||
{
|
||||
SingleMatrix singleMatrixA(deviceId);
|
||||
singleMatrixA.AssignTruncateBottomOf(SingleMatrix::RandomUniform(m, n, deviceId, -200, 1, IncrementCounter()), 0);
|
||||
const SingleMatrix singleMatrixB = SingleMatrix::RandomUniform(n, 1, deviceId, 0, 1, IncrementCounter());
|
||||
SingleMatrix singleMatrixAcsc = singleMatrixA.DeepClone();
|
||||
singleMatrixAcsc.SwitchToMatrixType(SPARSE, matrixFormatSparseCSC, true);
|
||||
|
||||
SingleMatrix singleMatrixCexpected = SingleMatrix::RandomUniform(m, n, deviceId, 0, 1, IncrementCounter());
|
||||
SingleMatrix singleMatrixCdense = singleMatrixCexpected.DeepClone();
|
||||
SingleMatrix singleMatrixCsparse = singleMatrixCexpected.DeepClone();
|
||||
|
||||
SingleMatrix singleMatrixBdiag(n, n, deviceId);
|
||||
singleMatrixBdiag.SetValue(0);
|
||||
singleMatrixBdiag.SetDiagonalValue(singleMatrixB);
|
||||
for(float beta : {0.0f, 1.0f})
|
||||
{
|
||||
SingleMatrix::MultiplyAndWeightedAdd(1, singleMatrixA, false, singleMatrixBdiag, false, beta, singleMatrixCexpected);
|
||||
SingleMatrix::ColumnwiseScaleAndWeightedAdd(1, singleMatrixA, singleMatrixB, beta, singleMatrixCdense);
|
||||
SingleMatrix::ColumnwiseScaleAndWeightedAdd(1, singleMatrixAcsc, singleMatrixB, beta, singleMatrixCsparse);
|
||||
|
||||
BOOST_CHECK(singleMatrixCexpected.IsEqualTo(singleMatrixCdense, c_epsilonFloatE4));
|
||||
BOOST_CHECK(singleMatrixCexpected.IsEqualTo(singleMatrixCsparse, c_epsilonFloatE4));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_FIXTURE_TEST_CASE(MatrixScaleAndAdd, RandomSeedFixture)
|
||||
{
|
||||
std::mt19937 rng(0);
|
||||
|
|
Загрузка…
Ссылка в новой задаче