fix squeezenet onnx_backend_test of issues with Softmax, hardmax, logsoftmax and flatten with freedimension on batchaxis.
This commit is contained in:
Родитель
541e2100eb
Коммит
ceaec5636f
|
@ -1649,8 +1649,22 @@ namespace CNTK
|
|||
(int)operand.Shape().Rank(), ToLegacyString(ToUTF8(axis.AsString())).c_str());
|
||||
}
|
||||
|
||||
size_t dim0 = cntk_index == 0 ? 1 : operand.Shape().SubShape(0, cntk_index).TotalSize();
|
||||
size_t dim1 = cntk_index == operand.Shape().Rank() ? 1 : operand.Shape().SubShape(cntk_index).TotalSize();
|
||||
auto getFlattenedDim = [&](bool isDim0) -> size_t {
|
||||
const NDShape& operandSubShape = isDim0 ? operand.Shape().SubShape(0, cntk_index) : operand.Shape().SubShape(cntk_index);
|
||||
// If subshape contains free or inferred dimension, we have reshape node try to infer what should the flattened dimension be.
|
||||
if (operandSubShape.HasFreeDimension() || operandSubShape.HasInferredDimension())
|
||||
{
|
||||
return NDShape::InferredDimension;
|
||||
}
|
||||
|
||||
if ((isDim0 && cntk_index == 0) || (!isDim0 && cntk_index == operand.Shape().Rank()))
|
||||
return 1u;
|
||||
else
|
||||
return operandSubShape.TotalSize();
|
||||
};
|
||||
|
||||
size_t dim0 = getFlattenedDim(true);
|
||||
size_t dim1 = getFlattenedDim(false);
|
||||
|
||||
NDShape newShape({ dim0, dim1 });
|
||||
|
||||
|
|
|
@ -2442,7 +2442,12 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
{
|
||||
cntkFunction = Hardmax(input, ToFixedWStringFromMultiByte(node->Name()));
|
||||
}
|
||||
return Reshape(cntkFunction, inputs[0].Shape());
|
||||
NDShape originalShape = inputs[0].Shape();
|
||||
assert(originalShape.Rank() > 0);
|
||||
// If original shape has free dimension(batch axis), we'll need to have reshape node infer that for us.
|
||||
if (originalShape[originalShape.Rank() - 1] == NDShape::FreeDimension)
|
||||
originalShape[originalShape.Rank() - 1] = NDShape::InferredDimension;
|
||||
return Reshape(cntkFunction, originalShape);
|
||||
}
|
||||
else if (onnxOpName == "Softplus")
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче