Merge branch 'liqun/NewSequenceSliceStage2Stage'
This commit is contained in:
Коммит
ef836bc293
|
@ -600,24 +600,35 @@ namespace CNTK
|
|||
VerifyStaticAxis(ax, m_inputs[0].Shape());
|
||||
|
||||
size_t sliceAxisDim = m_inputs[0].Shape()[ax.StaticAxisIndex()];
|
||||
int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim;
|
||||
int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim;
|
||||
if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0))
|
||||
RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.",
|
||||
AsString().c_str(),
|
||||
beginIndex[i],
|
||||
endIndex[i],
|
||||
realBeginIndex,
|
||||
realEndIndex,
|
||||
m_inputs[0].AsString().c_str(),
|
||||
m_inputs[0].Shape().AsString().c_str());
|
||||
// propagate as much as we can
|
||||
// Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the
|
||||
// corresponding outputDim is also a free dimension
|
||||
if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) &&
|
||||
((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)))
|
||||
if (sliceAxisDim == NDShape::FreeDimension && (beginIndex[i] < 0 || endIndex[i] <= 0))
|
||||
{
|
||||
outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]);
|
||||
// not able to calculate real indices. do not narrow either.
|
||||
// note that endIndex[i] = 0 means to (and include) the last.
|
||||
// One case for this condition is to export and import, in ONNX format, a CNTK Sequence.Slice op.
|
||||
// In this case, if batch size is larger than 1 and input data are a zigged array (i.e. sequences of various lengths),
|
||||
// model evaludation will not march the original CNTK model.
|
||||
}
|
||||
else
|
||||
{
|
||||
int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim;
|
||||
int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim;
|
||||
if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0))
|
||||
RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.",
|
||||
AsString().c_str(),
|
||||
beginIndex[i],
|
||||
endIndex[i],
|
||||
realBeginIndex,
|
||||
realEndIndex,
|
||||
m_inputs[0].AsString().c_str(),
|
||||
m_inputs[0].Shape().AsString().c_str());
|
||||
// propagate as much as we can
|
||||
// Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the
|
||||
// corresponding outputDim is also a free dimension
|
||||
if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) &&
|
||||
((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)))
|
||||
{
|
||||
outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true);
|
||||
|
|
|
@ -83,7 +83,7 @@ private:
|
|||
std::vector<LotusIR::NodeArg *>& outputs, Graph *graph);
|
||||
|
||||
static LotusIR::Node *AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
|
||||
LotusIR::Graph* graph, int dynamicAxisCount);
|
||||
LotusIR::Graph* graph);
|
||||
static LotusIR::Node *AddMatMulNode(LotusIR::NodeArg &nodeArg1, LotusIR::NodeArg &nodeArg2, LotusIR::Graph* graph,
|
||||
const std::string &out_arg_name);
|
||||
static LotusIR::Node *AddArgMaxNode(LotusIR::NodeArg &nodeArg, LotusIR::Graph* graph, int axis);
|
||||
|
@ -1181,10 +1181,13 @@ bool IsUnSupportedLayerNormalization(const FunctionPtr src)
|
|||
FunctionPtr SkipBatchAndSequenceAxisOp(const FunctionPtr src)
|
||||
{
|
||||
if ((src->OpName() == L"ToSequenceOp" && src->Inputs()[0].Owner() &&
|
||||
src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") ||
|
||||
src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") ||
|
||||
(src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() &&
|
||||
src->Inputs()[0].Owner()->OpName() == L"UnpackSequenceOp"))
|
||||
return src->Inputs()[0].Owner()->Inputs()[0].Owner();
|
||||
else if (src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() &&
|
||||
src->Inputs()[0].Owner()->OpName() == L"Sequence::Slice")
|
||||
return src->Inputs()[0].Owner();
|
||||
else
|
||||
return src;
|
||||
}
|
||||
|
@ -1342,11 +1345,12 @@ bool CNTKToONNXHelper::FilterInput(const FunctionPtr& src, const CNTK::Variable&
|
|||
}
|
||||
|
||||
/*
|
||||
CNTK python static axis is zero based. Free/Inferred axis is not static.
|
||||
ONNX batch axis, if exists, is 0. in this case static axes start from 1.
|
||||
CNTK cpp get static axis in a dis-normalized form (e.g. -axis - 1)
|
||||
In general CNTK node attribute contains axis in this dis-normalized form.
|
||||
This function converts dis-normalized form to ONNX form.
|
||||
CNTK python static axis is zero based. Batch and Sequence axis is not static axis.
|
||||
CNTK cpp get static axis in a sanitized form (e.g. -axis - 1 by sanitize_axis)
|
||||
In general CNTK node attribute contains axis
|
||||
in a dis-normalized form (e.g. index from the last dimension).
|
||||
This function converts axis to ONNX form
|
||||
(e.g. index from the first dimension of the shape including both static and dynamic axes).
|
||||
*/
|
||||
int64_t CNTKToONNXHelper::ConvertAxisToOnnx(const Axis &axis, const Variable &operand)
|
||||
{
|
||||
|
@ -2315,9 +2319,9 @@ LotusIR::Node *CNTKToONNXHelper::AddReshapeNodeAccordingToONNXVersion(Graph *gra
|
|||
|
||||
|
||||
LotusIR::Node *CNTKToONNXHelper::AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
|
||||
LotusIR::Graph *graph, int dynamicAxisCount)
|
||||
LotusIR::Graph *graph)
|
||||
{
|
||||
onnx::TypeProto typeProto = ToTypeProto(newShape, dynamicAxisCount);
|
||||
onnx::TypeProto typeProto = ToTypeProto(newShape, false);
|
||||
UpdateONNXType(CNTK::DataType::Float, typeProto);
|
||||
|
||||
LotusIR::NodeArg &outputArg = graph->GetOrCreateNodeArg(outArgName, &typeProto);
|
||||
|
@ -2386,6 +2390,144 @@ LotusIR::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const FunctionP
|
|||
return reshapeNode;
|
||||
}
|
||||
|
||||
// To parse Sequence.Slice node graph to collect axis/begin index/end index
|
||||
// and to build an ONNX Slice op.
|
||||
// IMPORTANT NOTE:
|
||||
// This function convert a CNTK Sequence::Slice op to ONNX Slice op.
|
||||
// CNTK Sequence::Slice has ability to handle input of zigged arrays (i.e. sequences of various lengths).
|
||||
// ONNX Slice does not support zigged arrays data format.
|
||||
// Therefore in case of batch size larger than 1 and input data are a zigged arrays,
|
||||
// we do not expect model evaludation to generate marching numbers between CNTK and ONNX.
|
||||
// with this following CNTK example:
|
||||
// model = C.sequence.slice(C.sequence.input_variable((1)), -2, -1)
|
||||
// model.eval([[0, 1, 2], [0, 1, 2, 3, 4]])
|
||||
// CNTK output is:
|
||||
// array([[1.], [3.]], dtype = float32)
|
||||
// output from exported ONNX model will be:
|
||||
// array([[padding_value], [3.]], dtype = float32)
|
||||
LotusIR::Node* CNTKToONNXHelper::CreateSequenceSliceNode(const FunctionPtr& src,
|
||||
LotusIR::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, LotusIR::Node*>& functionNodes,
|
||||
std::unordered_map<Variable, LotusIR::Node*>& variableNodes,
|
||||
const std::unordered_map<Variable, Variable>& compositeOutputsMap)
|
||||
{
|
||||
auto f = src->BlockRoot();
|
||||
int64_t beginIndex = 0, endIndex = 0;
|
||||
|
||||
auto packedIndex = f->Inputs()[1].Owner();
|
||||
auto whereFunc = packedIndex->Inputs()[1].Owner();
|
||||
auto inputToWhere = whereFunc->Inputs()[0].Owner();
|
||||
// input to Where node can be:
|
||||
// ElementTimes - both indices are non-zero, beginIndex/endIndex are from First/Second inputs
|
||||
// FutureValue - beginIndex is negative, endIndex is zero
|
||||
// PastValue - endIndex is positive, beginIndex is zero
|
||||
// 1 Minus FutureValue - endIndex is negative, beginIndex is zero
|
||||
// 1 Minus PastValue - beginIndex is positive, endIndex is zero
|
||||
auto reportLogicError = [&src]()
|
||||
{
|
||||
LogicError("Failed to parse Sequence.Slice node %s(%s).", ToLegacyString(ToUTF8(src->Name())).c_str(), ToLegacyString(ToUTF8(src->Uid())).c_str());
|
||||
};
|
||||
if (inputToWhere->OpName() == L"ElementTimes")
|
||||
{
|
||||
{
|
||||
auto beginToWhere = inputToWhere->Inputs()[0].Owner();
|
||||
if (beginToWhere->OpName() == L"Minus")
|
||||
{
|
||||
auto beginToMinusMustBeAPastValueOp = beginToWhere->Inputs()[1].Owner();
|
||||
if (beginToMinusMustBeAPastValueOp->OpName() == L"PastValue")
|
||||
beginIndex = static_cast<int64_t>(beginToMinusMustBeAPastValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
else
|
||||
reportLogicError();
|
||||
}
|
||||
else if (beginToWhere->OpName() == L"FutureValue")
|
||||
{
|
||||
beginIndex = -static_cast<int64_t>(beginToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
}
|
||||
else
|
||||
reportLogicError();
|
||||
}
|
||||
{
|
||||
auto endToWhere = inputToWhere->Inputs()[1].Owner();
|
||||
if (endToWhere->OpName() == L"Minus")
|
||||
{
|
||||
auto endToMinusMustBeAFutureValueOp = endToWhere->Inputs()[1].Owner();
|
||||
if (endToMinusMustBeAFutureValueOp->OpName() == L"FutureValue")
|
||||
endIndex = -static_cast<int64_t>(endToMinusMustBeAFutureValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
else
|
||||
reportLogicError();
|
||||
}
|
||||
else if (endToWhere->OpName() == L"PastValue")
|
||||
{
|
||||
endIndex = static_cast<int64_t>(endToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
}
|
||||
else
|
||||
reportLogicError();
|
||||
}
|
||||
}
|
||||
else if (inputToWhere->OpName() == L"FutureValue")
|
||||
{
|
||||
beginIndex = -static_cast<int64_t>(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
}
|
||||
else if (inputToWhere->OpName() == L"PastValue")
|
||||
{
|
||||
endIndex = static_cast<int64_t>(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
}
|
||||
else if (inputToWhere->OpName() == L"Minus")
|
||||
{
|
||||
auto inputToMinus = inputToWhere->Inputs()[1].Owner();
|
||||
if (inputToMinus->OpName() == L"FutureValue")
|
||||
{
|
||||
endIndex = -static_cast<int64_t>(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
}
|
||||
else if (inputToMinus->OpName() == L"PastValue")
|
||||
{
|
||||
beginIndex = static_cast<int64_t>(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||
}
|
||||
}
|
||||
|
||||
if (endIndex == 0)
|
||||
// this is where CNTK and numpy disagree. numpy will output an empty matrix
|
||||
// where CNTK outputs from beginIndex to (and include) the last.
|
||||
endIndex = INT_MAX;
|
||||
|
||||
std::vector<LotusIR::NodeArg *> inputs;
|
||||
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs);
|
||||
|
||||
//std::vector<LotusIR::NodeArg *> outputs;
|
||||
//ProcessOutputs(src, outputs, graph);
|
||||
auto outputArgType = ToTypeProto(src->Output().Shape(), src->Output().HasBatchAxis(), src->Output().HasSequenceAxis());
|
||||
UpdateONNXType(src->Output().GetDataType(), outputArgType);
|
||||
|
||||
std::string outputName = ToLegacyString(ToUTF8(src->BlockRoot()->Output().Uid()));
|
||||
std::string sliceOutputName = outputName;
|
||||
bool seq_dim_is_1 = endIndex - beginIndex == 1 || (endIndex == INT_MAX && beginIndex == -1);
|
||||
if (seq_dim_is_1)
|
||||
{
|
||||
// it appears that sequence.slice squeezes sequence axis out if slice length is 1
|
||||
sliceOutputName += "_PreReshape";
|
||||
}
|
||||
|
||||
LotusIR::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(sliceOutputName, &outputArgType);
|
||||
|
||||
const std::string & nodeName = ToLegacyString(ToUTF8(src->Name()));
|
||||
LotusIR::Node *sequenceSliceNode = graph->AddNode(nodeName, "Slice", "", { inputs[inputs.size() - 1] }, { &outputNodeArg });
|
||||
sequenceSliceNode->AddAttribute("axes", std::vector<int64_t>({ int64_t(0) }));
|
||||
sequenceSliceNode->AddAttribute("ends", std::vector<int64_t>({ endIndex }));
|
||||
sequenceSliceNode->AddAttribute("starts", std::vector<int64_t>({ beginIndex }));
|
||||
if (seq_dim_is_1)
|
||||
{
|
||||
// CNTK Sequence.Slice op squeezes the sequence axis if it is of dimension 1.
|
||||
// insert reshape to remove sequence axis
|
||||
std::vector<int> newShape(reverse(Cast<size_t, int>(src->Output().Shape().Dimensions())));
|
||||
// add batch size at end
|
||||
newShape.insert(newShape.begin(), 1);
|
||||
const std::string outArgName = sliceOutputName;
|
||||
return AddReshapeNode(outputNodeArg, newShape, outputName, graph);
|
||||
}
|
||||
else
|
||||
return sequenceSliceNode;
|
||||
}
|
||||
|
||||
//
|
||||
// This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables,
|
||||
// and create the corresponding ONNX graph.
|
||||
|
@ -2417,7 +2559,15 @@ LotusIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& initialSrc,
|
|||
// return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
//}
|
||||
//else
|
||||
if (cntkOpName == "RNNStep")
|
||||
if (cntkOpName == "Sequence::Slice")
|
||||
{
|
||||
return CreateSequenceSliceNode(src,
|
||||
graph,
|
||||
functionNodes,
|
||||
variableNodes,
|
||||
compositeOutputsMap);
|
||||
}
|
||||
else if (cntkOpName == "RNNStep")
|
||||
{
|
||||
return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
|
|
|
@ -2678,7 +2678,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
// { L"", "Split)
|
||||
else if (onnxOpName == "Slice")
|
||||
{
|
||||
// axes is optional so provide a default
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
|
||||
std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts");
|
||||
|
@ -2692,7 +2691,14 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
|
||||
std::vector<int> starts = VecInt64ToVecInt(starts64);
|
||||
std::vector<int> ends = VecInt64ToVecInt(ends64);
|
||||
for (auto &e : ends)
|
||||
{
|
||||
// CNTK treats endIndex of 0 as to (and include) the last.
|
||||
if (e == INT_MAX)
|
||||
e = 0;
|
||||
}
|
||||
|
||||
// axes is optional so provide a default
|
||||
if (axes.empty())
|
||||
{
|
||||
for (int i = 0; i < starts.size(); i++)
|
||||
|
@ -2702,13 +2708,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
}
|
||||
}
|
||||
|
||||
bool workaroundONNXRT = false;
|
||||
if (workaroundONNXRT)
|
||||
{
|
||||
axes.erase(axes.begin());
|
||||
starts.erase(starts.begin());
|
||||
ends.erase(ends.begin());
|
||||
}
|
||||
FunctionPtr cntkFunction = Slice(inputs[0], axes, starts, ends, ToFixedWStringFromMultiByte(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
|
|
|
@ -463,11 +463,21 @@ Variable ToBatchAndSequence(Variable input)
|
|||
return operandWithBatchAndSequenceAxis;
|
||||
}
|
||||
|
||||
FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction)
|
||||
FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction, bool doTranspose)
|
||||
{
|
||||
FunctionPtr cntkFunctionWithoutSequenceAxis = Sequence::Unpack(rnnFunction, 0, L"");
|
||||
FunctionPtr cntkFunctionWithoutDynamicAxis = UnpackBatch(cntkFunctionWithoutSequenceAxis, L"");
|
||||
return cntkFunctionWithoutDynamicAxis;
|
||||
if (doTranspose)
|
||||
{
|
||||
FunctionPtr transpose = TransposeAxes(cntkFunctionWithoutDynamicAxis,
|
||||
Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 2),
|
||||
Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 1), L"");
|
||||
return transpose;
|
||||
}
|
||||
else
|
||||
// in case of RNN ops, transpose is inserted after the op so we do not do transpose again
|
||||
// TODO: do not transpose after RNN ops so we have one code path here.
|
||||
return cntkFunctionWithoutDynamicAxis;
|
||||
}
|
||||
|
||||
FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||
|
@ -557,7 +567,7 @@ FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &i
|
|||
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
||||
}
|
||||
|
||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction);
|
||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
|
||||
return unpackedRnnFunction;
|
||||
}
|
||||
|
||||
|
@ -622,7 +632,7 @@ FunctionPtr CreateGRU(const LotusIR::Node *node, const std::vector<Variable> &in
|
|||
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
||||
}
|
||||
|
||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction);
|
||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
|
||||
return unpackedRnnFunction;
|
||||
}
|
||||
|
||||
|
@ -678,7 +688,7 @@ FunctionPtr CreateRNN(const LotusIR::Node *node, const std::vector<Variable> &in
|
|||
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
||||
}
|
||||
|
||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction);
|
||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
|
||||
return unpackedRnnFunction;
|
||||
}
|
||||
|
||||
|
|
|
@ -184,4 +184,4 @@ std::vector<CNTK::FunctionPtr> GetRNNBlocksFromSingleOrBidirectionalRNN(const CN
|
|||
|
||||
CNTK::Variable ToBatchAndSequence(CNTK::Variable input);
|
||||
|
||||
CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction);
|
||||
CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction, bool doTranspose = true);
|
|
@ -93,20 +93,17 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N
|
|||
perm[0], perm[1] = perm[1], perm[0]
|
||||
return np.transpose(data, perm)
|
||||
|
||||
assert model.output.has_sequence_axis() and model.output.has_batch_axis()
|
||||
|
||||
# data here is reference to the outside data object. create deepcopy to avoid changing the outside data since it might get reused.
|
||||
data = deepcopy(data)
|
||||
opname = model.owner.op_name
|
||||
|
||||
loaded_model = try_save_load_resave_onnx_model(model, tmpdir, name, loaded_model)
|
||||
|
||||
model_shape = model.shape
|
||||
|
||||
# When both batch and sequence axes exist, model input will have batch and sequence axes
|
||||
# swapped to match the onnx model. In this case, input and output data shall be adjusted accordingly.
|
||||
model_shape = (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model_shape
|
||||
assert model_shape == loaded_model.shape
|
||||
# in cases like with RNN models where models have both batch and sequence axis
|
||||
# as dynamic axis, imported models will have the dynamic axes as free_dimensions in static shapes.
|
||||
if model.output.dynamic_axes == (C.Axis('defaultBatchAxis'), C.Axis('defaultDynamicAxis')):
|
||||
assert (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model.shape == loaded_model.shape
|
||||
|
||||
dataOnnx = TranposeDynamicAxis(data)
|
||||
if device:
|
||||
|
@ -116,13 +113,13 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N
|
|||
o0 = model.eval({model.arguments[0]:data})
|
||||
o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx})
|
||||
|
||||
if (type(o0) is list):
|
||||
o0 = o0[0]
|
||||
if (type(o1) is list):
|
||||
o1 = o1[0]
|
||||
o0 = np.array(o0)
|
||||
o1 = np.array(o1)
|
||||
|
||||
# squeeze the batch axis (=1) to match original model output
|
||||
o1 = np.squeeze(o1, axis=1)
|
||||
# if there is a sequence axis in the output, it must be swapped with batch axis
|
||||
# to match the original CNTK model's output
|
||||
if model.outputs[0].has_sequence_axis():
|
||||
o1 = TranposeDynamicAxis(o1)
|
||||
|
||||
assert np.allclose(o0, o1)
|
||||
return loaded_model
|
||||
|
@ -1352,6 +1349,21 @@ def test_Slice(tmpdir, dtype):
|
|||
model = C.slice(x1, [0,1], [1,0], [2,1]);
|
||||
verify_one_input(model, data, tmpdir, 'Slice2_1')
|
||||
|
||||
#Sequence.Slice
|
||||
@pytest.mark.parametrize("beginIndex, endIndex", (
|
||||
(-2, -1), (0, -1), (1, -1), (-1, 0), (1, 0), (-4, 2), (0, 1), (1, 2)))
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_SequenceSlice(tmpdir, dtype, beginIndex, endIndex):
|
||||
batch_size = 1
|
||||
sequence_length = 5
|
||||
feature_shape = (3,)
|
||||
shape = (batch_size, sequence_length, *feature_shape)
|
||||
data = np.reshape(range(0, np.prod(shape)), shape).astype(dtype)
|
||||
testName = "test_sequence_slice_{0}.{1}".format(beginIndex, endIndex)
|
||||
print(testName)
|
||||
model = C.sequence.slice(C.sequence.input_variable((feature_shape)), beginIndex, endIndex)
|
||||
verify_sequence_model(model, data, tmpdir, testName)
|
||||
|
||||
#Softmax
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_Softmax(tmpdir, dtype):
|
||||
|
|
Загрузка…
Ссылка в новой задаче