Merge branch 'liqun/NewSequenceSliceStage2Stage'
This commit is contained in:
Коммит
ef836bc293
|
@ -600,24 +600,35 @@ namespace CNTK
|
||||||
VerifyStaticAxis(ax, m_inputs[0].Shape());
|
VerifyStaticAxis(ax, m_inputs[0].Shape());
|
||||||
|
|
||||||
size_t sliceAxisDim = m_inputs[0].Shape()[ax.StaticAxisIndex()];
|
size_t sliceAxisDim = m_inputs[0].Shape()[ax.StaticAxisIndex()];
|
||||||
int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim;
|
if (sliceAxisDim == NDShape::FreeDimension && (beginIndex[i] < 0 || endIndex[i] <= 0))
|
||||||
int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim;
|
|
||||||
if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0))
|
|
||||||
RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.",
|
|
||||||
AsString().c_str(),
|
|
||||||
beginIndex[i],
|
|
||||||
endIndex[i],
|
|
||||||
realBeginIndex,
|
|
||||||
realEndIndex,
|
|
||||||
m_inputs[0].AsString().c_str(),
|
|
||||||
m_inputs[0].Shape().AsString().c_str());
|
|
||||||
// propagate as much as we can
|
|
||||||
// Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the
|
|
||||||
// corresponding outputDim is also a free dimension
|
|
||||||
if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) &&
|
|
||||||
((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)))
|
|
||||||
{
|
{
|
||||||
outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]);
|
// not able to calculate real indices. do not narrow either.
|
||||||
|
// note that endIndex[i] = 0 means to (and include) the last.
|
||||||
|
// One case for this condition is to export and import, in ONNX format, a CNTK Sequence.Slice op.
|
||||||
|
// In this case, if batch size is larger than 1 and input data are a zigged array (i.e. sequences of various lengths),
|
||||||
|
// model evaludation will not march the original CNTK model.
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim;
|
||||||
|
int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim;
|
||||||
|
if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0))
|
||||||
|
RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.",
|
||||||
|
AsString().c_str(),
|
||||||
|
beginIndex[i],
|
||||||
|
endIndex[i],
|
||||||
|
realBeginIndex,
|
||||||
|
realEndIndex,
|
||||||
|
m_inputs[0].AsString().c_str(),
|
||||||
|
m_inputs[0].Shape().AsString().c_str());
|
||||||
|
// propagate as much as we can
|
||||||
|
// Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the
|
||||||
|
// corresponding outputDim is also a free dimension
|
||||||
|
if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) &&
|
||||||
|
((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)))
|
||||||
|
{
|
||||||
|
outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true);
|
outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true);
|
||||||
|
|
|
@ -83,7 +83,7 @@ private:
|
||||||
std::vector<LotusIR::NodeArg *>& outputs, Graph *graph);
|
std::vector<LotusIR::NodeArg *>& outputs, Graph *graph);
|
||||||
|
|
||||||
static LotusIR::Node *AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
|
static LotusIR::Node *AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
|
||||||
LotusIR::Graph* graph, int dynamicAxisCount);
|
LotusIR::Graph* graph);
|
||||||
static LotusIR::Node *AddMatMulNode(LotusIR::NodeArg &nodeArg1, LotusIR::NodeArg &nodeArg2, LotusIR::Graph* graph,
|
static LotusIR::Node *AddMatMulNode(LotusIR::NodeArg &nodeArg1, LotusIR::NodeArg &nodeArg2, LotusIR::Graph* graph,
|
||||||
const std::string &out_arg_name);
|
const std::string &out_arg_name);
|
||||||
static LotusIR::Node *AddArgMaxNode(LotusIR::NodeArg &nodeArg, LotusIR::Graph* graph, int axis);
|
static LotusIR::Node *AddArgMaxNode(LotusIR::NodeArg &nodeArg, LotusIR::Graph* graph, int axis);
|
||||||
|
@ -1181,10 +1181,13 @@ bool IsUnSupportedLayerNormalization(const FunctionPtr src)
|
||||||
FunctionPtr SkipBatchAndSequenceAxisOp(const FunctionPtr src)
|
FunctionPtr SkipBatchAndSequenceAxisOp(const FunctionPtr src)
|
||||||
{
|
{
|
||||||
if ((src->OpName() == L"ToSequenceOp" && src->Inputs()[0].Owner() &&
|
if ((src->OpName() == L"ToSequenceOp" && src->Inputs()[0].Owner() &&
|
||||||
src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") ||
|
src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") ||
|
||||||
(src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() &&
|
(src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() &&
|
||||||
src->Inputs()[0].Owner()->OpName() == L"UnpackSequenceOp"))
|
src->Inputs()[0].Owner()->OpName() == L"UnpackSequenceOp"))
|
||||||
return src->Inputs()[0].Owner()->Inputs()[0].Owner();
|
return src->Inputs()[0].Owner()->Inputs()[0].Owner();
|
||||||
|
else if (src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() &&
|
||||||
|
src->Inputs()[0].Owner()->OpName() == L"Sequence::Slice")
|
||||||
|
return src->Inputs()[0].Owner();
|
||||||
else
|
else
|
||||||
return src;
|
return src;
|
||||||
}
|
}
|
||||||
|
@ -1342,11 +1345,12 @@ bool CNTKToONNXHelper::FilterInput(const FunctionPtr& src, const CNTK::Variable&
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
CNTK python static axis is zero based. Free/Inferred axis is not static.
|
CNTK python static axis is zero based. Batch and Sequence axis is not static axis.
|
||||||
ONNX batch axis, if exists, is 0. in this case static axes start from 1.
|
CNTK cpp get static axis in a sanitized form (e.g. -axis - 1 by sanitize_axis)
|
||||||
CNTK cpp get static axis in a dis-normalized form (e.g. -axis - 1)
|
In general CNTK node attribute contains axis
|
||||||
In general CNTK node attribute contains axis in this dis-normalized form.
|
in a dis-normalized form (e.g. index from the last dimension).
|
||||||
This function converts dis-normalized form to ONNX form.
|
This function converts axis to ONNX form
|
||||||
|
(e.g. index from the first dimension of the shape including both static and dynamic axes).
|
||||||
*/
|
*/
|
||||||
int64_t CNTKToONNXHelper::ConvertAxisToOnnx(const Axis &axis, const Variable &operand)
|
int64_t CNTKToONNXHelper::ConvertAxisToOnnx(const Axis &axis, const Variable &operand)
|
||||||
{
|
{
|
||||||
|
@ -2315,9 +2319,9 @@ LotusIR::Node *CNTKToONNXHelper::AddReshapeNodeAccordingToONNXVersion(Graph *gra
|
||||||
|
|
||||||
|
|
||||||
LotusIR::Node *CNTKToONNXHelper::AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
|
LotusIR::Node *CNTKToONNXHelper::AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
|
||||||
LotusIR::Graph *graph, int dynamicAxisCount)
|
LotusIR::Graph *graph)
|
||||||
{
|
{
|
||||||
onnx::TypeProto typeProto = ToTypeProto(newShape, dynamicAxisCount);
|
onnx::TypeProto typeProto = ToTypeProto(newShape, false);
|
||||||
UpdateONNXType(CNTK::DataType::Float, typeProto);
|
UpdateONNXType(CNTK::DataType::Float, typeProto);
|
||||||
|
|
||||||
LotusIR::NodeArg &outputArg = graph->GetOrCreateNodeArg(outArgName, &typeProto);
|
LotusIR::NodeArg &outputArg = graph->GetOrCreateNodeArg(outArgName, &typeProto);
|
||||||
|
@ -2386,6 +2390,144 @@ LotusIR::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const FunctionP
|
||||||
return reshapeNode;
|
return reshapeNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// To parse Sequence.Slice node graph to collect axis/begin index/end index
|
||||||
|
// and to build an ONNX Slice op.
|
||||||
|
// IMPORTANT NOTE:
|
||||||
|
// This function convert a CNTK Sequence::Slice op to ONNX Slice op.
|
||||||
|
// CNTK Sequence::Slice has ability to handle input of zigged arrays (i.e. sequences of various lengths).
|
||||||
|
// ONNX Slice does not support zigged arrays data format.
|
||||||
|
// Therefore in case of batch size larger than 1 and input data are a zigged arrays,
|
||||||
|
// we do not expect model evaludation to generate marching numbers between CNTK and ONNX.
|
||||||
|
// with this following CNTK example:
|
||||||
|
// model = C.sequence.slice(C.sequence.input_variable((1)), -2, -1)
|
||||||
|
// model.eval([[0, 1, 2], [0, 1, 2, 3, 4]])
|
||||||
|
// CNTK output is:
|
||||||
|
// array([[1.], [3.]], dtype = float32)
|
||||||
|
// output from exported ONNX model will be:
|
||||||
|
// array([[padding_value], [3.]], dtype = float32)
|
||||||
|
LotusIR::Node* CNTKToONNXHelper::CreateSequenceSliceNode(const FunctionPtr& src,
|
||||||
|
LotusIR::Graph* graph,
|
||||||
|
std::unordered_map<FunctionPtr, LotusIR::Node*>& functionNodes,
|
||||||
|
std::unordered_map<Variable, LotusIR::Node*>& variableNodes,
|
||||||
|
const std::unordered_map<Variable, Variable>& compositeOutputsMap)
|
||||||
|
{
|
||||||
|
auto f = src->BlockRoot();
|
||||||
|
int64_t beginIndex = 0, endIndex = 0;
|
||||||
|
|
||||||
|
auto packedIndex = f->Inputs()[1].Owner();
|
||||||
|
auto whereFunc = packedIndex->Inputs()[1].Owner();
|
||||||
|
auto inputToWhere = whereFunc->Inputs()[0].Owner();
|
||||||
|
// input to Where node can be:
|
||||||
|
// ElementTimes - both indices are non-zero, beginIndex/endIndex are from First/Second inputs
|
||||||
|
// FutureValue - beginIndex is negative, endIndex is zero
|
||||||
|
// PastValue - endIndex is positive, beginIndex is zero
|
||||||
|
// 1 Minus FutureValue - endIndex is negative, beginIndex is zero
|
||||||
|
// 1 Minus PastValue - beginIndex is positive, endIndex is zero
|
||||||
|
auto reportLogicError = [&src]()
|
||||||
|
{
|
||||||
|
LogicError("Failed to parse Sequence.Slice node %s(%s).", ToLegacyString(ToUTF8(src->Name())).c_str(), ToLegacyString(ToUTF8(src->Uid())).c_str());
|
||||||
|
};
|
||||||
|
if (inputToWhere->OpName() == L"ElementTimes")
|
||||||
|
{
|
||||||
|
{
|
||||||
|
auto beginToWhere = inputToWhere->Inputs()[0].Owner();
|
||||||
|
if (beginToWhere->OpName() == L"Minus")
|
||||||
|
{
|
||||||
|
auto beginToMinusMustBeAPastValueOp = beginToWhere->Inputs()[1].Owner();
|
||||||
|
if (beginToMinusMustBeAPastValueOp->OpName() == L"PastValue")
|
||||||
|
beginIndex = static_cast<int64_t>(beginToMinusMustBeAPastValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
else
|
||||||
|
reportLogicError();
|
||||||
|
}
|
||||||
|
else if (beginToWhere->OpName() == L"FutureValue")
|
||||||
|
{
|
||||||
|
beginIndex = -static_cast<int64_t>(beginToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
reportLogicError();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto endToWhere = inputToWhere->Inputs()[1].Owner();
|
||||||
|
if (endToWhere->OpName() == L"Minus")
|
||||||
|
{
|
||||||
|
auto endToMinusMustBeAFutureValueOp = endToWhere->Inputs()[1].Owner();
|
||||||
|
if (endToMinusMustBeAFutureValueOp->OpName() == L"FutureValue")
|
||||||
|
endIndex = -static_cast<int64_t>(endToMinusMustBeAFutureValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
else
|
||||||
|
reportLogicError();
|
||||||
|
}
|
||||||
|
else if (endToWhere->OpName() == L"PastValue")
|
||||||
|
{
|
||||||
|
endIndex = static_cast<int64_t>(endToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
reportLogicError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (inputToWhere->OpName() == L"FutureValue")
|
||||||
|
{
|
||||||
|
beginIndex = -static_cast<int64_t>(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
}
|
||||||
|
else if (inputToWhere->OpName() == L"PastValue")
|
||||||
|
{
|
||||||
|
endIndex = static_cast<int64_t>(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
}
|
||||||
|
else if (inputToWhere->OpName() == L"Minus")
|
||||||
|
{
|
||||||
|
auto inputToMinus = inputToWhere->Inputs()[1].Owner();
|
||||||
|
if (inputToMinus->OpName() == L"FutureValue")
|
||||||
|
{
|
||||||
|
endIndex = -static_cast<int64_t>(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
}
|
||||||
|
else if (inputToMinus->OpName() == L"PastValue")
|
||||||
|
{
|
||||||
|
beginIndex = static_cast<int64_t>(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (endIndex == 0)
|
||||||
|
// this is where CNTK and numpy disagree. numpy will output an empty matrix
|
||||||
|
// where CNTK outputs from beginIndex to (and include) the last.
|
||||||
|
endIndex = INT_MAX;
|
||||||
|
|
||||||
|
std::vector<LotusIR::NodeArg *> inputs;
|
||||||
|
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs);
|
||||||
|
|
||||||
|
//std::vector<LotusIR::NodeArg *> outputs;
|
||||||
|
//ProcessOutputs(src, outputs, graph);
|
||||||
|
auto outputArgType = ToTypeProto(src->Output().Shape(), src->Output().HasBatchAxis(), src->Output().HasSequenceAxis());
|
||||||
|
UpdateONNXType(src->Output().GetDataType(), outputArgType);
|
||||||
|
|
||||||
|
std::string outputName = ToLegacyString(ToUTF8(src->BlockRoot()->Output().Uid()));
|
||||||
|
std::string sliceOutputName = outputName;
|
||||||
|
bool seq_dim_is_1 = endIndex - beginIndex == 1 || (endIndex == INT_MAX && beginIndex == -1);
|
||||||
|
if (seq_dim_is_1)
|
||||||
|
{
|
||||||
|
// it appears that sequence.slice squeezes sequence axis out if slice length is 1
|
||||||
|
sliceOutputName += "_PreReshape";
|
||||||
|
}
|
||||||
|
|
||||||
|
LotusIR::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(sliceOutputName, &outputArgType);
|
||||||
|
|
||||||
|
const std::string & nodeName = ToLegacyString(ToUTF8(src->Name()));
|
||||||
|
LotusIR::Node *sequenceSliceNode = graph->AddNode(nodeName, "Slice", "", { inputs[inputs.size() - 1] }, { &outputNodeArg });
|
||||||
|
sequenceSliceNode->AddAttribute("axes", std::vector<int64_t>({ int64_t(0) }));
|
||||||
|
sequenceSliceNode->AddAttribute("ends", std::vector<int64_t>({ endIndex }));
|
||||||
|
sequenceSliceNode->AddAttribute("starts", std::vector<int64_t>({ beginIndex }));
|
||||||
|
if (seq_dim_is_1)
|
||||||
|
{
|
||||||
|
// CNTK Sequence.Slice op squeezes the sequence axis if it is of dimension 1.
|
||||||
|
// insert reshape to remove sequence axis
|
||||||
|
std::vector<int> newShape(reverse(Cast<size_t, int>(src->Output().Shape().Dimensions())));
|
||||||
|
// add batch size at end
|
||||||
|
newShape.insert(newShape.begin(), 1);
|
||||||
|
const std::string outArgName = sliceOutputName;
|
||||||
|
return AddReshapeNode(outputNodeArg, newShape, outputName, graph);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return sequenceSliceNode;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables,
|
// This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables,
|
||||||
// and create the corresponding ONNX graph.
|
// and create the corresponding ONNX graph.
|
||||||
|
@ -2417,7 +2559,15 @@ LotusIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& initialSrc,
|
||||||
// return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
|
// return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||||
//}
|
//}
|
||||||
//else
|
//else
|
||||||
if (cntkOpName == "RNNStep")
|
if (cntkOpName == "Sequence::Slice")
|
||||||
|
{
|
||||||
|
return CreateSequenceSliceNode(src,
|
||||||
|
graph,
|
||||||
|
functionNodes,
|
||||||
|
variableNodes,
|
||||||
|
compositeOutputsMap);
|
||||||
|
}
|
||||||
|
else if (cntkOpName == "RNNStep")
|
||||||
{
|
{
|
||||||
return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
|
return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2678,7 +2678,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
||||||
// { L"", "Split)
|
// { L"", "Split)
|
||||||
else if (onnxOpName == "Slice")
|
else if (onnxOpName == "Slice")
|
||||||
{
|
{
|
||||||
// axes is optional so provide a default
|
|
||||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||||
|
|
||||||
std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts");
|
std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts");
|
||||||
|
@ -2692,7 +2691,14 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
||||||
|
|
||||||
std::vector<int> starts = VecInt64ToVecInt(starts64);
|
std::vector<int> starts = VecInt64ToVecInt(starts64);
|
||||||
std::vector<int> ends = VecInt64ToVecInt(ends64);
|
std::vector<int> ends = VecInt64ToVecInt(ends64);
|
||||||
|
for (auto &e : ends)
|
||||||
|
{
|
||||||
|
// CNTK treats endIndex of 0 as to (and include) the last.
|
||||||
|
if (e == INT_MAX)
|
||||||
|
e = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// axes is optional so provide a default
|
||||||
if (axes.empty())
|
if (axes.empty())
|
||||||
{
|
{
|
||||||
for (int i = 0; i < starts.size(); i++)
|
for (int i = 0; i < starts.size(); i++)
|
||||||
|
@ -2702,13 +2708,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool workaroundONNXRT = false;
|
|
||||||
if (workaroundONNXRT)
|
|
||||||
{
|
|
||||||
axes.erase(axes.begin());
|
|
||||||
starts.erase(starts.begin());
|
|
||||||
ends.erase(ends.begin());
|
|
||||||
}
|
|
||||||
FunctionPtr cntkFunction = Slice(inputs[0], axes, starts, ends, ToFixedWStringFromMultiByte(node->Name()));
|
FunctionPtr cntkFunction = Slice(inputs[0], axes, starts, ends, ToFixedWStringFromMultiByte(node->Name()));
|
||||||
return cntkFunction;
|
return cntkFunction;
|
||||||
}
|
}
|
||||||
|
|
|
@ -463,11 +463,21 @@ Variable ToBatchAndSequence(Variable input)
|
||||||
return operandWithBatchAndSequenceAxis;
|
return operandWithBatchAndSequenceAxis;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction)
|
FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction, bool doTranspose)
|
||||||
{
|
{
|
||||||
FunctionPtr cntkFunctionWithoutSequenceAxis = Sequence::Unpack(rnnFunction, 0, L"");
|
FunctionPtr cntkFunctionWithoutSequenceAxis = Sequence::Unpack(rnnFunction, 0, L"");
|
||||||
FunctionPtr cntkFunctionWithoutDynamicAxis = UnpackBatch(cntkFunctionWithoutSequenceAxis, L"");
|
FunctionPtr cntkFunctionWithoutDynamicAxis = UnpackBatch(cntkFunctionWithoutSequenceAxis, L"");
|
||||||
return cntkFunctionWithoutDynamicAxis;
|
if (doTranspose)
|
||||||
|
{
|
||||||
|
FunctionPtr transpose = TransposeAxes(cntkFunctionWithoutDynamicAxis,
|
||||||
|
Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 2),
|
||||||
|
Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 1), L"");
|
||||||
|
return transpose;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
// in case of RNN ops, transpose is inserted after the op so we do not do transpose again
|
||||||
|
// TODO: do not transpose after RNN ops so we have one code path here.
|
||||||
|
return cntkFunctionWithoutDynamicAxis;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||||
|
@ -557,7 +567,7 @@ FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &i
|
||||||
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction);
|
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
|
||||||
return unpackedRnnFunction;
|
return unpackedRnnFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -622,7 +632,7 @@ FunctionPtr CreateGRU(const LotusIR::Node *node, const std::vector<Variable> &in
|
||||||
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction);
|
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
|
||||||
return unpackedRnnFunction;
|
return unpackedRnnFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -678,7 +688,7 @@ FunctionPtr CreateRNN(const LotusIR::Node *node, const std::vector<Variable> &in
|
||||||
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction);
|
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
|
||||||
return unpackedRnnFunction;
|
return unpackedRnnFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -184,4 +184,4 @@ std::vector<CNTK::FunctionPtr> GetRNNBlocksFromSingleOrBidirectionalRNN(const CN
|
||||||
|
|
||||||
CNTK::Variable ToBatchAndSequence(CNTK::Variable input);
|
CNTK::Variable ToBatchAndSequence(CNTK::Variable input);
|
||||||
|
|
||||||
CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction);
|
CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction, bool doTranspose = true);
|
|
@ -93,20 +93,17 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N
|
||||||
perm[0], perm[1] = perm[1], perm[0]
|
perm[0], perm[1] = perm[1], perm[0]
|
||||||
return np.transpose(data, perm)
|
return np.transpose(data, perm)
|
||||||
|
|
||||||
assert model.output.has_sequence_axis() and model.output.has_batch_axis()
|
|
||||||
|
|
||||||
# data here is reference to the outside data object. create deepcopy to avoid changing the outside data since it might get reused.
|
# data here is reference to the outside data object. create deepcopy to avoid changing the outside data since it might get reused.
|
||||||
data = deepcopy(data)
|
data = deepcopy(data)
|
||||||
opname = model.owner.op_name
|
opname = model.owner.op_name
|
||||||
|
|
||||||
loaded_model = try_save_load_resave_onnx_model(model, tmpdir, name, loaded_model)
|
loaded_model = try_save_load_resave_onnx_model(model, tmpdir, name, loaded_model)
|
||||||
|
|
||||||
model_shape = model.shape
|
|
||||||
|
|
||||||
# When both batch and sequence axes exist, model input will have batch and sequence axes
|
# in cases like with RNN models where models have both batch and sequence axis
|
||||||
# swapped to match the onnx model. In this case, input and output data shall be adjusted accordingly.
|
# as dynamic axis, imported models will have the dynamic axes as free_dimensions in static shapes.
|
||||||
model_shape = (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model_shape
|
if model.output.dynamic_axes == (C.Axis('defaultBatchAxis'), C.Axis('defaultDynamicAxis')):
|
||||||
assert model_shape == loaded_model.shape
|
assert (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model.shape == loaded_model.shape
|
||||||
|
|
||||||
dataOnnx = TranposeDynamicAxis(data)
|
dataOnnx = TranposeDynamicAxis(data)
|
||||||
if device:
|
if device:
|
||||||
|
@ -116,13 +113,13 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N
|
||||||
o0 = model.eval({model.arguments[0]:data})
|
o0 = model.eval({model.arguments[0]:data})
|
||||||
o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx})
|
o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx})
|
||||||
|
|
||||||
if (type(o0) is list):
|
o0 = np.array(o0)
|
||||||
o0 = o0[0]
|
o1 = np.array(o1)
|
||||||
if (type(o1) is list):
|
|
||||||
o1 = o1[0]
|
|
||||||
|
|
||||||
# squeeze the batch axis (=1) to match original model output
|
# if there is a sequence axis in the output, it must be swapped with batch axis
|
||||||
o1 = np.squeeze(o1, axis=1)
|
# to match the original CNTK model's output
|
||||||
|
if model.outputs[0].has_sequence_axis():
|
||||||
|
o1 = TranposeDynamicAxis(o1)
|
||||||
|
|
||||||
assert np.allclose(o0, o1)
|
assert np.allclose(o0, o1)
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
@ -1352,6 +1349,21 @@ def test_Slice(tmpdir, dtype):
|
||||||
model = C.slice(x1, [0,1], [1,0], [2,1]);
|
model = C.slice(x1, [0,1], [1,0], [2,1]);
|
||||||
verify_one_input(model, data, tmpdir, 'Slice2_1')
|
verify_one_input(model, data, tmpdir, 'Slice2_1')
|
||||||
|
|
||||||
|
#Sequence.Slice
|
||||||
|
@pytest.mark.parametrize("beginIndex, endIndex", (
|
||||||
|
(-2, -1), (0, -1), (1, -1), (-1, 0), (1, 0), (-4, 2), (0, 1), (1, 2)))
|
||||||
|
@pytest.mark.parametrize("dtype", DType_Config)
|
||||||
|
def test_SequenceSlice(tmpdir, dtype, beginIndex, endIndex):
|
||||||
|
batch_size = 1
|
||||||
|
sequence_length = 5
|
||||||
|
feature_shape = (3,)
|
||||||
|
shape = (batch_size, sequence_length, *feature_shape)
|
||||||
|
data = np.reshape(range(0, np.prod(shape)), shape).astype(dtype)
|
||||||
|
testName = "test_sequence_slice_{0}.{1}".format(beginIndex, endIndex)
|
||||||
|
print(testName)
|
||||||
|
model = C.sequence.slice(C.sequence.input_variable((feature_shape)), beginIndex, endIndex)
|
||||||
|
verify_sequence_model(model, data, tmpdir, testName)
|
||||||
|
|
||||||
#Softmax
|
#Softmax
|
||||||
@pytest.mark.parametrize("dtype", DType_Config)
|
@pytest.mark.parametrize("dtype", DType_Config)
|
||||||
def test_Softmax(tmpdir, dtype):
|
def test_Softmax(tmpdir, dtype):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче