diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index fdff1da5e..6f490bf77 100644 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -2386,10 +2386,14 @@ namespace CNTK InvalidArgument("Input argument epsilon must be non-negative."); auto operandPlaceholder = PlaceholderVariable(L"operand"); size_t operandRank = operand.Shape().Rank(); - if (operandRank < 2 && !useStatsAcrossChannels) - InvalidArgument("When rank of the operand is < 2, useStatsAcrossChannels must be set to false, because there is no channel dimension."); + size_t numAxesToReduce; + if (operandRank < 1) + InvalidArgument("The rank of the operand must be >= 1."); + else if (operandRank < 2) + numAxesToReduce = operandRank; // Operand's a vector, useStatsAcrossChannels is ignored and mean is computed over the vector. + else + numAxesToReduce = useStatsAcrossChannels ? operandRank : operandRank - 1; // Assuming last dim to be the channel dim. - auto numAxesToReduce = useStatsAcrossChannels ? operandRank : operandRank - 1; // Assuming last dim to be the channel dim. std::vector axesToReduce(numAxesToReduce); for (size_t i = 0; i < numAxesToReduce; ++i) axesToReduce[i] = Axis(i); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 3124874c1..54f233a7d 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -1274,20 +1274,8 @@ std::tuple, bool, int, bool> CNTKToONNXHelper::CalculateBroadca } // prepare an input node arg with correct name and meta data so that LotusIR can make the connection. -void CNTKToONNXHelper::PrepareRNNInput(const Variable &X, std::vector &nodeInputs) +void CNTKToONNXHelper::PrepareRNNInput(const Variable &input, std::vector &nodeInputs) { - Variable input; - wstring opName = X.Owner() ? X.Owner()->OpName() : L""; - if (X.BlockFunctionVariableMapping().IsInitialized() && !Operators::IsRNNOp(ToString(opName))) - { - input = X.BlockFunctionVariableMapping(); - } - else - { - input = X; - } - - std::string inputName = ToString(input.Uid()); onnx::TypeProto inputArgType = ToTypeProto(input.Shape(), (int)(input.DynamicAxes().size())); @@ -3122,7 +3110,7 @@ ONNXIR::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, ONNXIR::Graph* g mulNode->AddAttribute("broadcast", static_cast(1)); auto input2 = inputs[biasIndexInOnnxInputs]; - ONNXIR::NodeArg addTensorOutputArg(nodeName + string("_add_output0"), &input0ArgType); + ONNXIR::NodeArg addTensorOutputArg(nodeName + string("_Output_0"), &input0ArgType); node = graph->AddNode(nodeName + string("_add"), "Add", "", { mulTensorOutputArg, input2 }, { addTensorOutputArg }); node->AddAttribute("broadcast", static_cast(1));