first version of ResNet in TutorialImage.cntk;

breaking change: ConvolutionalLayer{} and XXPoolingLayer{} now both default to padding=false, like Keras
This commit is contained in:
Frank Seide 2016-08-10 23:47:44 -07:00
Родитель 8ecc7b13d5
Коммит 474570dbb5
3 изменённых файлов: 103 добавлений и 44 удалений

Просмотреть файл

@ -21,11 +21,11 @@ TrainConvNet = [
model = Sequential (
Subtract128 :
ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 0.0043} :
ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 0.0043} :
MaxPoolingLayer {(3:3), stride = (2:2)} :
ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} :
ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} :
MaxPoolingLayer {(3:3), stride = (2:2)} :
ConvolutionalLayer {64, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} :
ConvolutionalLayer {64, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} :
MaxPoolingLayer {(3:3), stride = (2:2)} :
DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} :
Dropout :
@ -86,13 +86,13 @@ TrainConvNetWithBN = [
model = Sequential (
Subtract128 :
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 0.0043} :
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 0.0043} :
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
MaxPoolingLayer {(3:3), stride = (2:2)} :
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
MaxPoolingLayer {(3:3), stride = (2:2)} :
ConvolutionalLayer {64, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
ConvolutionalLayer {64, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
MaxPoolingLayer {(3:3), stride = (2:2)} :
LinearLayer {64, bias = false, init = "gaussian", initValueScale = 12} :

Просмотреть файл

@ -18,33 +18,33 @@ TrainConvNet = [
labelDim = 10
# basic model
Subtract128 (x) = x - Constant (128)
model_basic (features) =
{
featNorm = features - Constant (128)
l1 = ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 0.0043} (featNorm)
# BUGBUG: pad=true desired?
l1 = ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 0.0043} (featNorm)
p1 = MaxPoolingLayer {(3:3), stride = (2:2)} (l1)
l2 = ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} (p1)
l2 = ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} (p1)
p2 = MaxPoolingLayer {(3:3), stride = (2:2)} (l2)
l3 = ConvolutionalLayer {64, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} (p2)
l3 = ConvolutionalLayer {64, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} (p2)
p3 = MaxPoolingLayer {(3:3), stride = (2:2)} (l3)
d1 = DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} (p3)
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
}.z
# with self-defined layers
MyConvLayer {dim, initValueScale} =
MyConvReLUPoolLayer {dim, initValueScale} =
{
C = ConvolutionalLayer {dim, (5:5), activation = ReLU, init = "gaussian", initValueScale = initValueScale}
C = ConvolutionalLayer {dim, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = initValueScale}
P = MaxPoolingLayer {(3:3), stride = (2:2)}
f(x) = P(C(x))
}.f
model_layers (features) =
{
featNorm = features - Constant (128)
h1 = MyConvLayer {32, 0.0043} (featNorm)
h2 = MyConvLayer {32, 1.414} (h1)
h3 = MyConvLayer {64, 1.414} (h2)
h1 = MyConvReLUPoolLayer {32, 0.0043} (featNorm)
h2 = MyConvReLUPoolLayer {32, 1.414} (h1)
h3 = MyConvReLUPoolLayer {64, 1.414} (h2)
d1 = DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} (h3)
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
}.z
@ -53,29 +53,30 @@ TrainConvNet = [
model_layerStack (features) =
{
featNorm = features - Constant (128)
h3 = LayerStack {3, i => MyConvLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
h3 = LayerStack {3, i => MyConvReLUPoolLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
d1 = DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} (h3)
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
}.z
# model-composition style
# ...TODO: test this again; last run was a little worse
Subtract128 (x) = x - Constant (128)
model_compositionStyle = Sequential (
Subtract128 :
LayerStack {3, i => MyConvLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } :
LayerStack {3, i => MyConvReLUPoolLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } :
DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} :
LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5}
)
// --- with BatchNorm
MyBNConvLayer {dim, initValueScale} =
MyConvBNReLUPoolLayer {dim, initValueScale} =
{
C = ConvolutionalLayer {dim, (5:5), bias = false, init = "gaussian", initValueScale = initValueScale}
C = ConvolutionalLayer {dim, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = initValueScale}
B = BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096}
P = MaxPoolingLayer {(3:3), stride = (2:2)}
f = Sequential (C:B:ReLU:P)
}.f
MyBNDenseLayer {dim, initValueScale} =
MyDenseBNReLULayer {dim, initValueScale} =
{
D = DenseLayer {dim, bias = false, init = "gaussian", initValueScale = initValueScale}
B = BatchNormalizationLayer {normalizationTimeConstant = 4096}
@ -84,17 +85,69 @@ TrainConvNet = [
model_withBatchNorm (features) =
{
featNorm = features - Constant (128)
h3 = LayerStack {3, i => MyBNConvLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
d1 = MyBNDenseLayer {64, 12} (h3)
h3 = LayerStack {3, i => MyConvBNReLUPoolLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
d1 = MyDenseBNReLULayer {64, 12} (h3)
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
}.z
// --- ResNet
MyConvBNLayer {dim, initValueScale, stride} =
{
# note: (3:3), while the macro above is (5:5)
C = ConvolutionalLayer {dim, (3:3), pad = true, stride = (stride:stride), bias = false, init = "gaussian", initValueScale = initValueScale}
B = BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096}
f = Sequential (C:B)
}.f
MyConvBNReLULayer {dim, initValueScale, stride} = # TODO: get rid of this
{
CB = MyConvBNLayer {dim, initValueScale, stride}
f = Sequential (CB:ReLU)
}.f
ResNetNode {dim, initValueScale} =
{
C1 = MyConvBNReLULayer {dim, initValueScale, 1} # first convolution layer
C2 = MyConvBNLayer {dim, initValueScale, 1} # second convolution layer, no ReLU
f(x) = ReLU (x + C2(C1(x)))
}.f
ResNetIncNode {dim, initValueScale} =
{
# one branch. C2 o c1 doubles the #channels but halves the image size
C1 = MyConvBNReLULayer {dim, initValueScale, 2} # first convolution layer, stride = 2
C2 = MyConvBNLayer {dim, initValueScale, 1} # second convolution layer, no ReLU
# second branch:
# - sub-sample spatially by a factor of 2
# - append dim/2 zero output channels
Down2 = MaxPoolingLayer {(1:1), stride = (2:2)} # this is a function that sub-samples spatially by a factor of 2
pad = ConstantTensor (0, (1:1:dim/2)) # the 1s will broadcast to image size
P(x) = Splice ((Down2(x) : pad), axis = 3)
B = BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096}
# layer sums both branches and rectifies the result
f(x) = ReLU (B(P(x)) + C2(C1(x)))
}.f
model_resNet (features) =
{
conv1 = MyConvBNReLULayer {16, 0.26, 1} (features)
rn1 = LayerStack {3, i=> ResNetNode {16, 7.07}} (conv1)
rn2_1 = ResNetIncNode {32, 7.07} (rn1)
rn2 = LayerStack {2, i=> ResNetNode {32, 7.07}} (rn2_1)
rn3_1 = ResNetIncNode {64, 7.07} (rn2)
rn3 = LayerStack {2, i=> ResNetNode {64, 7.07}} (rn3_1)
pool = AveragePoolingLayer {(8:8)} (rn3)
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 0.4} (pool)
}.z
# inputs
features = Input {imageShape}
labels = Input {labelDim}
# apply model to features
z = model_withBatchNorm (features)
z = model_resNet (features)
# connect to system
ce = CrossEntropyWithSoftmax (labels, z)
@ -109,21 +162,27 @@ TrainConvNet = [
]
SGD = [
epochSize = 49984
epochSize = 50000 # 49984 --TODO: why 16 less?
# without BatchNormalization:
#minibatchSize = 64
#maxEpochs = 30 ; minibatchSize = 64
#learningRatesPerSample = 0.00015625*10:0.000046875*10:0.000015625
#momentumAsTimeConstant = 600*20:6400
#maxEpochs = 30
#L2RegWeight = 0.03
# with BatchNormalization:
minibatchSize = 64
learningRatesPerSample = 0.00046875*7:0.00015625
momentumAsTimeConstant = 0
maxEpochs = 30
L2RegWeight = 0
#maxEpochs = 30 ; minibatchSize = 64
#learningRatesPerSample = 0.00046875*7:0.00015625
#momentumAsTimeConstant = 0
#L2RegWeight = 0
# ResNet
maxEpochs = 160 ; minibatchSize = 128
learningRatesPerSample = 0.0078125*80:0.00078125*40:0.000078125
momentumAsTimeConstant = 1200
L2RegWeight = 0.0001
#learningRatesPerMB = 1.0*80:0.1*40:0.01
#momentumPerMB = 0.9
firstMBsToShowResult = 10 ; numMBsToShowResult = 500
]
@ -156,13 +215,13 @@ TrainConvNetWithBN = [
model = Sequential (
Subtract128 :
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 0.0043} :
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 0.0043} :
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
MaxPoolingLayer {(3:3), stride = (2:2)} :
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
MaxPoolingLayer {(3:3), stride = (2:2)} :
ConvolutionalLayer {64, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
ConvolutionalLayer {64, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
MaxPoolingLayer {(3:3), stride = (2:2)} :
LinearLayer {64, bias = false, init = "gaussian", initValueScale = 12} :

Просмотреть файл

@ -66,9 +66,9 @@ ConvolutionalLayer {numOutputChannels, # e.g. (1) or BS.Constants.None
bias = true,
activation = (x=>x),
init = "uniform",
initValueScale = 1,
initValueScale = 1, # TODO: rename to initScale
#reductionRank = 1, # TODO: support this
stride = 1, autoPadding = true,
stride = 1, pad = false,
lowerPad = 0, upperPad = 0,
#transpose = false, # TODO: support this
maxTempMemSizeInSamples = 0} =
@ -84,7 +84,7 @@ ConvolutionalLayer {numOutputChannels, # e.g. (1) or BS.Constants.None
sharing = true # TODO: support this
transpose = false # TODO: support this
f(x) = {
c = Convolution (W, x, filterShape, mapDims = numOutputChannels, stride = stride, sharing = sharing, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad, transpose = transpose, maxTempMemSizeInSamples = maxTempMemSizeInSamples)
c = Convolution (W, x, filterShape, mapDims = numOutputChannels, stride = stride, sharing = sharing, autoPadding = pad, lowerPad = lowerPad, upperPad = upperPad, transpose = transpose, maxTempMemSizeInSamples = maxTempMemSizeInSamples)
res = activation (if bias then c + b else c)
}.res
}.f
@ -92,15 +92,15 @@ ConvolutionalLayer {numOutputChannels, # e.g. (1) or BS.Constants.None
# MaxPoolingLayer, AveragePoolingLayer -- create a max- or average-pooling layer
_PoolingLayer {poolKind, # "max" or "average"
filterShape, # e.g. (3:3)
stride = 1, autoPadding = false,
stride = 1, pad = false,
lowerPad = 0, upperPad = 0} = # TODO: support this
{
f(x) = Pooling (x, poolKind, filterShape, stride = stride, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad)
f(x) = Pooling (x, poolKind, filterShape, stride = stride, autoPadding = pad, lowerPad = lowerPad, upperPad = upperPad)
}.f
MaxPoolingLayer {filterShape, stride = 1, autoPadding = false, lowerPad = 0, upperPad = 0} =
_PoolingLayer {"max", filterShape, stride = stride, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad}
AveragePoolingLayer {filterShape, stride = 1, autoPadding = false, lowerPad = 0, upperPad = 0} =
_PoolingLayer {"average", filterShape, stride = stride, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad}
MaxPoolingLayer {filterShape, stride = 1, pad = false, lowerPad = 0, upperPad = 0} =
_PoolingLayer {"max", filterShape, stride = stride, pad = pad, lowerPad = lowerPad, upperPad = upperPad}
AveragePoolingLayer {filterShape, stride = 1, pad = false, lowerPad = 0, upperPad = 0} =
_PoolingLayer {"average", filterShape, stride = stride, pad = pad, lowerPad = lowerPad, upperPad = upperPad}
# RecurrentLSTMLayer -- create an LSTM layer
RecurrentLSTMLayer {outputDim,
@ -441,8 +441,8 @@ WeightedLogistic(label, probability, instanceWeight, tag='') = new ComputationNo
ReconcileDynamicAxis(dataInput, layoutInput, tag='') = new ComputationNode [ operation = 'ReconcileDynamicAxis' ; inputs = (dataInput : layoutInput) /*plus the function args*/ ]
ReconcileMBLayout = ReconcileDynamicAxis # back compat
CastAs (type, data) = ReconcileDynamicAxis (data, type) # read as CastAs<type>(data) where the cast may consist of rearranging the data w.r.t. MBLayout or broadcasting across sequence items
# ND convo & pooling/unpooling --why is autoPadding true? Normally one would want to reduce dimensions, no?
Convolution(weightNode, inputValueNode, kernelDims, mapDims = 0, stride = 1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, transpose=false, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'Convolution' ; inputs = (weightNode : inputValueNode); kernelShape = new TensorShape [ dims = kernelDims ] ; mapCount = new TensorShape [ dims = mapDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimSharing = new BoolVector [ items = sharing ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]
# ND pooling/unpooling --why is autoPadding true? Normally one would want to reduce dimensions, no?
Pooling(input, poolKind/*'max'|'average'*/, kernelDims, stride=1, autoPadding = true, lowerPad = 0, upperPad = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'Pooling' ; inputs = (input); pool = poolKind ; kernelShape = new TensorShape [ dims = kernelDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]
MaxUnpooling(unpoolInput, poolInput, kernelDims, stride=1, autoPadding = true, lowerPad = 0, upperPad = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'MaxUnpooling' ; inputs = (unpoolInput : poolInput); kernelShape = new TensorShape [ dims = kernelDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]
# 2D pooling