fix squeezenet onnx_backend_test of issues with Softmax, hardmax, logsoftmax and flatten with freedimension on batchaxis.

This commit is contained in:
Bowen Bao 2018-08-13 17:15:36 -07:00
Родитель 541e2100eb
Коммит ceaec5636f
2 изменённых файлов: 22 добавлений и 3 удалений

Просмотреть файл

@ -1649,8 +1649,22 @@ namespace CNTK
(int)operand.Shape().Rank(), ToLegacyString(ToUTF8(axis.AsString())).c_str());
}
size_t dim0 = cntk_index == 0 ? 1 : operand.Shape().SubShape(0, cntk_index).TotalSize();
size_t dim1 = cntk_index == operand.Shape().Rank() ? 1 : operand.Shape().SubShape(cntk_index).TotalSize();
auto getFlattenedDim = [&](bool isDim0) -> size_t {
const NDShape& operandSubShape = isDim0 ? operand.Shape().SubShape(0, cntk_index) : operand.Shape().SubShape(cntk_index);
// If subshape contains free or inferred dimension, we have reshape node try to infer what should the flattened dimension be.
if (operandSubShape.HasFreeDimension() || operandSubShape.HasInferredDimension())
{
return NDShape::InferredDimension;
}
if ((isDim0 && cntk_index == 0) || (!isDim0 && cntk_index == operand.Shape().Rank()))
return 1u;
else
return operandSubShape.TotalSize();
};
size_t dim0 = getFlattenedDim(true);
size_t dim1 = getFlattenedDim(false);
NDShape newShape({ dim0, dim1 });

Просмотреть файл

@ -2442,7 +2442,12 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
{
cntkFunction = Hardmax(input, ToFixedWStringFromMultiByte(node->Name()));
}
return Reshape(cntkFunction, inputs[0].Shape());
NDShape originalShape = inputs[0].Shape();
assert(originalShape.Rank() > 0);
// If original shape has free dimension(batch axis), we'll need to have reshape node infer that for us.
if (originalShape[originalShape.Rank() - 1] == NDShape::FreeDimension)
originalShape[originalShape.Rank() - 1] = NDShape::InferredDimension;
return Reshape(cntkFunction, originalShape);
}
else if (onnxOpName == "Softplus")
{