first version of ResNet in TutorialImage.cntk;
breaking change: ConvolutionalLayer{} and XXPoolingLayer{} now both default to padding=false, like Keras
This commit is contained in:
Родитель
8ecc7b13d5
Коммит
474570dbb5
|
@ -21,11 +21,11 @@ TrainConvNet = [
|
|||
|
||||
model = Sequential (
|
||||
Subtract128 :
|
||||
ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 0.0043} :
|
||||
ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 0.0043} :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} :
|
||||
ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
ConvolutionalLayer {64, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} :
|
||||
ConvolutionalLayer {64, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} :
|
||||
Dropout :
|
||||
|
@ -86,13 +86,13 @@ TrainConvNetWithBN = [
|
|||
|
||||
model = Sequential (
|
||||
Subtract128 :
|
||||
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 0.0043} :
|
||||
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 0.0043} :
|
||||
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
ConvolutionalLayer {64, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
ConvolutionalLayer {64, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
LinearLayer {64, bias = false, init = "gaussian", initValueScale = 12} :
|
||||
|
|
|
@ -18,33 +18,33 @@ TrainConvNet = [
|
|||
labelDim = 10
|
||||
|
||||
# basic model
|
||||
Subtract128 (x) = x - Constant (128)
|
||||
model_basic (features) =
|
||||
{
|
||||
featNorm = features - Constant (128)
|
||||
l1 = ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 0.0043} (featNorm)
|
||||
# BUGBUG: pad=true desired?
|
||||
l1 = ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 0.0043} (featNorm)
|
||||
p1 = MaxPoolingLayer {(3:3), stride = (2:2)} (l1)
|
||||
l2 = ConvolutionalLayer {32, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} (p1)
|
||||
l2 = ConvolutionalLayer {32, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} (p1)
|
||||
p2 = MaxPoolingLayer {(3:3), stride = (2:2)} (l2)
|
||||
l3 = ConvolutionalLayer {64, (5:5), activation = ReLU, init = "gaussian", initValueScale = 1.414} (p2)
|
||||
l3 = ConvolutionalLayer {64, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = 1.414} (p2)
|
||||
p3 = MaxPoolingLayer {(3:3), stride = (2:2)} (l3)
|
||||
d1 = DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} (p3)
|
||||
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
|
||||
}.z
|
||||
|
||||
# with self-defined layers
|
||||
MyConvLayer {dim, initValueScale} =
|
||||
MyConvReLUPoolLayer {dim, initValueScale} =
|
||||
{
|
||||
C = ConvolutionalLayer {dim, (5:5), activation = ReLU, init = "gaussian", initValueScale = initValueScale}
|
||||
C = ConvolutionalLayer {dim, (5:5), pad = true, activation = ReLU, init = "gaussian", initValueScale = initValueScale}
|
||||
P = MaxPoolingLayer {(3:3), stride = (2:2)}
|
||||
f(x) = P(C(x))
|
||||
}.f
|
||||
model_layers (features) =
|
||||
{
|
||||
featNorm = features - Constant (128)
|
||||
h1 = MyConvLayer {32, 0.0043} (featNorm)
|
||||
h2 = MyConvLayer {32, 1.414} (h1)
|
||||
h3 = MyConvLayer {64, 1.414} (h2)
|
||||
h1 = MyConvReLUPoolLayer {32, 0.0043} (featNorm)
|
||||
h2 = MyConvReLUPoolLayer {32, 1.414} (h1)
|
||||
h3 = MyConvReLUPoolLayer {64, 1.414} (h2)
|
||||
d1 = DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} (h3)
|
||||
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
|
||||
}.z
|
||||
|
@ -53,29 +53,30 @@ TrainConvNet = [
|
|||
model_layerStack (features) =
|
||||
{
|
||||
featNorm = features - Constant (128)
|
||||
h3 = LayerStack {3, i => MyConvLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
|
||||
h3 = LayerStack {3, i => MyConvReLUPoolLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
|
||||
d1 = DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} (h3)
|
||||
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
|
||||
}.z
|
||||
|
||||
# model-composition style
|
||||
# ...TODO: test this again; last run was a little worse
|
||||
Subtract128 (x) = x - Constant (128)
|
||||
model_compositionStyle = Sequential (
|
||||
Subtract128 :
|
||||
LayerStack {3, i => MyConvLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } :
|
||||
LayerStack {3, i => MyConvReLUPoolLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } :
|
||||
DenseLayer {64, activation = ReLU, init = "gaussian", initValueScale = 12} :
|
||||
LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5}
|
||||
)
|
||||
|
||||
// --- with BatchNorm
|
||||
MyBNConvLayer {dim, initValueScale} =
|
||||
MyConvBNReLUPoolLayer {dim, initValueScale} =
|
||||
{
|
||||
C = ConvolutionalLayer {dim, (5:5), bias = false, init = "gaussian", initValueScale = initValueScale}
|
||||
C = ConvolutionalLayer {dim, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = initValueScale}
|
||||
B = BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096}
|
||||
P = MaxPoolingLayer {(3:3), stride = (2:2)}
|
||||
f = Sequential (C:B:ReLU:P)
|
||||
}.f
|
||||
MyBNDenseLayer {dim, initValueScale} =
|
||||
MyDenseBNReLULayer {dim, initValueScale} =
|
||||
{
|
||||
D = DenseLayer {dim, bias = false, init = "gaussian", initValueScale = initValueScale}
|
||||
B = BatchNormalizationLayer {normalizationTimeConstant = 4096}
|
||||
|
@ -84,17 +85,69 @@ TrainConvNet = [
|
|||
model_withBatchNorm (features) =
|
||||
{
|
||||
featNorm = features - Constant (128)
|
||||
h3 = LayerStack {3, i => MyBNConvLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
|
||||
d1 = MyBNDenseLayer {64, 12} (h3)
|
||||
h3 = LayerStack {3, i => MyConvBNReLUPoolLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
|
||||
d1 = MyDenseBNReLULayer {64, 12} (h3)
|
||||
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 1.5} (d1)
|
||||
}.z
|
||||
|
||||
// --- ResNet
|
||||
MyConvBNLayer {dim, initValueScale, stride} =
|
||||
{
|
||||
# note: (3:3), while the macro above is (5:5)
|
||||
C = ConvolutionalLayer {dim, (3:3), pad = true, stride = (stride:stride), bias = false, init = "gaussian", initValueScale = initValueScale}
|
||||
B = BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096}
|
||||
f = Sequential (C:B)
|
||||
}.f
|
||||
MyConvBNReLULayer {dim, initValueScale, stride} = # TODO: get rid of this
|
||||
{
|
||||
CB = MyConvBNLayer {dim, initValueScale, stride}
|
||||
f = Sequential (CB:ReLU)
|
||||
}.f
|
||||
ResNetNode {dim, initValueScale} =
|
||||
{
|
||||
C1 = MyConvBNReLULayer {dim, initValueScale, 1} # first convolution layer
|
||||
C2 = MyConvBNLayer {dim, initValueScale, 1} # second convolution layer, no ReLU
|
||||
f(x) = ReLU (x + C2(C1(x)))
|
||||
}.f
|
||||
ResNetIncNode {dim, initValueScale} =
|
||||
{
|
||||
# one branch. C2 o c1 doubles the #channels but halves the image size
|
||||
C1 = MyConvBNReLULayer {dim, initValueScale, 2} # first convolution layer, stride = 2
|
||||
C2 = MyConvBNLayer {dim, initValueScale, 1} # second convolution layer, no ReLU
|
||||
|
||||
# second branch:
|
||||
# - sub-sample spatially by a factor of 2
|
||||
# - append dim/2 zero output channels
|
||||
Down2 = MaxPoolingLayer {(1:1), stride = (2:2)} # this is a function that sub-samples spatially by a factor of 2
|
||||
pad = ConstantTensor (0, (1:1:dim/2)) # the 1s will broadcast to image size
|
||||
P(x) = Splice ((Down2(x) : pad), axis = 3)
|
||||
B = BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096}
|
||||
|
||||
# layer sums both branches and rectifies the result
|
||||
f(x) = ReLU (B(P(x)) + C2(C1(x)))
|
||||
}.f
|
||||
model_resNet (features) =
|
||||
{
|
||||
conv1 = MyConvBNReLULayer {16, 0.26, 1} (features)
|
||||
rn1 = LayerStack {3, i=> ResNetNode {16, 7.07}} (conv1)
|
||||
|
||||
rn2_1 = ResNetIncNode {32, 7.07} (rn1)
|
||||
rn2 = LayerStack {2, i=> ResNetNode {32, 7.07}} (rn2_1)
|
||||
|
||||
rn3_1 = ResNetIncNode {64, 7.07} (rn2)
|
||||
rn3 = LayerStack {2, i=> ResNetNode {64, 7.07}} (rn3_1)
|
||||
|
||||
pool = AveragePoolingLayer {(8:8)} (rn3)
|
||||
|
||||
z = LinearLayer {labelDim, init = "gaussian", initValueScale = 0.4} (pool)
|
||||
}.z
|
||||
|
||||
# inputs
|
||||
features = Input {imageShape}
|
||||
labels = Input {labelDim}
|
||||
|
||||
# apply model to features
|
||||
z = model_withBatchNorm (features)
|
||||
z = model_resNet (features)
|
||||
|
||||
# connect to system
|
||||
ce = CrossEntropyWithSoftmax (labels, z)
|
||||
|
@ -109,21 +162,27 @@ TrainConvNet = [
|
|||
]
|
||||
|
||||
SGD = [
|
||||
epochSize = 49984
|
||||
epochSize = 50000 # 49984 --TODO: why 16 less?
|
||||
|
||||
# without BatchNormalization:
|
||||
#minibatchSize = 64
|
||||
#maxEpochs = 30 ; minibatchSize = 64
|
||||
#learningRatesPerSample = 0.00015625*10:0.000046875*10:0.000015625
|
||||
#momentumAsTimeConstant = 600*20:6400
|
||||
#maxEpochs = 30
|
||||
#L2RegWeight = 0.03
|
||||
|
||||
# with BatchNormalization:
|
||||
minibatchSize = 64
|
||||
learningRatesPerSample = 0.00046875*7:0.00015625
|
||||
momentumAsTimeConstant = 0
|
||||
maxEpochs = 30
|
||||
L2RegWeight = 0
|
||||
#maxEpochs = 30 ; minibatchSize = 64
|
||||
#learningRatesPerSample = 0.00046875*7:0.00015625
|
||||
#momentumAsTimeConstant = 0
|
||||
#L2RegWeight = 0
|
||||
|
||||
# ResNet
|
||||
maxEpochs = 160 ; minibatchSize = 128
|
||||
learningRatesPerSample = 0.0078125*80:0.00078125*40:0.000078125
|
||||
momentumAsTimeConstant = 1200
|
||||
L2RegWeight = 0.0001
|
||||
#learningRatesPerMB = 1.0*80:0.1*40:0.01
|
||||
#momentumPerMB = 0.9
|
||||
|
||||
firstMBsToShowResult = 10 ; numMBsToShowResult = 500
|
||||
]
|
||||
|
@ -156,13 +215,13 @@ TrainConvNetWithBN = [
|
|||
|
||||
model = Sequential (
|
||||
Subtract128 :
|
||||
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 0.0043} :
|
||||
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 0.0043} :
|
||||
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
ConvolutionalLayer {32, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
ConvolutionalLayer {32, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
ConvolutionalLayer {64, (5:5), bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
ConvolutionalLayer {64, (5:5), pad = true, bias = false, init = "gaussian", initValueScale = 1.414} :
|
||||
BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = 4096} : ReLU :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
LinearLayer {64, bias = false, init = "gaussian", initValueScale = 12} :
|
||||
|
|
|
@ -66,9 +66,9 @@ ConvolutionalLayer {numOutputChannels, # e.g. (1) or BS.Constants.None
|
|||
bias = true,
|
||||
activation = (x=>x),
|
||||
init = "uniform",
|
||||
initValueScale = 1,
|
||||
initValueScale = 1, # TODO: rename to initScale
|
||||
#reductionRank = 1, # TODO: support this
|
||||
stride = 1, autoPadding = true,
|
||||
stride = 1, pad = false,
|
||||
lowerPad = 0, upperPad = 0,
|
||||
#transpose = false, # TODO: support this
|
||||
maxTempMemSizeInSamples = 0} =
|
||||
|
@ -84,7 +84,7 @@ ConvolutionalLayer {numOutputChannels, # e.g. (1) or BS.Constants.None
|
|||
sharing = true # TODO: support this
|
||||
transpose = false # TODO: support this
|
||||
f(x) = {
|
||||
c = Convolution (W, x, filterShape, mapDims = numOutputChannels, stride = stride, sharing = sharing, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad, transpose = transpose, maxTempMemSizeInSamples = maxTempMemSizeInSamples)
|
||||
c = Convolution (W, x, filterShape, mapDims = numOutputChannels, stride = stride, sharing = sharing, autoPadding = pad, lowerPad = lowerPad, upperPad = upperPad, transpose = transpose, maxTempMemSizeInSamples = maxTempMemSizeInSamples)
|
||||
res = activation (if bias then c + b else c)
|
||||
}.res
|
||||
}.f
|
||||
|
@ -92,15 +92,15 @@ ConvolutionalLayer {numOutputChannels, # e.g. (1) or BS.Constants.None
|
|||
# MaxPoolingLayer, AveragePoolingLayer -- create a max- or average-pooling layer
|
||||
_PoolingLayer {poolKind, # "max" or "average"
|
||||
filterShape, # e.g. (3:3)
|
||||
stride = 1, autoPadding = false,
|
||||
stride = 1, pad = false,
|
||||
lowerPad = 0, upperPad = 0} = # TODO: support this
|
||||
{
|
||||
f(x) = Pooling (x, poolKind, filterShape, stride = stride, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad)
|
||||
f(x) = Pooling (x, poolKind, filterShape, stride = stride, autoPadding = pad, lowerPad = lowerPad, upperPad = upperPad)
|
||||
}.f
|
||||
MaxPoolingLayer {filterShape, stride = 1, autoPadding = false, lowerPad = 0, upperPad = 0} =
|
||||
_PoolingLayer {"max", filterShape, stride = stride, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad}
|
||||
AveragePoolingLayer {filterShape, stride = 1, autoPadding = false, lowerPad = 0, upperPad = 0} =
|
||||
_PoolingLayer {"average", filterShape, stride = stride, autoPadding = autoPadding, lowerPad = lowerPad, upperPad = upperPad}
|
||||
MaxPoolingLayer {filterShape, stride = 1, pad = false, lowerPad = 0, upperPad = 0} =
|
||||
_PoolingLayer {"max", filterShape, stride = stride, pad = pad, lowerPad = lowerPad, upperPad = upperPad}
|
||||
AveragePoolingLayer {filterShape, stride = 1, pad = false, lowerPad = 0, upperPad = 0} =
|
||||
_PoolingLayer {"average", filterShape, stride = stride, pad = pad, lowerPad = lowerPad, upperPad = upperPad}
|
||||
|
||||
# RecurrentLSTMLayer -- create an LSTM layer
|
||||
RecurrentLSTMLayer {outputDim,
|
||||
|
@ -441,8 +441,8 @@ WeightedLogistic(label, probability, instanceWeight, tag='') = new ComputationNo
|
|||
ReconcileDynamicAxis(dataInput, layoutInput, tag='') = new ComputationNode [ operation = 'ReconcileDynamicAxis' ; inputs = (dataInput : layoutInput) /*plus the function args*/ ]
|
||||
ReconcileMBLayout = ReconcileDynamicAxis # back compat
|
||||
CastAs (type, data) = ReconcileDynamicAxis (data, type) # read as CastAs<type>(data) where the cast may consist of rearranging the data w.r.t. MBLayout or broadcasting across sequence items
|
||||
# ND convo & pooling/unpooling --why is autoPadding true? Normally one would want to reduce dimensions, no?
|
||||
Convolution(weightNode, inputValueNode, kernelDims, mapDims = 0, stride = 1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, transpose=false, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'Convolution' ; inputs = (weightNode : inputValueNode); kernelShape = new TensorShape [ dims = kernelDims ] ; mapCount = new TensorShape [ dims = mapDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimSharing = new BoolVector [ items = sharing ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]
|
||||
# ND pooling/unpooling --why is autoPadding true? Normally one would want to reduce dimensions, no?
|
||||
Pooling(input, poolKind/*'max'|'average'*/, kernelDims, stride=1, autoPadding = true, lowerPad = 0, upperPad = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'Pooling' ; inputs = (input); pool = poolKind ; kernelShape = new TensorShape [ dims = kernelDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]
|
||||
MaxUnpooling(unpoolInput, poolInput, kernelDims, stride=1, autoPadding = true, lowerPad = 0, upperPad = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'MaxUnpooling' ; inputs = (unpoolInput : poolInput); kernelShape = new TensorShape [ dims = kernelDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]
|
||||
# 2D pooling
|
||||
|
|
Загрузка…
Ссылка в новой задаче