This commit is contained in:
Vadim Mazalov 2016-11-16 18:20:50 -08:00
Родитель d53d13ec3e
Коммит faef2b540a
21 изменённых файлов: 382 добавлений и 72 удалений

Просмотреть файл

@ -1114,6 +1114,7 @@ UNITTEST_MATH_SRC = \
$(SOURCEDIR)/../Tests/UnitTests/MathTests/CPUSparseMatrixTests.cpp \
$(SOURCEDIR)/../Tests/UnitTests/MathTests/fixtures.cpp \
$(SOURCEDIR)/../Tests/UnitTests/MathTests/QuantizersTests.cpp \
$(SOURCEDIR)/../Tests/UnitTests/MathTests/QuantizedOperationsTests.cpp \
$(SOURCEDIR)/../Tests/UnitTests/MathTests/TensorTests.cpp \
$(SOURCEDIR)/../Tests/UnitTests/MathTests/GPUMatrixCudaBlasTests.cpp \
$(SOURCEDIR)/../Tests/UnitTests/MathTests/GPUMatrixTests.cpp \

Просмотреть файл

@ -645,6 +645,7 @@ Tanh(z, tag='') = new ComputationNode [ operation = 'Tanh' ; inputs = _AsNodes (
TimeReverse(vectorSequence, tag='') = new ComputationNode [ operation = 'TimeReverse' ; inputs = _AsNodes (vectorSequence) /*plus the function args*/ ]
Trace (node, say='', logFrequency=100, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[], tag='') = new ComputationNode [ operation = 'Trace' ; inputs = _AsNodes (node) ]
TransposeTimes(leftMatrix, rightMatrix, tag='') = new ComputationNode [ operation = 'TransposeTimes' ; inputs = _AsNodes (leftMatrix : rightMatrix) /*plus the function args*/ ]
QuantizedTimes(leftMatrix, rightMatrix, bitSmoothingA=1, bitSmoothingB=1, outputRank=1, inferInputRankToMap=-1, tag='') = new ComputationNode [ operation = 'QuantizedTimes' ; inputs = _AsNodes (leftMatrix : rightMatrix) /*plus the function args*/ ]
Where(cond, tag='') = new ComputationNode [ operation = 'Where' ; inputs = _AsNodes (cond) /*plus the function args*/ ]
##############################################################################

Просмотреть файл

@ -16,6 +16,7 @@
#include <stdarg.h>
#ifdef _WIN32
#include <Windows.h>
#undef max
#endif
#if __unix__
#include <dlfcn.h> // for Plugin
@ -586,6 +587,7 @@ struct nocase_compare
// ----------------------------------------------------------------------------
// Array class
// Wrapper that holds pointer to data, as well as size
template <class T>
class ArrayRef
{
@ -612,7 +614,9 @@ public:
// TODO: Move assignment operator
ArrayRef& operator=(ArrayRef&& rhs) = delete;
size_t size() const { return count; }
size_t size() const { return count; }
void setSize(size_t size) { count = size; }
T* data() const { return elements; }
T operator[](size_t i) const

Просмотреть файл

@ -124,6 +124,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == OperationNameOf(TimesNode)) return New<TimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(TransposeDimensionsNode)) return New<TransposeDimensionsNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(TransposeTimesNode)) return New<TransposeTimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(QuantizedTimesNode)) return New<QuantizedTimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(WhereNode)) return New<WhereNode<ElemType>>(forward<_Types>(_Args)...);
// legacy names we also support for back compat of model-files
else if (nodeType == L"ColumnElementTimes") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
@ -693,6 +694,12 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Trans
return net.AddNodeToNetAndAttachInputs(New<TransposeTimesNode<ElemType>>(net.GetDeviceId(), nodeName), { a, b });
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::QuantizedTimes(const ComputationNodePtr a, const ComputationNodePtr b, size_t bitSmoothingA, size_t bitSmoothingB, size_t outputRank, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<QuantizedTimesNode<ElemType>>(net.GetDeviceId(), nodeName, bitSmoothingA, bitSmoothingB, outputRank), { a, b });
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName)
{

Просмотреть файл

@ -183,6 +183,7 @@ public:
ComputationNodePtr Times(const ComputationNodePtr a, const ComputationNodePtr b, size_t outputRank = 1, const std::wstring nodeName = L"");
ComputationNodePtr TransposeDimensions(const ComputationNodePtr matrix, int dim1, int dim2, const std::wstring nodeName = L"");
ComputationNodePtr TransposeTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr QuantizedTimes(const ComputationNodePtr a, const ComputationNodePtr b, size_t bitSmoothingA = 1, size_t bitSmoothingB = 1, size_t outputRank = 1, const std::wstring nodeName = L"");
#if 1 // legacy
ComputationNodePtr LegacyReshape(const ComputationNodePtr a, const size_t num_rows, const TensorShape& imageLayout, const std::wstring nodeName = L"");
#endif

Просмотреть файл

@ -168,7 +168,9 @@ void ComputationNetwork::DeleteNode(const std::wstring& nodeName)
}
// replace a named node by newNode of the same type under the same name, including moving over all network links
// This is used in the KL-reg based adaptation to reduce feature copy
// This is used in
// 1. Update nodes to quantized versions.
// 2. The KL-reg based adaptation to reduce feature copy (deprecated)
// need to update all the mappings as well childrens.
void ComputationNetwork::ReplaceNode(wstring nodeName, ComputationNodeBasePtr newNode)
{
@ -176,8 +178,6 @@ void ComputationNetwork::ReplaceNode(wstring nodeName, ComputationNodeBasePtr ne
if (newNode->NodeName() != nodeName) // TODO: This was not tested for earlier; I hope no code depends on this.
InvalidArgument("ChangeNode: newNode must have the same name as the old node.");
if (oldNode->OperationName() != newNode->OperationName())
InvalidArgument("ReplaceNode: newNode must have the same type as the old node.");
InvalidateCompiledNetwork();

Просмотреть файл

@ -8,7 +8,6 @@
#include "ComputationNode.h"
#include "Matrix.h"
#include "TensorView.h"
#include <unordered_set>
#include <map>
#include <string>
@ -19,6 +18,8 @@
#include <algorithm>
#include <utility>
#include <assert.h>
#include "Quantizers.h"
#include "InputAndParamNodes.h"
namespace Microsoft { namespace MSR { namespace CNTK {
@ -288,7 +289,7 @@ public:
m_inferInputRankToMap = -1;
}
private:
protected:
// if the left argument of the matrix product (A) has a time axis, it can only be applied sample by sample
// where each sample is treated as a separate matrix object (as a consequence, it then also applies to B and the result as well)
TensorView<ElemType> OneSampleTensorFor(int inputIndex/*-1 for output*/, bool gradient/*instead of value*/, const FrameRange& fr)
@ -304,6 +305,7 @@ private:
return TensorView<ElemType>(data, tensorShape);
}
private:
// Check if TimesNodeBase could be simplified to ElementTimes to avoid unroll when:
// 1. input0: DENSE, is rank-1 and transposed, or is rank-2 with Dim(0)==1
// 2. input1: DENSE, is rank-1
@ -360,7 +362,7 @@ public:
auto input0 = OneSampleTensorFor(0, /*gradient=*/false, fr.AllowBroadcast());
auto input1 = OneSampleTensorFor(1, /*gradient=*/false, fr.AllowBroadcast());
auto output = OneSampleTensorFor(-1, /*gradient=*/false, fr);
output.AssignMatrixProductOf(false/*transC*/, input0, m_transpose/*transA*/, input1, false/*transB*/);
output.AssignMatrixProductOf(false/*transC*/, input0, m_transpose/*transA*/, input1, false/*transB*/, 1.0f, this->m_pQuantizedMultiplier);
}
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
@ -583,6 +585,9 @@ public:
size_t OutputRank() const { return m_outputRank; }
int InferInputRankToMap() const { return m_inferInputRankToMap; }
protected:
shared_ptr<QuantizedMultiplier<ElemType>> m_pQuantizedMultiplier;
private:
size_t m_outputRank;
int m_inferInputRankToMap; // -1 (not specified) or says how to expand shape of W, to keep this many mapping dims
@ -655,6 +660,96 @@ public:
template class TransposeTimesNode<float>;
template class TransposeTimesNode<double>;
// Fixed-point matrix product. This scales inputs to 16bit signed integers by Symmetric quantizers, performs
// integer multiplication using SSE/AVX2, and transforms the results back.
// Only dense untransposed matrix multiplication will be quantized. If at least one matrix is sparse then it will fall back to un-quantized default evaluation
// Currently it works for CPU only. On GPU logicError will be thrown.
// One way to include this node to the network is with the Edit command:
// ...
// node => if node.name == 'LSTMoutput1.output' then QuantizedTimes(node.inputs[0], node.inputs[1], bitShiftA=1, bitShiftB=2) else node,
// ...
// bitShift(A|B) - bit shift parameters of quantizers for matrices A and B, see the quantizers for more details. Decreases the maximum range of quantziation by 2^bitShift to prevent integer overflow during BLAS routines.
// bitShift=0 doesn't change the range; higher bitShift will decrease precision of quantization, but will make BLAS routines less prone to overflow.
// Other parameters - refer to the base multiplication class
template <class ElemType>
class QuantizedTimesNode : public TimesNodeBase<ElemType, false>
{
typedef TimesNodeBase<ElemType, false> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"QuantizedTimes";
}
private:
// Quantizer bit shift for matrices A and B
size_t m_bitShiftA;
size_t m_bitShiftB;
public:
QuantizedTimesNode(DEVICEID_TYPE deviceId, const wstring& name, size_t bitShiftA = 1, size_t bitShiftB = 1, size_t outputRank = 1, int inferInputRankToMap = -1)
: Base(deviceId, name, outputRank, inferInputRankToMap), m_bitShiftA(bitShiftA), m_bitShiftB(bitShiftB)
{
// TODO support multiplication on GPUs as well.
if (deviceId != CPUDEVICE)
LogicError("Quantized operation is supposed to be used on CPU device only.");
shared_ptr<SymmetricQuantizer<ElemType, short>> pQA(new SymmetricQuantizer<ElemType, short>(m_bitShiftA));
shared_ptr<SymmetricQuantizer<ElemType, short>> qQB(new SymmetricQuantizer<ElemType, short>(m_bitShiftB));
this->m_pQuantizedMultiplier = shared_ptr<QuantizedMultiplier<ElemType>>(new QuantizedMultiplier<ElemType>(pQA, qQB));
}
QuantizedTimesNode(const ScriptableObjects::IConfigRecordPtr configp)
: QuantizedTimesNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"bitShiftA"), configp->Get(L"bitShiftB"), configp->Get(L"outputRank"), configp->Get(L"inferInputRankToMap"))
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
{
Base::CopyTo(nodeP, newName, flags);
if (flags & CopyNodeFlags::copyNodeValue)
{
auto node = dynamic_pointer_cast<QuantizedTimesNode<ElemType>>(nodeP);
node->m_bitShiftA = m_bitShiftA;
node->m_bitShiftB = m_bitShiftB;
}
}
void Save(File& fstream) const
{
Base::Save(fstream);
fstream << m_bitShiftA;
fstream << m_bitShiftB;
}
virtual void Load(File& fstream, size_t modelVersion) override
{
Base::Load(fstream, modelVersion);
fstream >> m_bitShiftA;
fstream >> m_bitShiftB;
}
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
if (dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(0)))
this->m_pQuantizedMultiplier->SetIsAConstant(true);
if (dynamic_pointer_cast<LearnableParameter<ElemType>>(Input(1)))
this->m_pQuantizedMultiplier->SetIsBConstant(true);
Base::ForwardProp(fr);
}
virtual void /*ComputationNode::*/ BackpropTo(const size_t /*inputIndex*/, const FrameRange& /*fr*/) override
{
// This operation is intended only for inference
NOT_IMPLEMENTED;
}
};
template class QuantizedTimesNode<float>;
template class QuantizedTimesNode<double>;
// -----------------------------------------------------------------------
// SumElementsNode (input)
// Sums up all elements in the input across all samples into a single scalar.

