Updates on several ONNX exports.

* ConvTranspose outputShape: now pads values are always exported even
when outputShape is given. The reason is that CNTK and ONNX have
different padding specs.
* Flatten: in CNTK flatten does not affect batch axis, this should be
preserved in ONNX.
This commit is contained in:
Bowen Bao 2018-10-25 17:50:19 -07:00
Родитель a19ce9ef7c
Коммит 3f46cf0269
4 изменённых файлов: 186 добавлений и 53 удалений

Просмотреть файл

@ -559,18 +559,26 @@ private:
//
static void ValidatePadValueForCeilOutDim(const std::vector<int64_t> lowerPad, const std::vector<int64_t> upperPad, const std::vector<bool>& autoPadding,
const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides, const NDShape& dilation, bool transpose = false);
//
// Check if CNTK node's pad attribute value is provided and valid.
//
static bool IsPadValueValid(const std::vector<int64_t>& lowerPad, const std::vector<int64_t>& upperPad, const std::vector<bool>& autoPadding, const bool ceilOutDim);
//
// Get ONNX 'pads' attribute value based on CNTK node's autoPadding attribute value.
//
static std::pair<std::vector<int>, std::vector<int> > GetONNXPadsAttributeFromCNTKNode(
const std::vector<bool>& cntkAutoPadding, const NDShape& kernelShape, const NDShape& inputShape,
const NDShape& strides, const NDShape& dilation, bool ceilOutDim, bool transpose);
const std::vector<bool>& cntkAutoPadding, const NDShape& kernelShape, const NDShape& inputShape,
const NDShape& strides, const NDShape& dilation, const NDShape& outputShape, bool ceilOutDim, bool transpose);
//
// Adds attributes 'pads' to saved node (typically convolution or pooling).
//
static void PutPadAttrInNode(onnxruntime::Node* node, const std::vector<bool>& autoPadding,
const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides, const NDShape& dilation, bool ceilOutDim = false, bool transpose = false);
const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides, const NDShape& dilation,
bool ceilOutDim = false, bool transpose = false);
static void PutPadAttrInNode(onnxruntime::Node* node, const std::vector<bool>& autoPadding,
const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides, const NDShape& dilation, const NDShape& outputShape,
bool ceilOutDim = false, bool transpose = false);
//
// Takes CNTK's OneHotOp node and converts it into OneHotEncoder op on the ONNX side.
@ -671,12 +679,22 @@ private:
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
std::vector<ScanLoop> &scanLoops, int createLoopIndex);
static onnxruntime::Node * CreateBatchNormalization(const FunctionPtr & src,
onnxruntime::Graph * graph, std::unordered_map<FunctionPtr,
onnxruntime::Node*>& functionNodes,
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
std::vector<ScanLoop> &scanLoops, int createLoopIndex);
// Takes CNTK's Flatten node and converts it into a series of ONNX nodes.
static onnxruntime::Node * CreateONNXNodesForFlatten(const FunctionPtr & src,
onnxruntime::Graph * graph, std::unordered_map<FunctionPtr,
onnxruntime::Node*>& functionNodes,
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
std::vector<ScanLoop> &scanLoops, int createLoopIndex);
//
// Method to create ONNX nodes that have an explicit batch axis from their CNTK
// counterparts.
@ -4122,6 +4140,10 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
scanLoops, createLoopIndex);
}
}
else if (cntkOpName == "Flatten")
{
return CreateONNXNodesForFlatten(src, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
}
//
// If this block node equivalent to a primitive ONNX OP, then treated as such.
@ -4913,7 +4935,11 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
}
// Flatten op takes single axis. It is safe here to assume that the axis is a static axis.
int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]) + 1 /* TODO: Figure out how to remove this hardcoded 1 */;
// ax needs the additional 1 here.
int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]) + 1;
// Flatten op in ONNX doesn't count batch axis.
if (src->Inputs()[0].HasBatchAxis())
ax--;
node->AddAttribute(attributesMap[L"axis"], ax);
}
else if (src->OpName() == L"Squeeze")
@ -5044,17 +5070,39 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
node->AddAttribute("dilations", ToINTS(dilations));
node->AddAttribute("group", (int64_t)groups);
if (transpose)
{
auto outputShape = (NDShape)src->Attributes()[L"outputShape"].Value<NDShape>();
if(outputShape != NDShape({ 0 }))
node->AddAttribute("output_shape", ToINTS(outputShape, src->Inputs()[1].HasBatchAxis()));
// Notes on outputShape vs pads for convolution and convTranspose.
// In onnx spec for convTranspose, quoted here "If output_shape is specified pads values are ignored".
// If outputShape is exported for convTranspose, autopad/pads can be skipped.
}
const NDShape& inputShape = src->Inputs()[1].Shape();
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, dilations, ceilOutDim, transpose);
auto lowerPadShape = (NDShape)src->Attributes()[L"lowerPad"].Value<NDShape>();
auto upperPadShape = (NDShape)src->Attributes()[L"upperPad"].Value<NDShape>();
if (lowerPadShape.Rank() > kernelShape.Rank())
lowerPadShape = lowerPadShape.SubShape(0, lowerPadShape.Rank() - 1);
if (upperPadShape.Rank() > kernelShape.Rank())
upperPadShape = upperPadShape.SubShape(0, upperPadShape.Rank() - 1);
auto lowerPad = ToINTS(lowerPadShape);
auto upperPad = ToINTS(upperPadShape);
if (IsPadValueValid(lowerPad, upperPad, autoPadding, transpose))
{
if (ceilOutDim)
ValidatePadValueForCeilOutDim(lowerPad, upperPad, autoPadding, kernelShape, inputShape, strides,
/*dilation=*/std::vector<size_t>(kernelShape.Rank(), 1), transpose);
lowerPad.insert(lowerPad.end(), upperPad.cbegin(), upperPad.cend());
node->AddAttribute("pads", lowerPad);
}
else
{
if (transpose && src->Attributes().Contains(L"outputShape"))
{
auto outputShape = (NDShape)src->Attributes()[L"outputShape"].Value<NDShape>();
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, dilations, outputShape, ceilOutDim, transpose);
}
else
{
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, dilations, ceilOutDim, transpose);
}
}
}
else if (src->OpName() == L"Pooling")
{
@ -5092,16 +5140,8 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
auto lowerPad = ToINTS(src->Attributes()[L"lowerPad"].Value<NDShape>());
auto upperPad = ToINTS(src->Attributes()[L"upperPad"].Value<NDShape>());
// lowerPad/upperPad is set to NDShape({0}) by default.
// If this node has explicitly set the lowerPad and upperPad values(i.e. nodes that are constructed with lowerPad/upperPad values and autoPadding=False),
// export these values directly. Otherwise, check autoPadding and export accordingly.
bool isAllPadsZero = std::all_of(lowerPad.begin(), lowerPad.end(), [](int64_t i) { return i == 0; });
isAllPadsZero = isAllPadsZero & std::all_of(upperPad.begin(), upperPad.end(), [](int64_t i) { return i == 0; });
bool isAnyAutoPadTrue = std::any_of(autoPadding.begin(), autoPadding.end(), [](bool i) { return i; });
if (lowerPad.size() > 0 && upperPad.size() > 0
&& !(lowerPad.size() == 1 && upperPad.size() == 1 && lowerPad[0] == 0 && upperPad[0] == 0)
&& !(isAllPadsZero && ceilOutDim)
&& !isAnyAutoPadTrue)
if (IsPadValueValid(lowerPad, upperPad, autoPadding, ceilOutDim))
{
if (ceilOutDim)
ValidatePadValueForCeilOutDim(lowerPad, upperPad, autoPadding, kernelShape, inputShape, strides,
@ -5161,7 +5201,8 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
void CNTKToONNXHelper::ValidatePadValueForCeilOutDim(const std::vector<int64_t> lowerPad, const std::vector<int64_t> upperPad, const std::vector<bool>& autoPadding,
const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides, const NDShape& dilations, bool transpose)
{
auto padsValueVectorsForONNX = GetONNXPadsAttributeFromCNTKNode(autoPadding, kernelShape, inputShape, strides, dilations, /*ceilOutDim=*/true, transpose);
auto padsValueVectorsForONNX = GetONNXPadsAttributeFromCNTKNode(autoPadding, kernelShape, inputShape, strides, dilations,
/*outputShape=*/{0}, /*ceilOutDim=*/true, transpose);
auto onnxLowerPads = ToINTS(padsValueVectorsForONNX.first);
auto onnxUpperPads = ToINTS(padsValueVectorsForONNX.second);
@ -5172,10 +5213,31 @@ void CNTKToONNXHelper::ValidatePadValueForCeilOutDim(const std::vector<int64_t>
}
}
bool CNTKToONNXHelper::IsPadValueValid(const std::vector<int64_t>& lowerPad, const std::vector<int64_t>& upperPad, const std::vector<bool>& autoPadding, const bool ceilOutDim)
{
// lowerPad/upperPad is set to NDShape({0}) by default.
// If this node has explicitly set the lowerPad and upperPad values(i.e. nodes that are constructed with lowerPad/upperPad values and autoPadding=False),
// export these values directly. Otherwise, check autoPadding and export accordingly.
bool isAllPadsZero = std::all_of(lowerPad.begin(), lowerPad.end(), [](int64_t i) { return i == 0; });
isAllPadsZero = isAllPadsZero & std::all_of(upperPad.begin(), upperPad.end(), [](int64_t i) { return i == 0; });
bool isAnyAutoPadTrue = std::any_of(autoPadding.begin(), autoPadding.end(), [](bool i) { return i; });
return lowerPad.size() > 0 && upperPad.size() > 0
&& !(lowerPad.size() == 1 && upperPad.size() == 1 && lowerPad[0] == 0 && upperPad[0] == 0)
&& !(isAllPadsZero && ceilOutDim)
&& !isAnyAutoPadTrue;
}
void CNTKToONNXHelper::PutPadAttrInNode(onnxruntime::Node* node,
const std::vector<bool>& autoPadding, const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides,
const NDShape& dilations, bool ceilOutDim, bool transpose)
{
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, dilations, /*outputShape=*/{0}, ceilOutDim, transpose);
}
void CNTKToONNXHelper::PutPadAttrInNode(onnxruntime::Node* node,
const std::vector<bool>& autoPadding, const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides,
const NDShape& dilations, const NDShape& outputShape, bool ceilOutDim, bool transpose)
{
// To fully support CNTK exporting of convolution & pooling ops for all input settings,
// The padding attributes must be exported in 'pads' instead of 'autoPad'.
@ -5184,7 +5246,7 @@ void CNTKToONNXHelper::PutPadAttrInNode(onnxruntime::Node* node,
// E.g.
// operand shape: [7, 8], kernel shape: [2, 3], strides: [2, 2].
// The pad values CNTK generates is [0, 1, 1, 0]. This cannot be expressed in one single "SAME_UPPER" nor "SAME_LOWER".
auto padsValueVectorsForONNX = GetONNXPadsAttributeFromCNTKNode(autoPadding, kernelShape, inputShape, strides, dilations, ceilOutDim, transpose);
auto padsValueVectorsForONNX = GetONNXPadsAttributeFromCNTKNode(autoPadding, kernelShape, inputShape, strides, dilations, outputShape, ceilOutDim, transpose);
auto lowerPads = ToINTS(padsValueVectorsForONNX.first);
auto upperPads = ToINTS(padsValueVectorsForONNX.second);
lowerPads.insert(lowerPads.end(), upperPads.cbegin(), upperPads.cend());
@ -5478,15 +5540,15 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
}
std::pair<std::vector<int>, std::vector<int>> CNTKToONNXHelper::GetONNXPadsAttributeFromCNTKNode(
const std::vector<bool>& cntkAutoPadding, const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides,
const NDShape& dilations, bool ceilOutDim, bool transpose)
const std::vector<bool>& cntkAutoPadding, const NDShape& kernelShape, const NDShape& inputShape, const NDShape& strides,
const NDShape& dilations, const NDShape& outputShape, bool ceilOutDim, bool transpose)
{
// Reuse ConvolveGeometry to compute outputShape and pad values.
// The difference here is that ConvolveGeometry expects parameters to be non-spatial shapes, which includes the channel info that
// we have just excluded. Thus emulated channel axis is inserted.
assert(inputShape.Rank() > 0);
const size_t channelSize = inputShape[inputShape.Rank() - 1];
const NDShape& kernelShapeWithChannel = kernelShape.AppendShape({ channelSize });
NDShape kernelShapeWithChannel = kernelShape.AppendShape({ channelSize });
const NDShape& stridesWithChannel = strides.Rank() == 1 ? strides : strides.AppendShape({ channelSize });
const NDShape& dilationsWithChannel = dilations.Rank() == 1 ? dilations : dilations.AppendShape({ 1 });
std::vector<bool> cntkAutoPaddingWithChannel = cntkAutoPadding.size() > 0 ? cntkAutoPadding : std::vector<bool>({ false });
@ -5500,13 +5562,23 @@ std::pair<std::vector<int>, std::vector<int>> CNTKToONNXHelper::GetONNXPadsAttri
NDShape convOperandShape = inputShape;
if (transpose)
{
// For convolution transpose. Reuse logic in ConvolveGeometry to compute the actual pads value.
// First, get CNTK convTranspose outputShape by ConvolveGeometry::ComputeInputShape.
// Next, treat this as normal convolution, and use the achieved outputShape as inputShape, to compute the pads values.
convOperandShape = AsNDShape(ConvolveGeometry::ComputeInputShape(AsTensorShape(inputShape), AsTensorShape(kernelShapeWithChannel),
/*mapCount=*/AsTensorShape({ 1 }), AsTensorShape(stridesWithChannel), /*sharing=*/std::vector<bool>({ true }), cntkAutoPaddingWithChannel,
/*lowerPad=*/AsTensorShape({ 0 }), /*upperPad=*/AsTensorShape({ 0 }), AsTensorShape(dilationsWithChannel), /*groups=*/1,
ceilOutDim, /*(UNUSED)needsDynamicValidation=*/false, /*(UNUSED)isFinalValidationPass=*/false));
if (outputShape.Rank() == 1 && outputShape[0] == 0)
{
// outputShape is not available.
// For convolution transpose. Reuse logic in ConvolveGeometry to compute the actual pads value.
// First, get CNTK convTranspose outputShape by ConvolveGeometry::ComputeInputShape.
// Next, treat this as normal convolution, and use the achieved outputShape as inputShape, to compute the pads values.
convOperandShape = AsNDShape(ConvolveGeometry::ComputeInputShape(AsTensorShape(inputShape), AsTensorShape(kernelShapeWithChannel),
/*mapCount=*/AsTensorShape({ 1 }), AsTensorShape(stridesWithChannel), /*sharing=*/std::vector<bool>({ true }), cntkAutoPaddingWithChannel,
/*lowerPad=*/AsTensorShape({ 0 }), /*upperPad=*/AsTensorShape({ 0 }), AsTensorShape(dilationsWithChannel), /*groups=*/1,
ceilOutDim, /*(UNUSED)needsDynamicValidation=*/false, /*(UNUSED)isFinalValidationPass=*/false));
}
else
{
convOperandShape = outputShape;
// Use the correct channel size.
kernelShapeWithChannel[kernelShapeWithChannel.Rank() - 1] = convOperandShape[convOperandShape.Rank() - 1];
}
}
auto geometry = std::make_shared<ConvolveGeometry>(AsTensorShape(convOperandShape), AsTensorShape(kernelShapeWithChannel),
@ -6369,6 +6441,60 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForTimesTranspose(const Func
return matmulNode;
}
onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForFlatten(const FunctionPtr &src,
onnxruntime::Graph* graph,
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
std::vector<ScanLoop> &scanLoops, int createLoopIndex)
{
std::vector<onnxruntime::NodeArg *> inputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs,
scanLoops, createLoopIndex);
std::vector<onnxruntime::NodeArg *> outputs;
ProcessOutputs(src, outputs, graph);
auto flattenInput = src->Inputs()[0];
auto flattenOutput = src->Outputs()[0];
const std::string& nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
const std::string& inputNodeName = UniqueNodeNameStorage::GetUniqueInputNodeName(flattenInput);
auto functionNode = [&]() -> onnxruntime::Node* {
if (flattenInput.HasBatchAxis())
{
onnx::TypeProto inputReshapeOut = ToTypeProto(flattenInput.Shape());
onnx::TypeProto outputReshapeIn = ToTypeProto(flattenOutput.Shape());
onnx::TypeProto outputReshapeOut = ToTypeProto(flattenOutput.Shape(), /*hasBatchAxis=*/true);
UpdateONNXType(flattenInput.GetDataType(), inputReshapeOut);
UpdateONNXType(flattenOutput.GetDataType(), outputReshapeIn);
UpdateONNXType(flattenOutput.GetDataType(), outputReshapeOut);
onnxruntime::NodeArg &preReshapeOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(inputNodeName + "_reshape"),
&inputReshapeOut);
onnxruntime::Node* preReshapeNode = AddReshapeNodeImpl(graph, UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_pre_reshape"),
inputs[0], &preReshapeOutputArg, ToINTS(inputReshapeOut));
onnxruntime::NodeArg &postReshapeInputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_out_reshape"),
&outputReshapeIn);
onnxruntime::Node* postReshapeNode = AddReshapeNodeImpl(graph, UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_post_reshape"),
&postReshapeInputArg, outputs[0], ToINTS(outputReshapeOut));
onnxruntime::Node* flattenNode = graph->AddNode(nodeName, ToOPName(src), "", { &preReshapeOutputArg }, { &postReshapeInputArg });
CopyAttributes(src, flattenNode);
return postReshapeNode;
}
else
{
return AddNode(src, graph, inputs, outputs);
}
}();
functionNodes.emplace(src, functionNode);
return functionNode;
}
onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr &src,
onnxruntime::Graph* graph,
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,

Просмотреть файл

@ -143,6 +143,11 @@ skip_round_trip_model_names = [
'test_top_k',
'test_transpose_default',
'test_upsample_nearest',
# Lack proper support for ONNX ConvTranspose output_padding attribute.
'test_convtranspose_kernel_shape',
'test_convtranspose_output_shape',
'test_convtranspose_pad',
'test_convtranspose_with_kernel',
]
skip_cntk_model_names = []

Просмотреть файл

@ -25,7 +25,7 @@ set_of_batch_ops = {'Pooling', 'Convolution', 'GlobalAveragePooling', 'GlobalMax
# of whether the input has batch axis or not.
# Basically, for these ops we don't prepend 1 to the output shape
# when the input has batch axis.
set_of_batch_irrelevant_ops = {'Flatten'}
set_of_batch_irrelevant_ops = {}
##########################################
## helper verification functions
@ -632,7 +632,6 @@ def test_DepthToSpace(tmpdir, dtype):
image_shape = (4, 5)
input_val = np.array(np.reshape(range(num_channels), (num_channels, 1, 1)), dtype=dtype)
input_val = np.tile(input_val, (1,) + image_shape)
input_val.shape = (1,) + input_val.shape
img = C.input_variable((num_channels,) + image_shape, dtype=dtype)
model = C.depth_to_space(img, block_size)
@ -752,11 +751,14 @@ def test_Gather(tmpdir, dtype):
if (dtype == np.float16):
pytest.skip("TO BE FIXED")
with C.default_options(dtype = dtype):
c = np.asarray([[[0],[1]]]).astype(dtype)
#c = np.asarray([[[0],[1]],[[4],[5]]]).astype(dtype) # batch size = 2 not supported yet.
c = np.asarray([[0],[1]]).astype(dtype)
x = C.input_variable((2,1))
d = np.arange(12).reshape(6,2).astype(dtype)
y = C.constant(d)
x_constant = C.constant(c)
model = C.gather(y, x_constant)
verify_no_input(model, tmpdir, 'Gather_0')
model = C.gather(y, x)
verify_one_input(model, c, tmpdir, 'Gather_1')
@ -771,6 +773,7 @@ def test_Gather_With_Axis(tmpdir, dtype):
x = C.input_variable(np.shape(data))
y = C.input_variable(np.shape(indices))
axis = 1
model = C.gather(data, y, axis, 'gather_with_axis')
verify_one_input(model, indices, tmpdir, 'Gather_With_Axis_1')
@ -1848,7 +1851,6 @@ def test_SpaceToDepth(tmpdir, dtype):
image_shape = (12, 15)
input_val = np.array(np.reshape(range(num_channels), (num_channels, 1, 1)), dtype=dtype)
input_val = np.tile(input_val, (1,) + image_shape)
input_val.shape = (1,) + input_val.shape
img = C.input_variable((num_channels,) + image_shape, dtype=dtype)
model = C.space_to_depth(img, block_size)

Просмотреть файл

@ -11,19 +11,19 @@ windows = os.getenv("OS")=="Windows_NT"
known_issues = [
'BatchNormalization_float160',
'SpatialBatchNormalization_float160',
'DepthToSpace',
'RNN',
'test_sequence_slice_',
'test_sequence_slice_0',
'test_sequence_slice_1',
'RNN.reverse.one_layer.relu',
'RNN.bidirectional.two_layer.tanh',
'test_sequence_slice_-1.0',
'test_sequence_slice_0.-1',
'test_sequence_slice_0.1',
'test_sequence_slice_1.-1',
'test_sequence_slice_1.0',
'test_sequence_slice_1.2',
'test_sequence_slice_-2.-1',
'test_sequence_slice_-4.2',
'SequenceSoftmax',
'SpaceToDepth',
'top_k',
'ConvTranspose_with_OutputShape_0',
'Flatten_1',
'Gather_1',
# Not in onnxruntime
'LayerNorm_0',
'MVN_0',
@ -33,7 +33,7 @@ known_issues = [
]
def parse_single_result_case(case_str):
fails = re.search(r'Failed Test Cases:\w+', case_str)
fails = re.search(r'Failed Test Cases:[\w\.\-]+', case_str)
if fails:
failed_case = fails.group().split(':')[1]
if not failed_case in known_issues: