Integrate sptiwari/mean_var_norm_issue into master
This commit is contained in:
Коммит
3a229624be
|
@ -2386,10 +2386,14 @@ namespace CNTK
|
|||
InvalidArgument("Input argument epsilon must be non-negative.");
|
||||
auto operandPlaceholder = PlaceholderVariable(L"operand");
|
||||
size_t operandRank = operand.Shape().Rank();
|
||||
if (operandRank < 2 && !useStatsAcrossChannels)
|
||||
InvalidArgument("When rank of the operand is < 2, useStatsAcrossChannels must be set to false, because there is no channel dimension.");
|
||||
size_t numAxesToReduce;
|
||||
if (operandRank < 1)
|
||||
InvalidArgument("The rank of the operand must be >= 1.");
|
||||
else if (operandRank < 2)
|
||||
numAxesToReduce = operandRank; // Operand's a vector, useStatsAcrossChannels is ignored and mean is computed over the vector.
|
||||
else
|
||||
numAxesToReduce = useStatsAcrossChannels ? operandRank : operandRank - 1; // Assuming last dim to be the channel dim.
|
||||
|
||||
auto numAxesToReduce = useStatsAcrossChannels ? operandRank : operandRank - 1; // Assuming last dim to be the channel dim.
|
||||
std::vector<Axis> axesToReduce(numAxesToReduce);
|
||||
for (size_t i = 0; i < numAxesToReduce; ++i)
|
||||
axesToReduce[i] = Axis(i);
|
||||
|
|
|
@ -1274,20 +1274,8 @@ std::tuple<std::vector<int>, bool, int, bool> CNTKToONNXHelper::CalculateBroadca
|
|||
}
|
||||
|
||||
// prepare an input node arg with correct name and meta data so that LotusIR can make the connection.
|
||||
void CNTKToONNXHelper::PrepareRNNInput(const Variable &X, std::vector<ONNXIR::NodeArg> &nodeInputs)
|
||||
void CNTKToONNXHelper::PrepareRNNInput(const Variable &input, std::vector<ONNXIR::NodeArg> &nodeInputs)
|
||||
{
|
||||
Variable input;
|
||||
wstring opName = X.Owner() ? X.Owner()->OpName() : L"";
|
||||
if (X.BlockFunctionVariableMapping().IsInitialized() && !Operators::IsRNNOp(ToString(opName)))
|
||||
{
|
||||
input = X.BlockFunctionVariableMapping();
|
||||
}
|
||||
else
|
||||
{
|
||||
input = X;
|
||||
}
|
||||
|
||||
|
||||
std::string inputName = ToString(input.Uid());
|
||||
onnx::TypeProto inputArgType = ToTypeProto(input.Shape(), (int)(input.DynamicAxes().size()));
|
||||
|
||||
|
@ -3122,7 +3110,7 @@ ONNXIR::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, ONNXIR::Graph* g
|
|||
mulNode->AddAttribute("broadcast", static_cast<int64_t>(1));
|
||||
|
||||
auto input2 = inputs[biasIndexInOnnxInputs];
|
||||
ONNXIR::NodeArg addTensorOutputArg(nodeName + string("_add_output0"), &input0ArgType);
|
||||
ONNXIR::NodeArg addTensorOutputArg(nodeName + string("_Output_0"), &input0ArgType);
|
||||
node = graph->AddNode(nodeName + string("_add"), "Add",
|
||||
"", { mulTensorOutputArg, input2 }, { addTensorOutputArg });
|
||||
node->AddAttribute("broadcast", static_cast<int64_t>(1));
|
||||
|
|
Загрузка…
Ссылка в новой задаче