Просмотреть файл

@ -4624,7 +4624,7 @@ void CPUMatrix<ElemType>::BatchNormalizationBackward(const CPUMatrix<ElemType>&
/// <param name="c">Resulting matrix, user is responsible for allocating this</param>
template <class ElemType>
void CPUMatrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const CPUMatrix<ElemType>& a, const bool transposeA, const CPUMatrix<ElemType>& b, const bool transposeB,
ElemType beta, CPUMatrix<ElemType>& c)
ElemType beta, CPUMatrix<ElemType>& c, shared_ptr<QuantizedMultiplier<ElemType>> pQuantizedMultiplier)
{
if (a.IsEmpty() || b.IsEmpty())
return;
@ -4676,14 +4676,25 @@ void CPUMatrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const CPUMatrix
ldc = (int) c.GetNumRows();
if (sizeof(ElemType) == sizeof(double))
if (pQuantizedMultiplier == nullptr)
{
cblas_dgemm((CBLAS_ORDER) (int)MatrixOrder::ColMajor, mklTransA, mklTransB, m, n, k, alpha, reinterpret_cast<double*>(a.Data()), lda, reinterpret_cast<double*>(b.Data()), ldb, beta, reinterpret_cast<double*>(c.Data()), ldc);
if (sizeof(ElemType) == sizeof(double))
{
cblas_dgemm((CBLAS_ORDER) (int)MatrixOrder::ColMajor, mklTransA, mklTransB, m, n, k, alpha, reinterpret_cast<double*>(a.Data()), lda, reinterpret_cast<double*>(b.Data()), ldb, beta, reinterpret_cast<double*>(c.Data()), ldc);
}
else
{
#pragma warning(suppress : 4244)
cblas_sgemm((CBLAS_ORDER) (int)MatrixOrder::ColMajor, mklTransA, mklTransB, m, n, k, alpha, reinterpret_cast<float*>(a.Data()), lda, reinterpret_cast<float*>(b.Data()), ldb, beta, reinterpret_cast<float*>(c.Data()), ldc);
}
}
else
{
#pragma warning(suppress : 4244)
cblas_sgemm((CBLAS_ORDER) (int)MatrixOrder::ColMajor, mklTransA, mklTransB, m, n, k, alpha, reinterpret_cast<float*>(a.Data()), lda, reinterpret_cast<float*>(b.Data()), ldb, beta, reinterpret_cast<float*>(c.Data()), ldc);
// TODO: support transpose product
if (mklTransA == CBLAS_TRANSPOSE::CblasTrans || mklTransB == CBLAS_TRANSPOSE::CblasTrans)
LogicError("Quantized multiplier currently doesn't support transpose.");
pQuantizedMultiplier->Multiply(m, n, k, a.Data(), b.Data(), c.Data());
}
}

Просмотреть файл

@ -13,6 +13,7 @@
#include <stdio.h>
#include <ctime>
#include <limits.h>
#include "QuantizedOperations.h"
//#include "GPUMatrix.h"
//#include "CPUSparseMatrix.h"
@ -394,7 +395,7 @@ public:
// static BLAS functions
static void SVD(const CPUMatrix<ElemType>& A, CPUMatrix<ElemType>& SIGMA, CPUMatrix<ElemType>& U, CPUMatrix<ElemType>& VT, CPUMatrix<ElemType>& W);
static void MultiplyAndWeightedAdd(ElemType alpha, const CPUMatrix<ElemType>& a, const bool transposeA, const CPUMatrix<ElemType>& b, const bool transposeB, ElemType beta, CPUMatrix<ElemType>& c);
static void MultiplyAndWeightedAdd(ElemType alpha, const CPUMatrix<ElemType>& a, const bool transposeA, const CPUMatrix<ElemType>& b, const bool transposeB, ElemType beta, CPUMatrix<ElemType>& c, shared_ptr<QuantizedMultiplier<ElemType>> pQuantizedMultiplier=nullptr);
static void MultiplyAndAdd(const CPUMatrix<ElemType>& a, const bool transposeA, const CPUMatrix<ElemType>& b, const bool transposeB, CPUMatrix<ElemType>& c);
static void Multiply(const CPUMatrix<ElemType>& a, const bool transposeA, const CPUMatrix<ElemType>& b, const bool transposeB, CPUMatrix<ElemType>& c);
static void Multiply(const CPUMatrix<ElemType>& a, const CPUMatrix<ElemType>& b, CPUMatrix<ElemType>& c);

Просмотреть файл

@ -177,6 +177,7 @@
<ClInclude Include="TensorOps.h" />
<ClInclude Include="TensorView.h" />
<ClInclude Include="Quantizers.h" />
<ClInclude Include="QuantizedOperations.h" />
<None Include="GPUWatcher.cu" />
<None Include="GPUWatcher.h">
<FileType>CppHeader</FileType>

Просмотреть файл

@ -128,6 +128,7 @@
<Filter>CPU</Filter>
</ClInclude>
<ClInclude Include="Quantizers.h" />
<ClInclude Include="QuantizedOperations.h" />
<ClInclude Include="BlockMultiplierMatrixUtil.h" />
<ClInclude Include="DataTransferer.h" />
</ItemGroup>

Просмотреть файл

@ -17,6 +17,7 @@
#include "GPUWatcher.h" // bring in this class as well so that it gets exported from this DLL
#include <memory>
#include <atomic>
#include "Quantizers.h"
#ifndef CPUONLY
#pragma comment(lib, "MathCUDA.lib") // built by CNTKMathCUDA project
#endif
@ -4502,7 +4503,7 @@ void Matrix<ElemType>::SVD(const Matrix<ElemType>& A, Matrix<ElemType>& SIGMA, M
/// <param name="c">Resulting matrix, user is responsible for allocating this</param>
template <class ElemType>
void Matrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemType>& a, const bool transposeA, const Matrix<ElemType>& b, const bool transposeB,
ElemType beta, Matrix<ElemType>& c)
ElemType beta, Matrix<ElemType>& c, shared_ptr<QuantizedMultiplier<ElemType>> pQuantizedMultiplier)
{
DecideAndMoveToRightDevice(a, b, c);
@ -4552,7 +4553,7 @@ void Matrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemT
else // CPU, DENSE * DENSE -> DENSE (matrix c enforced to be DENSE)
{
c.SwitchToMatrixType(MatrixType::DENSE, matrixFormatDense, false);
CPUMatrix<ElemType>::MultiplyAndWeightedAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUMatrix, transposeB, beta, *c.m_CPUMatrix);
CPUMatrix<ElemType>::MultiplyAndWeightedAdd(alpha, *a.m_CPUMatrix, transposeA, *b.m_CPUMatrix, transposeB, beta, *c.m_CPUMatrix, pQuantizedMultiplier);
c.SetDataLocation(CPU, DENSE);
}
}

Просмотреть файл

