Adding mean_variance_normalization CNTK and ONNX op, and LayerNormalization ONNX suort.

This commit is contained in:
Spandan Tiwari 2018-02-16 00:41:08 -08:00
Родитель 5f1d710709
Коммит 4017a1664a
11 изменённых файлов: 224 добавлений и 5 удалений

Просмотреть файл

@ -4505,6 +4505,11 @@ namespace CNTK
///
CNTK_API FunctionPtr PerDimMeanVarianceNormalize(const Variable& operand, const Variable& mean, const Variable& invStdDev, const std::wstring& name = L"");
///
/// Mean-variance normalization of the specified input operand.
///
CNTK_API FunctionPtr MeanVarianceNormalization(const Variable& operand, const bool useStatsAcrossChannels = false, const bool doVarianceScaling = true, const std::wstring& name = L"");
///
/// Per dimension mean-variance normalization of the specified input operand.
///

Просмотреть файл

@ -2340,7 +2340,7 @@ namespace CNTK
size_t channels = operand.Shape()[2];
if (channels != biases.size())
LogicError("ImageScaler: number of biase (%d) does not equal channels of the image (%d)", (int)biases.size(), (int)(channels));
LogicError("ImageScaler: number of biases (%d) does not equal channels of the image (%d)", (int)biases.size(), (int)(channels));
auto additionalProperties = Dictionary();
additionalProperties[L"Scaler"] = scale;
@ -2375,6 +2375,32 @@ namespace CNTK
return AsBlock(std::move(ElementTimes(Minus(operandPlaceholder, meanPlaceholder), invStdDevPlaceholder)), { { operandPlaceholder, operand },{ meanPlaceholder, mean },{ invStdDevPlaceholder, invStdDev } }, L"PerDimMeanVarianceNormalize", name);
}
FunctionPtr MeanVarianceNormalization(const Variable& operand, const bool useStatsAcrossChannels, const bool doVarianceScaling, const std::wstring& name)
{
Dictionary additionalAttributes;
additionalAttributes[PrimitiveFunction::AttributeNameUseStatsAcrossChannels] = useStatsAcrossChannels;
additionalAttributes[PrimitiveFunction::AttributeNameDoVarianceScaling] = doVarianceScaling;
auto operandPlaceholder = PlaceholderVariable(L"operand");
size_t operandRank = operand.Shape().Rank();
if (operandRank < 2 && !useStatsAcrossChannels)
InvalidArgument("When rank of the operand is < 2, useStatsAcrossChannels must be set to false, because there is no channel dimension.");
auto numAxesToReduce = useStatsAcrossChannels ? operandRank : operandRank - 1; // Assuming last dim to be the channel dim.
std::vector<Axis> axesToReduce(numAxesToReduce);
for (size_t i = 0; i < numAxesToReduce; ++i)
axesToReduce[i] = Axis(i);
FunctionPtr operandMeanRemoved = Minus(operandPlaceholder, ReduceMean(operandPlaceholder, axesToReduce));
if (!doVarianceScaling)
{
return AsBlock(std::move(operandMeanRemoved), { { operandPlaceholder, operand } }, std::move(additionalAttributes), L"MeanVarianceNormalization", name);
}
else
{
return AsBlock(std::move(ElementDivide(operandMeanRemoved, Sqrt(ReduceMean(Square(operandMeanRemoved), axesToReduce)))),
{ { operandPlaceholder, operand } }, std::move(additionalAttributes), L"MeanVarianceNormalization", name);
}
}
FunctionPtr Convolution(const Variable& convolutionMap,
const Variable& operand,
const NDShape& strides,

Просмотреть файл

@ -312,6 +312,8 @@ namespace CNTK
static const std::wstring AttributeNameCustomAttributes;
static const std::wstring AttributeNameNumItems;
static const std::wstring AttributeNameFillValue;
static const std::wstring AttributeNameUseStatsAcrossChannels;
static const std::wstring AttributeNameDoVarianceScaling;
protected:
PrimitiveFunction(PrimitiveOpType op, const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid)

Просмотреть файл

@ -115,4 +115,6 @@ namespace CNTK
/*static*/ const std::wstring PrimitiveFunction::AttributeNameCustomAttributes = L"customAttributes";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameNumItems = L"numItems";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameFillValue = L"fillValue";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameUseStatsAcrossChannels = L"useStatsAcrossChannels";
/*static*/ const std::wstring PrimitiveFunction::AttributeNameDoVarianceScaling = L"doVarianceScaling";
}

Просмотреть файл

@ -1276,6 +1276,13 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node
node->AddAttribute("scale", scale);
node->AddAttribute("bias", biases);
}
else if (src->OpName() == L"MeanVarianceNormalization")
{
auto useStatsAcrossChannels = (int64_t)(src->Attributes()[L"useStatsAcrossChannels"].Value<bool>());
auto doVarianceScaling = (int64_t)(src->Attributes()[L"doVarianceScaling"].Value<bool>());
node->AddAttribute(attributesMap[L"useStatsAcrossChannels"], useStatsAcrossChannels);
node->AddAttribute(attributesMap[L"doVarianceScaling"], doVarianceScaling);
}
}
else
{
@ -1468,6 +1475,7 @@ ONNXIR::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, ONNXIR::Graph* g
}
}
else
{
//
// CNTK Times OP is way more flexible for ONNX, so depend on the inputs and output shape,
// we will need to insert some reshapes.
@ -1505,8 +1513,37 @@ ONNXIR::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, ONNXIR::Graph* g
else
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
}
else if (src->OpName() == L"LayerNormalization")
{
// Special handling of LayerNormalization to use MeanVarianceNormalization (and not reduce* ops).
// This assumes that the orderedInputs are in the order:
// [0]: tensor operand, [1]: scale constant, [2]: bias constant.
// Also assumes that tensor operand is index [2] in src->Inputs().
auto input0 = orderedInputs[0];
onnx::TypeProto input0ArgType = ToTypeProto(src->Inputs()[2].Shape(), src->Inputs()[2].HasBatchAxis());
UpdateONNXType(src->Inputs()[2].GetDataType(), input0ArgType);
ONNXIR::NodeArg mvnTensorOutputArg(nodeName + string("_mvn_output0"), &input0ArgType);
ONNXIR::Node* mvnNode = graph->AddNode(nodeName + string("_MVN"), "MeanVarianceNormalization",
"", { input0 }, { mvnTensorOutputArg });
mvnNode->AddAttribute("across_channels", static_cast<int64_t>(1));
mvnNode->AddAttribute("normalize_variance", static_cast<int64_t>(1));
auto input1 = orderedInputs[1];
ONNXIR::NodeArg mulTensorOutputArg(nodeName + string("_mul_output0"), &input0ArgType);
ONNXIR::Node* mulNode = graph->AddNode(nodeName + string("_mul"), "Mul",
"", { mvnTensorOutputArg, input1 }, { mulTensorOutputArg });
mulNode->AddAttribute("broadcast", static_cast<int64_t>(1));
auto input2 = orderedInputs[2];
ONNXIR::NodeArg addTensorOutputArg(nodeName + string("_add_output0"), &input0ArgType);
node = graph->AddNode(nodeName + string("_add"), "Add",
"", { mulTensorOutputArg, input2 }, { addTensorOutputArg });
node->AddAttribute("broadcast", static_cast<int64_t>(1));
}
else
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
}
//
// Copy and validate attributes.

Просмотреть файл

@ -1752,6 +1752,12 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
return ImageScaler(inputs[0], scale, bias, ToWString(node->Name()));
}
}
else if (onnxOpName == "MeanVarianceNormalization")
{
size_t acrossChannels = GetNamedAttributeAsInt64(node, "across_channels", 0);
size_t normalizeVariance = GetNamedAttributeAsInt64(node, "normalize_variance", 1);
return MeanVarianceNormalization(inputs[0], !!acrossChannels, !!normalizeVariance, ToWString(node->Name()));
}
else
{
LogicError("ONNX (%s) is not supported in CNTK", onnxOpName.c_str());

Просмотреть файл

@ -70,7 +70,7 @@ namespace ONNX
{ L"epsilon", "epsilon" },
// { L"", "momentum" },
} } },
// from ONNX experiament, added to test Caffe models
// from ONNX experiment, added to test Caffe models
// TODO: set key as BatchNormalization instead of BatchNormalizationCaffe
{ L"BatchNormalizationCaffe",{ {
{ L"BatchNormalization", "SpatialBN" },
@ -78,7 +78,13 @@ namespace ONNX
// { L"", "is_test" },
{ L"epsilon", "epsilon" },
// { L"", "momentum" },
} } },
} } },
{ L"LayerNormalization",{ {
{ L"LayerNormalization", "LayerNormalization" },
{ L"initial_scale", "initial_scale" },
{ L"initial_bias", "initial_bias" },
{ L"epsilon", "epsilon" },
} } },
{ L"LocalResponseNormalization",{ {
{ L"LocalResponseNormalization", "LRN" },
{ L"size", "size" },
@ -359,6 +365,11 @@ namespace ONNX
{ L"ImageScaler",{ {
{ L"ImageScaler", "ImageScaler" },
} } },
{ L"MeanVarianceNormalization",{ {
{ L"MeanVarianceNormalization", "MeanVarianceNormalization" },
{ L"useStatsAcrossChannels", "across_channels" },
{ L"doVarianceScaling", "normalize_variance" },
} } },
};
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,
@ -442,6 +453,7 @@ namespace ONNX
{ L"Times",{ 1, 0 } },
{ L"Gather",{ 1, 0 } },
{ L"PReLU",{ 1, 0 } },
{ L"LayerNormalization",{ 1, 2, 0 } },
};
//

Просмотреть файл

@ -428,6 +428,18 @@ namespace ONNXIR {
.Attr("normalize_variance", "If false, normalize the mean only. Default is true.",
AttrType::AttributeProto_AttributeType_INT, int64_t(1));
// Manually added on 2/14/2018.
REGISTER_OPERATOR_SCHEMA(MeanVarianceNormalization)
.Description("Perform mean variance normalization.")
.Input("input", "Input tensor of any shape", "T")
.Output("output", "Output tensor of same shape and type as input X.", "T")
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, "Constrain input and output "
"types to float tensors.")
.Attr("across_channels", "If true, mean and variance are computed across channels. "
"Default is false.", AttrType::AttributeProto_AttributeType_INT, int64_t(0))
.Attr("normalize_variance", "If false, normalize the mean only. Default is true.",
AttrType::AttributeProto_AttributeType_INT, int64_t(1));
REGISTER_OPERATOR_SCHEMA(LpNormalization)
.Description("Given a matrix, apply Lp-normalization along the provided axis. "
"For RS4 default of p = 2 and it will perform L2 normalization. Divide each "

Просмотреть файл

@ -3912,4 +3912,47 @@ def cast(node_input, dtype, name=''):
from cntk.cntk_py import cast
arg_node_input = sanitize_input(node_input, get_data_type(node_input))
arg_dtype = sanitize_dtype_cntk(dtype)
return cast(arg_node_input, arg_dtype, name)
return cast(arg_node_input, arg_dtype, name)
@typemap
def mean_variance_normalization(operand, use_stats_across_channels = False, do_variance_scaling = True, name=''):
'''
Computes mean-variance normalization of the specified input operand.
This operation computes and mean and variance for the entire tensor if use_stats_across_channels is True.
If use_stats_across_channels is False the computes mean and variance per channel and normalizes each
channel with its own mean and variance. If do_variance_scaling is False, only the mean is subtracted,
and the variance scaling is omitted.
Example:
>>> data = np.array([[[0., 2], [4., 6.]], [[0., 4], [8., 12.]]]).astype(np.float32)
>>> data
array([[[ 0., 2.],
[ 4., 6.]],
<BLANKLINE>
[[ 0., 4.],
[ 8., 12.]]], dtype=float32)
>>> saved_precision = np.get_printoptions()['precision']
>>> np.set_printoptions(precision=4) # For consistent display upto 4 decimals.
>>> C.mean_variance_normalization(data).eval()
array([[[-1.3416, -0.4472],
[ 0.4472, 1.3416]],
<BLANKLINE>
[[-1.3416, -0.4472],
[ 0.4472, 1.3416]]], dtype=float32)
>>> np.set_printoptions(precision=saved_precision) # Reseting the display precision.
Args:
operand: Input tensor, with dimensions :math:`[C \\times H \\times W]`.
use_stats_across_channels (bool): If False, mean and variance are computed per channel.
If True, mean and variance are computed over the entire tensor (all axes).
do_variance_scaling (bool): If False, only the mean is subtracted. If True, it is also
scaled by inverse of standard deviation.
name (str, optional): the name of the Function instance in the network
Returns:
:class:`~cntk.ops.functions.Function`
'''
from cntk.cntk_py import mean_variance_normalization
operand = sanitize_input(operand, get_data_type(operand))
return mean_variance_normalization(operand, use_stats_across_channels, do_variance_scaling, name)

Просмотреть файл

@ -462,3 +462,52 @@ def test_auto_broadcast_reconcile_issue():
# check does the reconcile_dynamic_axes call trigger the auto broadcast
assert len(inputs) == 2
assert inputs[0].name == 'y' and inputs[1].name == 'x'
MEAN_VARIANCE_NORMALIZATION_DATA = [
(np.array([[[0., 2], # Input tensor
[4., 6.]],
[[0., 4],
[8., 12.]]]),
False, # use_stats_across_channels
False, # do_variance_scaling
np.array([[[-3., -1.], # Output tensor
[1., 3.]],
[[-6., -2],
[2., 6.]]])
),
(np.array([[[0., 2], # Input tensor
[4., 6.]],
[[0., 4],
[8., 12.]]]),
False, # use_stats_across_channels
True, # do_variance_scaling
np.array([[[-1.34163487, -0.44721162],
[ 0.44721162, 1.34163487]],
[[-1.34163785, -0.44721264],
[ 0.44721264, 1.34163785]]])
),
(np.array([[[0., 2], # Input tensor
[4., 6.]],
[[8., 10],
[12., 14.]]]),
True, # use_stats_across_channels
True, # do_variance_scaling
np.array([[[-1.52752209, -1.0910871],
[-0.6546523, -0.21821743]],
[[ 0.21821743, 0.6546523],
[ 1.0910871, 1.52752209]]])
),
]
@pytest.mark.parametrize("input_operand, use_stats_across_channels, do_variance_scaling, output_ref", MEAN_VARIANCE_NORMALIZATION_DATA)
def test_op_mean_variance_normalization(input_operand, use_stats_across_channels, do_variance_scaling, output_ref, device_id, precision):
dt_precision = PRECISION_TO_TYPE[precision]
input_ref = AA(input_operand, dtype=dt_precision)
a = C.input_variable(shape=input_ref.shape,
dtype=sanitize_dtype_cntk(precision),
needs_gradient=False,
name='a')
norm_op = C.mean_variance_normalization(a, use_stats_across_channels=use_stats_across_channels, do_variance_scaling=do_variance_scaling)
output_test = norm_op.eval({a:input_ref}, device=cntk_device(device_id))
assert np.allclose(output_test, output_ref, atol=1e-4)

Просмотреть файл

@ -458,6 +458,15 @@ def test_ImageScaler(tmpdir):
model = C.image_scaler(x, scalar, bias);
verify_one_input(model, image, tmpdir, 'ImageScaler_1')
#LayerNormalization
def test_LayerNormalization(tmpdir):
test_shapes = [(3, 5, 7), (10, ), (20, 31)]
for shape in test_shapes:
data = np.reshape(np.arange(np.prod(shape), dtype = np.float32), shape)
input_operand = C.input_variable(shape=shape)
model0 = model0 = C.layers.LayerNormalization(epsilon=0.0)(input_operand)
verify_one_input(model0, data, tmpdir, 'Pad_0')
#LeakyRelu
def test_LeakyRelu(tmpdir):
data = np.asarray([[-1, -0.5, 0, 1, 2]], dtype=np.float32)
@ -552,6 +561,22 @@ def test_Mean(tmpdir):
verify_two_input(model, in1_data, in2_data, tmpdir, 'Mean_2')
#MeanVarianceNormalization
def test_MeanVarianceNormalization(tmpdir):
shape = (3, 5, 7)
data = np.reshape(np.arange(np.prod(shape), dtype = np.float32), shape)
input_operand = C.input_variable(shape=shape)
model0 = C.mean_variance_normalization(input_operand, use_stats_across_channels=False, do_variance_scaling=True)
verify_one_input(model0, data, tmpdir, 'Pad_0')
model1 = C.mean_variance_normalization(input_operand, use_stats_across_channels=False, do_variance_scaling=False)
verify_one_input(model1, data, tmpdir, 'Pad_1')
model2 = C.mean_variance_normalization(input_operand, use_stats_across_channels=True, do_variance_scaling=True)
verify_one_input(model2, data, tmpdir, 'Pad_2')
#Min
def test_Min(tmpdir):
data0 = np.asarray([1., 1., 1., 1.], dtype=np.float32)
@ -583,7 +608,7 @@ def test_Pad(tmpdir):
x = C.input_variable(shape)
model = C.pad(x, pattern=[(1,1),(2,2)], mode=C.ops.REFLECT_PAD)
verify_one_input(model, data, tmpdir, 'Pad_1')
verify_one_input(model, data, tmpdir, 'Pad_1')
#PRelu
#def test_PRelu(tmpdir):