Added NDConvolution node. Added ConvolveGeometry.

This commit is contained in:
Alexey Kamenev 2016-03-01 11:10:17 -08:00
Родитель 65de62c934
Коммит cee4cfa6f8
9 изменённых файлов: 394 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,152 @@
# Parameters can be overwritten on the command line
# for example: cntk configFile=myConfigFile RootDir=../..
# For running from Visual Studio add
# currentDirectory=$(SolutionDir)/<path to corresponding data folder>
RootDir = ".."
ConfigDir = "$RootDir$/Config"
DataDir = "$RootDir$/Data"
OutputDir = "$RootDir$/Output"
ModelDir = "$OutputDir$/Models"
deviceId = 0
imageLayout = "cudnn"
# override the above as follows when running on CPU:
# deviceId = -1
# imageLayout = "legacy"
command = train:test
precision = "float"
modelPath = "$ModelDir$/04_NDConvolution"
# uncomment the following line to write logs to a file
stderr = "$OutputDir$/04_NDConvolution_out"
traceLevel=1
numMBsToShowResult=500
prefetch=true
useCuDnn=true
#######################################
# TRAINING CONFIG #
#######################################
train = [
action = "train"
BrainScriptNetworkBuilder = [
useCuDnn = $useCuDnn$
// macros
NDConvReLULayer(inp, kW, kH, inMap, outMap, hStride, vStride, wScale, bValue) = [ // ReLU non-linearity
convW = Parameter(outMap, kW * kH * inMap, init="uniform", initValueScale=wScale, initOnCPUOnly=false)
//conv = NDConvolution(convW, inp, kernelShape=(kW : kH : inMap), mapShape=(1 : 1 : outMap), stride=(hStride : vStride : 1),
// sharing=(true : true : true), padding=(true : true : false), imageLayout=if useCuDnn then "cudnn" else "legacy")
//conv = NDConvolution(convW, inp, (kW : kH : inMap), (1 : 1 : outMap), (hStride : vStride : 1),
// (1 : 1 : 1), (1 : 1 : 1), (1 : 1 : 1), (1 : 1 : 1), imageLayout=if useCuDnn then "cudnn" else "legacy")
conv = NDConvolution(convW, inp, (kW : kH : inMap), (1 : 1 : outMap), stride=1, sharing=true, autoPadding=false, lowerPad=1, imageLayout=if useCuDnn then "cudnn" else "legacy")
convB = if useCuDnn
then ParameterTensor((1 : 1 : outMap), init="fixedValue", value=bValue)
else Parameter(outMap, 1, init="fixedValue", value=bValue)
convPlusB = Plus(conv, convB);
out = RectifiedLinear(convPlusB);
]
DNNSigmoidLayer(inDim, outDim, x, parmScale) = [ // Sigmoid non-linearity
W = Parameter(outDim, inDim, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
b = Parameter(outDim, 1, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
t = Times(W, x)
z = Plus(t, b)
out = Sigmoid(z)
]
DNNLayer(inDim, outDim, x, parmScale) = [ // no non-linearity, as input for SoftMax
W = Parameter(outDim, inDim, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
b = Parameter(outDim, 1, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
t = Times(W, x)
out = Plus(t, b)
]
imageW = 28
imageH = 28
labelDim = 10
features = ImageInput(imageW, imageH, 1, imageLayout=if useCuDnn then "cudnn" else "legacy", tag="feature")
featScale = Constant(0.00390625)
featScaled = Scale(featScale, features)
labels = Input(labelDim, tag="label")
# conv1
kW1 = 5
kH1 = 5
cMap1 = 16
hStride1 = 1
vStride1 = 1
# weight[cMap1, kW1 * kH1 * inputChannels]
conv1_act = NDConvReLULayer(featScaled, kW1, kH1, 1, cMap1, hStride1, vStride1, 10, 1).out
h1Dim = 128
# DNNSigmoidLayer and DNNLayer are defined in Macros.ndl
h1 = DNNSigmoidLayer(12544, h1Dim, conv1_act, 1).out
ol = DNNLayer(h1Dim, labelDim, h1, 1).out
ce = CrossEntropyWithSoftmax(labels, ol, tag="criterion")
err = ErrorPrediction(labels, ol, tag="eval")
outputNodes = ol
]
SGD = [
epochSize = 60000
minibatchSize = 32
learningRatesPerMB = 0.5
momentumPerMB = 0*10:0.7
maxEpochs = 15
]
reader = [
readerType = "UCIFastReader"
# To get the data (Train-28x28.txt) please run `python mnist_convert.py`
# from the 'AdditionalFiles' folder. See REAMDE.md for details.
file = "$DataDir$/Train-28x28.txt"
features = [
dim = 784
start = 1
]
labels = [
dim = 1
start = 0
labelDim = 10
labelMappingFile = "$DataDir$/labelsmap.txt"
]
]
]
#######################################
# TEST CONFIG #
#######################################
test = [
action = test
minibatchSize = 16
reader = [
readerType = "UCIFastReader"
file = "$DataDir$/Test-28x28.txt"
features = [
dim = 784
start = 1
]
labels = [
dim = 1
start = 0
labelDim = 10
labelMappingFile = "$DataDir$/labelsmap.txt"
]
]
]

Просмотреть файл

@ -63,6 +63,9 @@ L"ParameterTensor(dims, learningRateMultiplier = 1.0, init = 'uniform'/*|fixedVa
L"WeightedLogistic(label, probability, instanceWeight, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability : instanceWeight) /*plus the function args*/ ]\n"
L"ReconcileMBLayout(dataInput, layoutInput, tag='') = new ComputationNode [ operation = 'ReconcileMBLayout' ; inputs = (dataInput : layoutInput) /*plus the function args*/ ]\n"
L"Convolution(weightNode, inputValueNode, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, zeroPadding = false, maxTempMemSizeInSamples = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'Convolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
L"NDConvolution(weightNode, inputValueNode, kernelDims, mapDims, stride=1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'NDConvolution' ; inputs = (weightNode : inputValueNode); kernelShape = new TensorShape [ dims = kernelDims ] ; mapShape = new TensorShape [ dims = mapDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimSharing = new BoolVector [ items = sharing ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]\n"
//L"NDConvolution(weightNode, inputValueNode, kernelShape = new TensorShape [ /*dims*/ ], mapShape = new TensorShape[ /*dims*/ ], stride = new TensorShape [ /*dims*/ ], sharing = new BoolVector [ /*dims*/ ], padding = new BoolVector [ /*dims*/ ], lowerPad = new TensorShape [ /*dims*/ ], upperPad = new TensorShape [ /*dims*/ ], imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'NDConvolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
//L"NDConvolution(weightNode, inputValueNode, kernelShape, mapShape, stride, sharing, padding, lowerPad, upperPad, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'NDConvolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
L"MaxPooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'MaxPooling' ; inputs = input /*plus the function args*/ ]\n"
L"AveragePooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'AveragePooling' ; inputs = input /*plus the function args*/ ]\n"
// TODO: define DelayedValue, with negative delay for future; cannot do this yet, need to be able to say something like delay = -(^.delay)

Просмотреть файл

@ -127,6 +127,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateNode(const std::wstring& node
if (nodeType == OperationNameOf(AveragePoolingNode)) return New<AveragePoolingNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(BatchNormalizationNode)) return New<BatchNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ConvolutionNode)) return New<ConvolutionNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(NDConvolutionNode)) return New<NDConvolutionNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(SparseInputValue)) return New<SparseInputValue<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(InputValue)) return New<InputValue<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(LearnableParameter)) return New<LearnableParameter<ElemType>>(forward<_Types>(_Args)...);

Просмотреть файл

@ -563,5 +563,6 @@ public:
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedTensorShape> registerTensorShape(L"TensorShape");
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedVector<int>> registerIntVector(L"IntVector");
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedVector<size_t>> registerSizeVector(L"SizeVector");
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedVector<bool>> registerBoolVector(L"BoolVector");
}}}

Просмотреть файл

@ -25,6 +25,157 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class NDConvolutionNode : public ComputationNode<ElemType>, public NumInputs<2>
{
typedef ComputationNode<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"NDConvolution";
}
public:
NDConvolutionNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name)
{
}
NDConvolutionNode(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& kernelShape, const TensorShape& mapShape, const TensorShape& strideShape,
const std::vector<bool>& sharing, const std::vector<bool>& autoPadding, const TensorShape& lowerPad, const TensorShape& upperPad,
ImageLayoutKind imageLayout, size_t maxTempMemSizeInSamples)
: Base(deviceId, name), m_kernel(kernelShape), m_stride(strideShape), m_sharing(sharing),
m_autoPad(autoPadding), m_lowerPad(lowerPad), m_upperPad(upperPad)
{
}
NDConvolutionNode(const ScriptableObjects::IConfigRecordPtr configp)
: NDConvolutionNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"kernelShape"), configp->Get(L"mapShape"), configp->Get(L"strideShape"),
configp->Get(L"dimSharing"), configp->Get(L"dimPadding"), configp->Get(L"dimPadLower"), configp->Get(L"dimPadUpper"),
ImageLayoutKindFrom(configp->Get(L"imageLayout")), configp->Get(L"maxTempMemSizeInSamples"))
{
AttachInputs(configp, this->GetExpectedNumInputs());
}
public:
void Save(File& fstream) const override
{
Base::Save(fstream);
}
void Load(File& fstream, size_t modelVersion) override
{
Base::Load(fstream, modelVersion);
}
void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
{
Base::CopyTo(nodeP, newName, flags);
if (flags & CopyNodeFlags::copyNodeValue)
{
}
}
void BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
return false;
}
void ForwardProp(const FrameRange& fr) override
{
}
void Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
bool validate = isFinalValidationPass;
if (validate && m_imageLayoutKind != ImageLayoutKind::CHW)
{
InvalidArgument(
"NDConvolution supports only cuDNN (CHW) data layout. "
"Please specify imageLayout=\"cudnn\" in NDConvolution node in your BrainScript "
"and make sure input data layout is CHW");
}
auto dimsInput = GetInputSampleLayout(1);
if (validate)
{
if (dimsInput.GetRank() != m_kernel.GetRank())
InvalidArgument("Convolution input and kernel tensors must have the same rank.");
if (m_stride.GetRank() != 1 && dimsInput.GetRank() != m_stride.GetRank())
InvalidArgument("Convolution stride tensor must have rank 1 or the same as the input tensor.");
if (m_sharing.size() != 1 && dimsInput.GetRank() != m_sharing.size())
InvalidArgument("Convolution sharing tensor must have rank 1 or the same as the input tensor.");
if (m_autoPad.size() != 1 && dimsInput.GetRank() != m_autoPad.size())
InvalidArgument("Convolution padding tensor must have rank 1 or the same as the input tensor.");
if (m_lowerPad.GetRank() != 1 && dimsInput.GetRank() != m_lowerPad.GetRank())
InvalidArgument("Convolution lower pad tensor must have rank 1 or the same as the input tensor.");
if (m_upperPad.GetRank() != 1 && dimsInput.GetRank() != m_upperPad.GetRank())
InvalidArgument("Convolution upper pad tensor must have rank 1 or the same as the input tensor.");
}
SmallVector<size_t> dimsOutput(dimsInput.GetRank());
for (size_t i = 0; i < dimsInput.GetRank(); i++)
{
assert(dimsInput[i] >= 1);
if (validate && m_kernel[i] > dimsInput[i])
InvalidArgument("NDConvolution operation requires that kernel dim %d <= input dim %d.", (int)m_kernel[i], (int)dimsInput[i]);
size_t delta = m_stride[m_stride.GetRank() == 1 ? 0 : i];
if (validate && delta > m_kernel[i])
InvalidArgument("NDConvolution operation requires that stride %d <= input dim %d.", (int)m_stride[i], (int)dimsInput[i]);
size_t dim = dimsInput[i];
bool autoPad = m_autoPad[m_autoPad.size() == 1 ? 0 : i];
if (autoPad)
{
dim -= 1;
}
else
{
size_t lo = m_lowerPad[m_lowerPad.size() == 1 ? 0 : i];
size_t hi = m_upperPad[m_upperPad.size() == 1 ? 0 : i];
dim += lo + hi;
}
}
SetDims(TensorShape(dimsOutput), true);
if (isFinalValidationPass)
{
}
}
void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) override
{
Base::RequestMatricesBeforeForwardProp(matrixPool);
}
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
Base::RequestMatricesBeforeBackprop(matrixPool);
}
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
{
Base::ReleaseMatricesAfterBackprop(matrixPool);
}
private:
ImageLayoutKind m_imageLayoutKind;
TensorShape m_kernel;
TensorShape m_stride;
std::vector<bool> m_sharing;
std::vector<bool> m_autoPad;
TensorShape m_lowerPad;
TensorShape m_upperPad;
};
// -----------------------------------------------------------------------
// ConvolutionNode (convolutionWeights, inputFeature)
// -----------------------------------------------------------------------

Просмотреть файл

@ -18,6 +18,7 @@
#include "Matrix.h"
#include "TensorShape.h" // for ImageLayoutKind
#include "ConvolveGeometry.h"
namespace Microsoft { namespace MSR { namespace CNTK {

Просмотреть файл

@ -0,0 +1,81 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "Basics.h"
namespace Microsoft { namespace MSR { namespace CNTK {
// Notes:
// * ConvolveGeometry represents the application of one or more rectangular "kernels" (all of the same size)
// to a rectangular input to produce a rectangular output.
// * A "cell" in the rectangular input is identified by a single coordinate called a "col" (for column).
// * A "cell" in the rectangular output is identified by a single coordinate called a "row".
// * The kernels may involve weights, in which case MpRowIwht indicates the starting index of the weights
// used for a given output cell.
// The overall idea of ConvolveGeometry is to precompute maps that can be used to apply convolutions of
// arbitrary configuration and dimension. In such case the generic implementation becomes very simple and invariant
// wrt convolution configuration and dimensionality. For specific cases like 2D convolutions and full sharing,
// highly optimized implementations (e.g. cuDNN) are used.
class ConvolveGeometry final
{
public:
using IntVec = std::vector<int>;
// Maps from a "row" (index of output cell) to its base "col" (index of input cell). For a given row,
// the cols that contribute to it are { MpRowCol[row] + Indices[i0 + 1 + i] | 0 <= i < Indices[i0] },
// where i0 = MpRowIndices[row].
const IntVec& MpRowCol() const { return m_mpRowCol; }
// Maps from a "row" (index of output cell) to where to start in the weights array. Each run of weights
// consists of KernelSize weights.
const IntVec& MpRowIwht() const { return m_mpRowIwht; }
// Maps from a "row" (index of output cell) to its starting index in Runs. A run consists of:
// * skip count (to skip that many weights)
// * item count
// * relative indices into source (item count of these)
// * masks (all 1's or all 0's) (item count of these)
// For items that are masked out (0 mask), the index stored is the next valid index.
// This ensures that accessing the corresponding neuron value doesn't fault and that
// backprop operations write the correct value last (any previous writes won't change
// the value).
// NOTE: The first (zeroth) run is always the "full" kernel run. Also, MpRowRun can be empty,
// indicating that all values are zero (all outputs use the "full" kernel run).
const IntVec& MpRowRun() const { return m_mpRowRun; }
const IntVec& Runs() const { return m_runs; }
// Maps from a "row" (index of output cell) to its starting index in Indices. Note that "Runs" is intended
// for kernels that have weights, while "Indices" is intended for kernels that don't need to access weights.
// As a result, the encoding in Indices is simpler and more direct.
// A run in Indices consists of:
// * item count
// * relative indices into source (item count of these)
// NOTE: The first run of indices is always the "full" kernel run. Also, MpRowIndices can be empty,
// indicating that all values are zero (all outputs use the "full" kernel run).
const IntVec& MpRowIndices() const { return m_mpRowIndices; }
const IntVec& Indices() const { return m_indices; }
// The indices of the first ("top-left-most") "kernel-center" cell in the source.
const IntVec& Start() const { return m_start; }
int StartIndex() const { return m_startIndex; }
ConvolveGeometry(const TensorShape& input, const TensorShape& kernel)
{
assert(input.GetRank() == kernel.GetRank());
}
private:
IntVec m_mpRowCol;
IntVec m_mpRowIwht;
IntVec m_mpRowRun;
IntVec m_runs;
IntVec m_mpRowIndices;
IntVec m_indices;
IntVec m_start;
int m_startIndex;
};
} } }

Просмотреть файл

@ -158,6 +158,7 @@
<ClInclude Include="..\Common\Include\fileutil.h" />
<ClInclude Include="CommonMatrix.h" />
<ClInclude Include="ConvolutionEngine.h" />
<ClInclude Include="ConvolveGeometry.h" />
<ClInclude Include="CPUMatrix.h" />
<ClInclude Include="MatrixQuantizerImpl.h" />
<ClInclude Include="TensorOps.h" />

Просмотреть файл

@ -97,6 +97,9 @@
<ClInclude Include="MatrixQuantizerImpl.h">
<Filter>1bitSGD</Filter>
</ClInclude>
<ClInclude Include="ConvolveGeometry.h">
<Filter>Convolution</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="GPUMatrix.h">