@ -19,6 +19,7 @@
#include <memory> // for shared_ptr
#include <array>
#include <initializer_list>
#include "QuantizedOperations.h"
// This class is exported from the Math.dll
namespace Microsoft { namespace MSR { namespace CNTK {
@ -540,7 +541,7 @@ public:
// singular value decomposition of A as A = U*SIGMA*VT
static void SVD(const Matrix<ElemType>& A, Matrix<ElemType>& SIGMA, Matrix<ElemType>& U, Matrix<ElemType>& VT, Matrix<ElemType>& W);
static void MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemType>& a, const bool transposeA, const Matrix<ElemType>& b, const bool transposeB, ElemType beta, Matrix<ElemType>& c); // SGEMM
static void MultiplyAndWeightedAdd(ElemType alpha, const Matrix<ElemType>& a, const bool transposeA, const Matrix<ElemType>& b, const bool transposeB, ElemType beta, Matrix<ElemType>& c, shared_ptr<QuantizedMultiplier<ElemType>> pQuantizedMultiplier=nullptr); // SGEMM
static void MultiplyAndAdd(const Matrix<ElemType>& a, const bool transposeA, const Matrix<ElemType>& b, const bool transposeB, Matrix<ElemType>& c);
static void Multiply(const Matrix<ElemType>& a, const bool transposeA, const Matrix<ElemType>& b, const bool transposeB, Matrix<ElemType>& c);
static void Multiply(const Matrix<ElemType>& a, const Matrix<ElemType>& b, Matrix<ElemType>& c);

Просмотреть файл

@ -0,0 +1,89 @@
//
// Copyright (c) Microsoft. All rights resized.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "Quantizers.h"
namespace Microsoft { namespace MSR { namespace CNTK {
// Quantized product of two dense matrices A and B, where each matrix has its own quantizer.
// This class handles quantization of both matrices, product and de-quantization of the result.
// Other implementations should inherit from this class or extract common methods to the base class and inherit from the base.
template <class ElemType>
class QuantizedMultiplier
{
// Quantizers for matrices A and B
shared_ptr<QuantizerBase<ElemType, short>> m_pQuantizerA;
shared_ptr<QuantizerBase<ElemType, short>> m_pQuantizerB;
// Placeholders for quantized matrices A and B
vector<short> m_pMatA, m_pMatB;
// Whether matrices A and B are constant (i.e. weights)
// If the matrix is constant, the size of the underlying container for quatized values will be preserved for
// the lifespan of the object
bool m_isAConstant;
bool m_isBConstant;
bool m_firstPass;
public:
QuantizedMultiplier(shared_ptr<QuantizerBase<ElemType, short>> pQuantizerA, bool isAConstant, shared_ptr<QuantizerBase<ElemType, short>> pQuantizerB, bool isBConstant) :
m_pQuantizerA(pQuantizerA), m_pQuantizerB(pQuantizerB), m_isAConstant(isAConstant), m_isBConstant(isBConstant), m_firstPass(true)
{
if (isAConstant && isBConstant)
LogicError("Quantized multiplication is applied to two constant matrices -- it is highly inefficient. Better approach is to replace the operation with the resulting matrix.");
};
QuantizedMultiplier(shared_ptr<QuantizerBase<ElemType, short>> pQuantizerA, shared_ptr<QuantizerBase<ElemType, short>> pQuantizerB) :
QuantizedMultiplier(pQuantizerA, false, pQuantizerB, false)
{
};
// A[m,k]*B[k,n] = C[m,n]
void Multiply(int m, int n, int k, ElemType* A, ElemType* B, ElemType* C)
{
// Quantize
if (!m_isAConstant || m_firstPass)
{
m_pMatA.resize(m*k);
ArrayRef<short> refMatA(m_pMatA.data(), m_pMatA.size());
m_pQuantizerA->Quantize(ArrayRef<ElemType>(A, m_pMatA.size()), refMatA);
}
if (!m_isBConstant || m_firstPass)
{
m_pMatB.resize(n*k);
ArrayRef<short> refMatB(m_pMatB.data(), m_pMatB.size());
m_pQuantizerB->Quantize(ArrayRef<ElemType>(B, m_pMatB.size()), refMatB);
}
m_firstPass = false;
// Do multiply
// Naive inefficient product, just for demonstation
// TODO: replace with an efficient version, e.g. IPG, block multiplier, Eigen, gemmlowp, etc.
for (size_t i = 0; i < m; i++)
for (size_t j = 0; j < n; j++)
{
int dotProduct=0;
for (size_t l = 0; l < k; l++)
{
// CNTK is using column-major storage
dotProduct += m_pMatA[i + l*m] * m_pMatB[l + k*j];
}
C[i + j*m] = (ElemType)dotProduct;
}
// De-quantize
int mn = m*n;
m_pQuantizerB->Dequantize(C, C, mn);
m_pQuantizerA->Dequantize(C, C, mn);
}
void SetIsAConstant(bool v) { m_isAConstant = v; }
void SetIsBConstant(bool v) { m_isBConstant = v; }
};
}}}

Просмотреть файл

@ -18,7 +18,9 @@ public:
rangeMax = std::numeric_limits<QuantizedType>::max();
}
virtual void Quantize(const ArrayRef<RawType>& input, ArrayRef<QuantizedType>& output) = 0;
virtual void Dequantize(const ArrayRef<QuantizedType>& input, ArrayRef<RawType>& output) = 0;
virtual void Dequantize(const ArrayRef<RawType>& input, ArrayRef<RawType>& output) = 0;
virtual void Dequantize(const RawType* input, RawType* output, size_t size) = 0;
protected:
QuantizedType rangeMax;
@ -27,56 +29,68 @@ protected:
// Symmetric quantizer.
// Quantization is achieved by
// 1. Finding the absolute max of values to be quantized.
// 2. Adjusting the absolute max with extraBits parameter.
// 3. Scaling all values in the collection to be within the symmetric range of the QuantizedType
// 2. Adjusting the max with bit shifting specified with the bitShift parameter (see comment at the declaration of the parameter)
// 3. Scaling all values in the collection to be within the symmetric range of the signed integer (QuantizedType)
template <class RawType, class QuantizedType>
class SymmetricQuantizer : public QuantizerBase<RawType, QuantizedType>
{
RawType m_quantizeFactor;
RawType m_inverseQuantizerFactor;
RawType m_absMax;
// Decreases the maximum range of quantziation by 2^bitShift to prevent integer overflow during BLAS routines.
// bitShift=0 doesn't change the range; higher bitShift will decrease precision of quantization, but will make BLAS routines less prone to overflow.
// For quantization with shorts, recommended value of bitShift is from 1 to 3, but it's model and feature dependent and should be experimented with for optimal results
size_t m_bitShift;
public:
// elements - collection to be quantized
// extraBits decreases the quantization normalizer to prevent integer overflow during BLAS routines.
// Higher extraBits will decrease precision of quantization, but will make BLAS routines less prone to overflow.
// For quantization with shorts, recommended value of extraBits is 1-3.
// This constructor accepts the collection of RawType to initialize internal quantizer
// and then apply this quantizer to collections with similar range as the one it was initialized with.
SymmetricQuantizer(const ArrayRef<RawType>& input, size_t extraBits)
// bitShift - see comment above
SymmetricQuantizer(size_t bitShift) : m_bitShift(bitShift)
{
m_absMax = FindAbsMax(input);
Initialize(m_absMax, extraBits);
}
// absoluteMax - the range of the quantizer (normally represents maximum absolute value of the values in the collection to be quantized).
// extraBits - see comment in another ctor
SymmetricQuantizer(RawType absoluteMax, size_t extraBits)
{
Initialize(absoluteMax, extraBits);
}
// Perform quantization of the input collection, put result into pre-allocated output collection
virtual void Quantize(const ArrayRef<RawType>& input, ArrayRef<QuantizedType>& output)
{
if (input.size() == 0)
return;
assert(input.size() == output.size());
RawType absoluteMax = FindAbsMax(input);
RawType shiftedMax = absoluteMax * (1 << m_bitShift);
if (shiftedMax == 0)
{
// Whole input collection is 0's
// Turn output collection to 0's as well
m_quantizeFactor = 0;
m_inverseQuantizerFactor = 0;
}
else
{
m_quantizeFactor = this->rangeMax / shiftedMax;
m_inverseQuantizerFactor = 1 / m_quantizeFactor;
}
for (size_t i = 0; i < input.size(); i++)
{
#ifdef _DEBUG
assert(abs(input[i]) <= m_absMax);
#endif
output[i] = (QuantizedType) round((input[i] * m_quantizeFactor));
output[i] = (QuantizedType)round(input[i] * m_quantizeFactor);
}
}
// Accept quantized collection as input, put de-quantization result into pre-allocated output collection.
virtual void Dequantize(const ArrayRef<QuantizedType>& input, ArrayRef<RawType>& output)
virtual void Dequantize(const ArrayRef<RawType>& input, ArrayRef<RawType>& output)
{
assert(input.size() == output.size());
for (size_t i = 0; i < input.size(); i++)
Dequantize(input.data(), output.data(), input.size());
}
// Accept quantized collection as input, put de-quantization result into pre-allocated output collection.
virtual void Dequantize(const RawType* input, RawType* output, size_t size)
{
for (size_t i = 0; i < size; i++)
{
output[i] = (RawType)(input[i] * m_inverseQuantizerFactor);
output[i] = input[i] * m_inverseQuantizerFactor;
}
}
@ -84,22 +98,9 @@ private:
// Find absolute maximum value
RawType FindAbsMax(const ArrayRef<RawType>& arrayRef)
{
RawType maxElem = *std::max_element(arrayRef.begin(), arrayRef.end());
RawType minElem = *std::min_element(arrayRef.begin(), arrayRef.end());
auto minMaxPair = std::minmax_element(arrayRef.begin(), arrayRef.end());
return std::max(maxElem, std::abs(minElem));
}
void Initialize(RawType absoluteMax, size_t extraBits)
{
RawType shiftedMax = absoluteMax * (1 << extraBits);
if (shiftedMax == 0)
{
LogicError("The absolute max element in the sequence to be quantized is 0.");
}
m_absMax = absoluteMax;
m_quantizeFactor = this->rangeMax / shiftedMax;
m_inverseQuantizerFactor = 1 / m_quantizeFactor;
return std::max(arrayRef[minMaxPair.second - arrayRef.begin()], std::abs(arrayRef[minMaxPair.first - arrayRef.begin()]));
}
};

Просмотреть файл

@ -352,7 +352,7 @@ shared_ptr<Matrix<ElemType>> TensorView<ElemType>::AsMatrix() const
}
template <class ElemType>
void TensorView<ElemType>::DoMatrixProductOf(ElemType beta, bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha)
void TensorView<ElemType>::DoMatrixProductOf(ElemType beta, bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha, shared_ptr<QuantizedMultiplier<ElemType>> pQuantizedMultiplier)
{
// determine integration dimension offset
auto shapeA = a.m_shape;
@ -383,9 +383,9 @@ void TensorView<ElemType>::DoMatrixProductOf(ElemType beta, bool transC, const T
auto C = Reshaped(shapeC).AsMatrix();
// and go
if (!transC)
Matrix<ElemType>::MultiplyAndWeightedAdd(alpha, *A, transA, *B, transB, beta, *C);
Matrix<ElemType>::MultiplyAndWeightedAdd(alpha, *A, transA, *B, transB, beta, *C, pQuantizedMultiplier);
else // C' = A * B <==> C = (A * B)' = B' * A'
Matrix<ElemType>::MultiplyAndWeightedAdd(alpha, *B, !transB, *A, !transA, beta, *C);
Matrix<ElemType>::MultiplyAndWeightedAdd(alpha, *B, !transB, *A, !transA, beta, *C, pQuantizedMultiplier);
}
template class TensorView<float>;

Просмотреть файл

@ -10,6 +10,7 @@
#include "Basics.h"
#include "Matrix.h"
#include "TensorShape.h"
#include "Quantizers.h"
#pragma warning(push)
#pragma warning(disable : 4251) // needs to have dll-interface to be used by clients of... caused by TensorView::m_shape which is only private. We use the same compiler everywhere.
@ -143,9 +144,9 @@ public:
// If beta == 0, c is not read out, i.e. it can be uninitialized or contain NaNs.
// -------------------------------------------------------------------
void DoMatrixProductOf (ElemType beta, bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha);
void AssignMatrixProductOf( bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha = 1.0f) { DoMatrixProductOf(0, transC, a, transA, b, transB, alpha); }
void AddMatrixProductOf ( bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha = 1.0f) { DoMatrixProductOf(1.0f, transC, a, transA, b, transB, alpha); }
void DoMatrixProductOf(ElemType beta, bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha, shared_ptr<QuantizedMultiplier<ElemType>> pQuantizedMultiplier = nullptr);
void AssignMatrixProductOf( bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha = 1.0f, shared_ptr<QuantizedMultiplier<ElemType>> pQuantizedMultiplier = nullptr) { DoMatrixProductOf(0, transC, a, transA, b, transB, alpha, pQuantizedMultiplier); }
void AddMatrixProductOf ( bool transC, const TensorView& a, bool transA, const TensorView& b, bool transB, ElemType alpha = 1.0f) { DoMatrixProductOf(1.0f, transC, a, transA, b, transB, alpha); }
shared_ptr<Matrix<ElemType>> AsMatrix() const;
const TensorShape& GetShape() const { return m_shape; }

Просмотреть файл

@ -6398,12 +6398,22 @@ Test module "MathTests" has passed with:
2 assertions out of 2 passed
Test suite "QuantizersUnitTests" has passed with:
1 test case out of 1 passed
2 test case out of 2 passed
12 assertions out of 12 passed
Test case "QuantizersUnitTests/FloatToShort" has passed with:
12 assertions out of 12 passed
6 assertions out of 6 passed
Test case "QuantizersUnitTests/QuantizeZeros" has passed with:
6 assertions out of 6 passed
Test suite "QuantizedOperationsUnitTests" has passed with:
1 test case out of 1 passed
65 assertions out of 65 passed
Test case "QuantizedOperationsUnitTests/MultiplyIntToShort" has passed with:
65 assertions out of 65 passed
Test suite "MathTensorTests" has passed with:
9 test cases out of 9 passed
6 assertions out of 6 passed

Просмотреть файл

@ -145,6 +145,7 @@
<ClCompile Include="MatrixSparseDenseInteractionsTests.cpp" />
<ClCompile Include="MatrixTests.cpp" />
<ClCompile Include="QuantizersTests.cpp" />
<ClCompile Include="QuantizedOperationsTests.cpp" />
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader>Create</PrecompiledHeader>
</ClCompile>

Просмотреть файл

@ -0,0 +1,56 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "../../../Source/Math/QuantizedOperations.h"
#include "../../../Source/Math/Helpers.h"
using namespace Microsoft::MSR::CNTK;
namespace Microsoft { namespace MSR { namespace CNTK { namespace Test {
BOOST_AUTO_TEST_SUITE(QuantizedOperationsUnitTests)
BOOST_FIXTURE_TEST_CASE(MultiplyIntToShort, RandomSeedFixture)
{
// A[m,k]*B[k,n] = C[m,n]
int m = 5, n = 4, k = 3;
std::vector<float> A = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
std::vector<float> B = {16,17,18,19,20,21,22,23,24,25,26,27};
std::vector<float> C_expected = { 316, 367, 418, 469, 520, 370, 430, 490, 550, 610, 424, 493, 562, 631, 700, 478, 556, 634, 712, 790 };
std::vector<float> C;
C.resize(m*n);
shared_ptr<QuantizerBase<float, short>> quantA(new SymmetricQuantizer<float, short>(1));
shared_ptr<QuantizerBase<float, short>> quantB(new SymmetricQuantizer<float, short>(2));
// A - is constant; B - is not
QuantizedMultiplier<float> mult(quantA, true, quantB, false);
// First pass
mult.Multiply(m, n, k, A.data(), B.data(), C.data());
for (size_t i = 0; i < m*n; i++)
BOOST_CHECK_EQUAL(round(C[i]), C_expected[i]);
// Second pass, the same matrices
mult.Multiply(m, n, k, A.data(), B.data(), C.data());
for (size_t i = 0; i < m*n; i++)
BOOST_CHECK_EQUAL(round(C[i]), C_expected[i]);
// Third pass with updated B (size and values)
int n_upd = 5;
std::vector<float> B_upd = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 };
std::vector<float> C_expected_upd = { 46, 52, 58, 64, 70, 100, 115, 130, 145, 160, 154, 178, 202, 226, 250, 208, 241, 274, 307, 340, 262, 304, 346, 388, 430};
std::vector<float> C_upd;
C_upd.resize(m*n_upd);
mult.Multiply(m, n_upd, k, A.data(), B_upd.data(), C_upd.data());
for (size_t i = 0; i < m*n_upd; i++)
BOOST_CHECK_EQUAL(round(C_upd[i]), C_expected_upd[i]);
}
BOOST_AUTO_TEST_SUITE_END()
} } } }

Просмотреть файл

@ -25,29 +25,56 @@ BOOST_FIXTURE_TEST_CASE(FloatToShort, RandomSeedFixture)
ArrayRef<float> inputAr(input, 3);
ArrayRef<short> outputAr(output, 3);
std::unique_ptr<QuantizerBase<float, short>> symQuantPtr(new SymmetricQuantizer<float, short>(10.0f, 0));
std::unique_ptr<QuantizerBase<float, short>> symQuantPtr(new SymmetricQuantizer<float, short>(0));
symQuantPtr->Quantize(inputAr, outputAr);
for (size_t i = 0; i < 3; i++)
BOOST_CHECK_EQUAL(output[i], outputCorrect[i]);
symQuantPtr->Dequantize(outputAr, inputAr);
float* outputFloat = new float[3];
ArrayRef<float> outputFlAr(outputFloat, 3);
for (size_t i = 0; i < 3; i++)
outputFlAr[i] = (float)outputAr[i];
symQuantPtr->Dequantize(outputFlAr, inputAr);
for (size_t i = 0; i < 3; i++)
BOOST_CHECK_EQUAL(round(input[i] * (10^4)), round(inputCorrect[i] * (10^4)));
std::unique_ptr<QuantizerBase<float, short>> symQuantPtr2(new SymmetricQuantizer<float, short>(inputAr, 0));
symQuantPtr2->Quantize(inputAr, outputAr);
delete[] inputCorrect;
delete[] outputFloat;
}
BOOST_FIXTURE_TEST_CASE(QuantizeZeros, RandomSeedFixture)
{
float input[3] = { 0, 0, 0 };
short output[3] = { 0, 0, 0 };
float* inputCorrect = new float[3];
for (size_t i = 0; i < 3; i++)
inputCorrect[i] = input[i];
short outputCorrect[3] = { 0, 0, 0 };
ArrayRef<float> inputAr(input, 3);
ArrayRef<short> outputAr(output, 3);
std::unique_ptr<QuantizerBase<float, short>> symQuantPtr(new SymmetricQuantizer<float, short>(0));
symQuantPtr->Quantize(inputAr, outputAr);
for (size_t i = 0; i < 3; i++)
BOOST_CHECK_EQUAL(output[i], outputCorrect[i]);
symQuantPtr2->Dequantize(outputAr, inputAr);
float* outputFloat = new float[3];
ArrayRef<float> outputFlAr(outputFloat, 3);
for (size_t i = 0; i < 3; i++)
outputFlAr[i] = (float)outputAr[i];
symQuantPtr->Dequantize(outputFlAr, inputAr);
for (size_t i = 0; i < 3; i++)
BOOST_CHECK_EQUAL(round(input[i] * (10 ^ 4)), round(inputCorrect[i] * (10 ^ 4)));
delete[] inputCorrect;
delete[] outputFloat;
}
BOOST_AUTO_TEST_SUITE_END()
} } } }