This commit is contained in:
liqfu 2018-03-05 13:13:03 -08:00
Родитель 6a8773bcf2
Коммит 7651b0d567
11 изменённых файлов: 2387 добавлений и 97 удалений

Просмотреть файл

@ -531,7 +531,8 @@ CNTKLIBRARY_COMMON_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/model.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/Operators.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNX.cpp \

Просмотреть файл

@ -1975,6 +1975,13 @@ namespace CNTK
[](const Axis& axis) { return (axis == Axis::DefaultBatchAxis()); });
bool HasSequenceAxis() const {
return (DynamicAxes().size() - (HasBatchAxis() ? 1 : 0)) > 0;
bool IsInitialized() const {
return m_dataFields != nullptr;
/// Returns the name of 'this' variable

Просмотреть файл

@ -187,6 +187,7 @@
<ClInclude Include="proto\onnx\ONNX.h" />
<ClInclude Include="proto\onnx\ONNXToCNTK.h" />
<ClInclude Include="proto\onnx\Operators.h" />
<ClInclude Include="proto\onnx\RNNHelper.h" />
<ClInclude Include="Serialization.h" />
<ClInclude Include="tensorboard\TensorBoardUtils.h" />
<ClInclude Include="UserDefinedFunction.h" />
@ -243,6 +244,7 @@
<ClCompile Include="proto\onnx\ONNXToCNTK.cpp" />
<ClCompile Include="proto\onnx\Operators.cpp" />
<ClCompile Include="proto\onnx\protobuf\" />
<ClCompile Include="proto\onnx\RNNHelper.cpp" />
<ClCompile Include="Serialization.cpp" />
<ClCompile Include="stdafx.cpp">

Просмотреть файл

@ -108,6 +108,9 @@
<ClCompile Include="proto\onnx\protobuf\">
<ClCompile Include="proto\onnx\RNNHelper.cpp">
<ClInclude Include="stdafx.h" />
@ -190,6 +193,9 @@
<ClInclude Include="API\CNTKLibraryC.h">
<ClInclude Include="proto\onnx\RNNHelper.h">
<Filter Include="API">
@ -252,4 +258,4 @@

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -9,6 +9,7 @@
#include "Operators.h"
#include <algorithm>
#include <iostream>
#include "RNNHelper.h"
using namespace ONNXIR;
using namespace CNTK;
@ -17,14 +18,16 @@ using namespace CNTK::ONNX;
namespace CNTK
typedef std::unordered_map<const Node *, FunctionPtr> ONNXToCNTKMap;
typedef std::unordered_map<const Node *, std::vector<FunctionPtr>> ONNXToCNTKMap;
typedef std::unordered_map<std::string, Variable> ONNXToCNTKVariableMap;
class ONNXToCNTKHelper
// Convert an ONNX graph to a CNTK graph (Function).
static FunctionPtr FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
static std::vector<FunctionPtr> FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
ONNXToCNTKVariableMap &constructedNodeArgVariableMap,
const Graph* graph, const DeviceDescriptor& computeDevice);
@ -37,6 +40,9 @@ namespace CNTK
const DeviceDescriptor& computeDevice);
static Variable CreateLeafVariableOrConstant(const NodeArg *nodeArg, const Node *parentNode, const Graph *graph,
const DeviceDescriptor& computeDevice);
static std::vector<Variable> CreateRNNLeafVariableOrConstant(const NodeArg *nodeArg,
const Node *parentNode, const Graph* graph,
ONNXToCNTKVariableMap &constructedNodeArgVariableMap, const DeviceDescriptor& computeDevice);
static FunctionPtr CreateFunction(const Node *node, const std::vector<Variable> &inputs);
static bool IsSecondInputOfElementWiseOpsWithBroadcast(const Node *parentNode, const NodeArg *nodeArg);
@ -100,14 +106,17 @@ namespace CNTK
static string GetNamedAttributeAsString(const Node *node, const string &attributeName);
static string GetNamedAttributeAsString(const Node *node, const string &attributeName, const string& defaultValue);
static std::vector<std::string> GetNamedAttributeAsStringVec(const Node *node, const string &attributeName,
const std::vector<std::string> &defaultValues);
static std::vector<int64_t> GetNamedAttributeAsInt64Vec(const Node *node, const string &attributeName);
static std::vector<int64_t> GetNamedAttributeAsInt64Vec(const Node *node, const string &attributeName, const std::vector<int64_t>& defaultValue);
static std::vector<float> GetNamedAttributeAsFloatVec(const Node *node, const string &attributeName);
static std::vector<float> GetNamedAttributeAsFloatVec(const Node *node, const string &attributeName, const std::vector<float>& defaultValue);
static Axis ConvertAxisToCNTKCppApi(const Axis& axes, const Variable& input);
static std::vector<Axis> ConvertAxesToCNTKCppApi(const std::vector<Axis> &axes, const Variable& operand);
static Axis ConvertONNXAxisToCNTKCppApi(int64_t axes, const Variable& input);
static std::vector<Axis> ConvertONNXAxesToCNTKCppApi(const std::vector<int64_t> &axes, const Variable& operand);
static void AdjustAutoPaddingAndStrideForCNTKSpecialCases(const Variable &operand,
std::vector<bool> &autoPadding, NDShape &strides);
@ -547,18 +556,344 @@ bool ONNXToCNTKHelper::FixConstantShapeForConstantVariableInputPair(const std::v
return true;
Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
const Node *parentNode, const Graph* graph, const DeviceDescriptor& computeDevice)
int CalculateNodeArgInputIndex(const NodeArg *nodeArg, const Node *parentNode)
std::string nodeName = nodeArg->Name();
std::vector<NodeArg>::const_iterator it = std::find_if(parentNode->InputDefs().cbegin(),
parentNode->InputDefs().cend(), [nodeArg](const NodeArg &other) {return other.Name() == nodeArg->Name(); });
if (it == parentNode->InputDefs().cend())
return -1;
return it - parentNode->InputDefs().cbegin();
template<typename DType>
Constant CreateConstantWithRawData(DType *data,const NDShape &shape, const std::string &name,
const DeviceDescriptor& computeDevice)
DataType dataType = AsDataType<DType>();
int totalSize = shape.TotalSize();
NDArrayViewPtr dstFinal(new NDArrayView(dataType, shape, data,
totalSize * sizeof(DType), computeDevice.CPUDevice()));
if (computeDevice.Type() == DeviceKind::CPU)
Constant constantVariable(dstFinal, ToWString(name));
return constantVariable;
// this is the way to load values into GPU:
// Create a GPU NDArrayView and CopyFrom a CPU NDArrayView that holding the data.
NDArrayViewPtr dstFinalGPU(new NDArrayView(dataType, StorageFormat::Dense, shape, computeDevice));
Constant constantVariable(dstFinalGPU, ToWString(name));
return constantVariable;
std::vector<Variable> CreateRNNConstant(
const Node *parentNode, int index, const std::string &name, onnx::TensorProto &valueProto, const DeviceDescriptor& computeDevice)
std::vector<Variable> inputs;
string parentONNXOpName = parentNode->OpType();
auto dataType = valueProto.data_type();
switch (dataType)
case TensorProto_DataType_FLOAT:
if (valueProto.float_data().empty())
case TensorProto_DataType_DOUBLE:
if (valueProto.double_data().empty())
// index to LSTM inputs as specified in the ONNX document.
if (parentONNXOpName == "LSTM")
switch (index)
case LSTMInputIndexX:
// X, should not come to here
return inputs;
case LSTMInputIndexW:
case LSTMInputIndexH:
// W, R:
// see ONNX spec for the tensor shape
int num_directions = valueProto.dims(0);
size_t rows = valueProto.dims(1);
size_t cols = valueProto.dims(2);
// CNTK cpp requires shape being (input_size, 4 * hidden_size)
NDShape weightShape({ rows, cols });
int input_size = cols;
int cell_size = rows / 4;
for (int dir = 0; dir < num_directions; dir++)
std::string nodeName = name + (index == 1 ? "_W_" : "_R_") + (char)dir;
int totalSizePerDirection = rows * cols;
// TODO: what about double?
float *data = new float[totalSizePerDirection];
for (size_t count = 0; count < totalSizePerDirection; count++)
int row = count / input_size;
int col = count % input_size;
int block = row / cell_size;
if (block == 1)
// o
row += cell_size * 2;
else if (block == 3)
// c
row -= cell_size * 2;
int sourceIndex = dir * totalSizePerDirection + count;
int targetIndex = col * cell_size * 4 + row;
data[targetIndex] = valueProto.float_data()[sourceIndex];
Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice);
return inputs;
case LSTMInputIndexB:
// B
// see ONNX spec for the tensor shape
int num_directions = valueProto.dims(0);
int cell_size = valueProto.dims(1) / 8;
// there is an ONNX spec issue with bias input. It states that
// "This tensor has shape `[num_directions, 8*hidden_size]", which means
// hidden and input are applied with bias separately after weight.
// In CNTK, bias is be applied in fused form, after hidden and input
// are element-wise added. In this case
// the bias shape is [num_directions, 4*hidden_size]
NDShape weightShape({ (size_t)(4 * cell_size) });
for (int dir = 0; dir < num_directions; dir++)
std::string nodeName = name + std::string(1, (char)dir) + LSTMInputBiasNameHint;
int totalSizePerDirection = 4 * cell_size;
float *data = new float[totalSizePerDirection];
for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++)
int row = targetIndex;
// TODO: specific to LSTM. icfo (CNTK) to iofc(ONNX)
int block = row / cell_size;
if (block == 1)
// c
row += 2 * cell_size;
else if (block == 3)
// o
row -= 2 * cell_size;
// soruce is collmn major
int src_index = row;
// "fuse"
data[targetIndex] =
valueProto.float_data()[dir * 2 * totalSizePerDirection + src_index] +
valueProto.float_data()[dir * 2 * totalSizePerDirection + totalSizePerDirection + src_index];
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
return inputs;
case LSTMInputIndexSequenceLens:
// sequence length is treated as free dimension
return inputs;
case LSTMInputIndexinitial_h:
case LSTMInputIndexinitial_c:
// initial_h, initial_c
int num_directions = valueProto.dims(0);
// TODO: batch shall be one?
// int batchSize = valueProto.dims(1);
int cell_size = valueProto.dims(2);
// there is an ONNX spec issue with bias input. It states that
// "This tensor has shape `[num_directions, 8*hidden_size]", which means
// hidden and input are applied with bias separately after weight.
// In CNTK, bias is be applied in fused form, after hidden and input
// are element-wise added. In this case
// the bias shape is [num_directions, 4*hidden_size]
NDShape weightShape({ (size_t)(cell_size) });
for (int dir = 0; dir < num_directions; dir++)
std::string nodeName = name + std::string(1, (char)dir);
if (index == 5)
nodeName += LSTMInputInitialHNameHint;
nodeName += LSTMInputInitialCNameHint;
float *data = new float[cell_size];
for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
data[targetIndex] = valueProto.float_data()[dir * cell_size + targetIndex];
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
return inputs;
case LSTMInputIndexP:
// P
int num_directions = valueProto.dims(0);
int cell_size = valueProto.dims(1) / 3;
for (int dir = 0; dir < num_directions; dir++)
for (int i = 0; i < 3; i++)
std::string nodeName = name + ((i == 0) ? "_i" : ((i == 1) ? "_o" : "_f")) +
std::string(1, (char)dir) + LSTMInputPeepholeNameHint;
float *data = new float[cell_size];
NDShape weightShape({ (size_t)(cell_size) });
for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
data[targetIndex] = valueProto.float_data()[(dir * 3 + i) * cell_size + targetIndex];
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
return inputs;
CNTK::LogicError("CreateRNNConstant received unepxpeted index: %d", index);
std::vector<FunctionPtr> CreateRNNConstantOp(const Graph* graph, const Node *node, const Node *parentNode, int index,
const DeviceDescriptor& computeDevice)
onnx::TensorProto valueProto;
if (!graph->GetInitialTensor(node->Name(), valueProto))
NodeAttributes::const_iterator itValue = node->GetAttributes().find("value");
if (itValue == node->GetAttributes().cend())
return std::vector<FunctionPtr>();
valueProto = itValue->second.t();
std::vector<Variable> constantNodes = CreateRNNConstant(parentNode, index, node->Name(), valueProto, computeDevice);
std::vector<FunctionPtr> returns;
for (auto c : constantNodes)
return returns;
std::vector<Variable> ONNXToCNTKHelper::CreateRNNLeafVariableOrConstant(const NodeArg *nodeArg,
const Node *parentNode, const Graph* graph, ONNXToCNTKVariableMap &constructedNodeArgVariableMap,
const DeviceDescriptor& computeDevice)
string parentONNXOpName = parentNode->OpType();
std::string nodeName = nodeArg->Name();
onnx::TensorProto valueProto;
if (graph->GetInitialTensor(nodeName, valueProto))
int index = CalculateNodeArgInputIndex(nodeArg, parentNode);
return CreateRNNConstant(parentNode, index, nodeName, valueProto, computeDevice);
const TensorShapeProto *shapeProto = nodeArg->Shape();
if (shapeProto == nullptr)
// dummy input,
return std::vector<Variable>();
// std::string nodeName = nodeArg->Name();
std::vector<Axis> dynamicAxes({ Axis::OperandSequenceAxis() , Axis::DefaultBatchAxis()});
if (parentONNXOpName == "LSTM")
// index to LSTM inputs as specified in the ONNX document.
int inputIndex = CalculateNodeArgInputIndex(nodeArg, parentNode);
switch (inputIndex)
case LSTMInputIndexX:
// X: `[seq_length, batch_size, input_size]`.
Variable inputVariable;
if (constructedNodeArgVariableMap.find(nodeArg->Name()) == constructedNodeArgVariableMap.end())
DataType dataType = FromONNXType(nodeArg->ToProto().type());
int input_size = shapeProto->dim(2).dim_value();
NDShape shape({ (size_t)(input_size) });
inputVariable = InputVariable(shape, dataType, ToWString(nodeArg->Name()), dynamicAxes);
constructedNodeArgVariableMap.insert(ONNXToCNTKVariableMap::value_type(nodeArg->Name(), inputVariable));
return std::vector<Variable>({ constructedNodeArgVariableMap[nodeArg->Name()]});
// other inputs shall be ONNX constant node and be created as CNTK Constant in CreateRNNConstant
case LSTMInputIndexW: // W
case LSTMInputIndexH: // R
case LSTMInputIndexB: // B
case LSTMInputIndexSequenceLens: // sequence_lens
case LSTMInputIndexinitial_h: // initial_h
case LSTMInputIndexinitial_c: // initial_c
case LSTMInputIndexP: // P
LogicError("LSTM node has unexpected input");
Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
const Node *parentNode, const Graph* graph, const DeviceDescriptor& computeDevice)
string parentONNXOpName = parentNode->OpType();
std::string nodeName = nodeArg->Name();
onnx::TensorProto valueProto;
if (graph->GetInitialTensor(nodeName, valueProto))
return CreateConstant(valueProto, nodeName, computeDevice);
auto dataType = FromONNXType(nodeArg->ToProto().type());
auto shapeProto = nodeArg->Shape();
// in CNTK constants are created as Node (not a leaf) with values.
@ -571,15 +906,38 @@ Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
shape = shape.SubShape(0, shape.Rank() - 1);
std::vector<Axis> dynamicAxes({ Axis::DefaultBatchAxis() });
// TODO: this is not fully correct. We need to get hasSequenceAxis
// over the traverse path. An input will have a sequence axis
// only if it outputs to an RNN op along the path.
// This requires support from LotusIR.
// Now traversing starts from arbitray nodes which may miss the RNN op.
bool hasSequenceAxis = false;
for (Graph::NodeIterator nodeIt = (const_cast<Graph *>(graph))->Nodes_begin();
nodeIt != (const_cast<Graph *>(graph))->Nodes_end(); ++nodeIt)
if (Operators::IsRNNOp((*nodeIt)->OpType()))
hasSequenceAxis = true;
if (hasSequenceAxis)
shape = shape.SubShape(0, shape.Rank() - 1);
dynamicAxes.insert(dynamicAxes.begin(), Axis::OperandSequenceAxis());
auto dataType = FromONNXType(nodeArg->ToProto().type());
switch (dataType)
case DataType::Float:
return InputVariable(shape, DataType::Float, ToWString(nodeArg->Name()), { Axis::DefaultBatchAxis() });
return InputVariable(shape, DataType::Float, ToWString(nodeArg->Name()), dynamicAxes);
case DataType::Double:
return InputVariable(shape, DataType::Double, ToWString(nodeArg->Name()), { Axis::DefaultBatchAxis() });
return InputVariable(shape, DataType::Double, ToWString(nodeArg->Name()), dynamicAxes);
@ -775,6 +1133,17 @@ string ONNXToCNTKHelper::GetNamedAttributeAsString(const Node *node, const strin
return attributeProto.s();
std::vector<std::string> ONNXToCNTKHelper::GetNamedAttributeAsStringVec(const Node *node, const string &attributeName,
const std::vector<std::string> &defaultValues)
NodeAttributes::const_iterator itValue = FindAttributeIterator(node, attributeName, false);
if (itValue == node->GetAttributes().end())
return defaultValues;
const AttributeProto &attributeProto = itValue->second;
return std::vector<std::string>(attributeProto.strings().begin(), attributeProto.strings().end());
std::vector<int64_t> ONNXToCNTKHelper::GetNamedAttributeAsInt64Vec(const Node *node, const string &attributeName)
NodeAttributes::const_iterator itValue = FindAttributeIterator(node, attributeName, true);
@ -970,40 +1339,50 @@ std::pair<Variable, Variable> ONNXToCNTKHelper::BroadcastElementWiseInput(
// this is to handle edge cases where one input has batch (and sequence) axis and the other does not.
// the input with rank higher is caused by extra batch/sequence dimension which need to be squeezed away.
// TODO: investigate if this edge case can be avoided when converting CNTK to ONNX.
if (input0.Shape().Rank() == input1.Shape().Rank() - 1)
if (input0.HasBatchAxis() && !input1.HasBatchAxis())
return{ input0, Reshape(input1, shape1.SubShape(0, shape1.Rank() - 1)) };
else if (input0.Shape().Rank() == input1.Shape().Rank() - 2 && input0.DynamicAxes().size() == 2 &&
input1.DynamicAxes().size() == 0)
return{ input0, Reshape(input1, shape1.SubShape(0, shape1.Rank() - 2)) };
else if (input0.Shape().Rank() - 1 == input1.Shape().Rank())
if (!input0.HasBatchAxis() && input1.HasBatchAxis())
return{ Reshape(input0, shape0.SubShape(0, shape0.Rank() - 1)), input1 };
else if (input0.Shape().Rank() - 2 == input1.Shape().Rank() && input1.DynamicAxes().size() == 2 &&
input0.DynamicAxes().size() == 0)
return{ Reshape(input0, shape0.SubShape(0, shape0.Rank() - 2)), input1 };
return{ input0 , input1 };
Axis ONNXToCNTKHelper::ConvertAxisToCNTKCppApi(const Axis& axis, const Variable& operand)
Axis ONNXToCNTKHelper::ConvertONNXAxisToCNTKCppApi(int64_t axis, const Variable& operand)
// reverse CNTKToONNXHelper::ConvertAxisToOnnx
// note that axis is already decreased by one (assuming there is a batch axis)
int64_t index = axis.StaticAxisIndex();
if (!operand.HasBatchAxis())
int index = axis - operand.DynamicAxes().size();
if (index < 0)
LogicError("ConvertAxisToCNTKCppApi cannot convert index < 0 to axis");
// apply -index - 1
return Axis(-index - 1);
std::vector<Axis> ONNXToCNTKHelper::ConvertAxesToCNTKCppApi(const std::vector<Axis> &axes, const Variable& operand)
std::vector<Axis> ONNXToCNTKHelper::ConvertONNXAxesToCNTKCppApi(const std::vector<int64_t> &axes, const Variable& operand)
std::vector<Axis> cntkAxes(axes.size());
for (int i = 0; i < axes.size(); i++)
cntkAxes[i] = ConvertAxisToCNTKCppApi(axes[i], operand);
cntkAxes[i] = ConvertONNXAxisToCNTKCppApi(axes[i], operand);
return cntkAxes;
@ -1050,15 +1429,29 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
string onnxOpName = node->OpType();
if (onnxOpName == "Cast" && inputs[0].GetDataType() == DataType::Float && inputs[0].Owner() != nullptr)
// CNTK does not support cast op. Only float is available with ONNX support.
// Question for having a cast op: Why not cast data as necessary internally.
return inputs[0].Owner();
else if (onnxOpName == "LSTM")
const string direction = GetNamedAttributeAsString(node, "direction");
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
std::vector<float> activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector<float>());
const std::vector<string> activations = GetNamedAttributeAsStringVec(node, "activations",
std::vector<string>({"Sigmoid", "Tanh", "Tanh"}));
return CreateLSTM(node, inputs, direction, activations, activation_alpha, activation_beta);
if (onnxOpName == "FC")
return CreateCNTKFCNode(ToWString(node->Name()), inputs);
else if (onnxOpName == "Flatten")
Axis defaultAxis(-1);
Axis axis = GetNamedAttributeAsAxis(node, "axis", defaultAxis);
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
int64_t axisIndex = (size_t)GetNamedAttributeAsInt64(node, "axis", 0);
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
FunctionPtr cntkFunction = Flatten(inputs[0], axis, ToWString(node->Name()));
return cntkFunction;
@ -1513,57 +1906,57 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
else if (onnxOpName == "ReduceMax")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
FunctionPtr cntkFunction = ReduceMax(inputs[0], axes[0], ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceMin")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
FunctionPtr cntkFunction = ReduceMin(inputs[0], axes[0], ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceSum")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
FunctionPtr cntkFunction = ReduceSum(inputs[0], axes[0], ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceMean")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
FunctionPtr cntkFunction = ReduceMean(inputs[0], axes[0], ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceProd")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
FunctionPtr cntkFunction = ReduceProd(inputs[0], axes[0], ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceLogSumExp")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
FunctionPtr cntkFunction = ReduceLogSum(inputs[0], axes[0], ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceL1")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
bool keepdims = GetNamedAttributeAsInt64(node, "keepdims", 1) == 1;
FunctionPtr cntkFunction = ReduceL1(inputs[0], axes, keepdims, ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceL2")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
bool keepdims = GetNamedAttributeAsInt64(node, "keepdims", 1) == 1;
FunctionPtr cntkFunction = ReduceL2(inputs[0], axes, keepdims, ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "ReduceSumSquare")
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
bool keepdims = GetNamedAttributeAsInt64(node, "keepdims", 1) == 1;
FunctionPtr cntkFunction = ReduceSumSquare(inputs[0], axes, keepdims, ToWString(node->Name()));
return cntkFunction;
@ -1572,34 +1965,60 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis");
// -1 to compensate what ConvertAxisToCNTKCppApi assumes that axis is already decreased by 1
Axis axis(axisIndex - 1);
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
FunctionPtr cntkfunction = Argmax(inputs[0], axis, ToWString(node->Name()));
return cntkfunction;
else if (onnxOpName == "ArgMin")
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis");
// -1 to compensate what ConvertAxisToCNTKCppApi assumes that axis is already decreased by 1
Axis axis(axisIndex - 1);
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
FunctionPtr cntkFunction = Argmin(inputs[0], axis, ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "Reshape")
// Skip reshape is it follows a LSTM.
const Node* childNode = GetChildNode(node, &node->InputDefs()[0]);
if (childNode != nullptr && childNode->OpType() == "LSTM")
// TODO: this is to undo reshape after LSTM in CNTKToONNX to workaround
// output shape mismatch with input to the nexk LSTM layer.
NDShape newShape = GetNamedAttributeAsShape(node, "shape", false);
newShape = newShape.SubShape(0, newShape.Rank() - inputs[0].DynamicAxes().size());
return Reshape(inputs[0], newShape, ToWString(node->Name()));
NDShape newShape = GetNamedAttributeAsShape(node, "shape", false);
if (inputs[0].HasBatchAxis())
if (inputs[0].DynamicAxes().size() > 0)
if (newShape.Rank() == 1)
LogicError("Reshape: 'shape' attribute must include element for batch axis.");
newShape = newShape.SubShape(0, newShape.Rank() - 1);
newShape = newShape.SubShape(0, newShape.Rank() - inputs[0].DynamicAxes().size());
int inferredDim = -1;
for (size_t i = 0; i < newShape.Dimensions().size(); ++i)
if (newShape[i] == 0)
newShape[i] = inputs[0].Shape()[i];
else if (newShape[i] == 0)
if (inferredDim == -1)
inferredDim = i;
LogicError("Reshape: 'shape' contains more than one inferred dimension.");
if (inferredDim != -1)
if (inputs[0].Shape().TotalSize() % newShape.TotalSize() != 0)
LogicError("Reshape: 'shape' contains more than one inferred dimension.");
int inferredDimSize = inputs[0].Shape().TotalSize() / newShape.TotalSize();
newShape[inferredDim] = inferredDimSize;
FunctionPtr cntkFunction = Reshape(inputs[0], newShape, ToWString(node->Name()));
return cntkFunction;
@ -1608,8 +2027,8 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
// We allow the 'axis' attribute to be optional, and not required (as
// given in Concat's ONNX spec), to be consistent with other frameworks.
// 'axis' can be enforced as a required attribute, if needed.
Axis axis = GetNamedAttributeAsAxis(node, "axis", Axis(0));
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
int64_t onnxAxis = GetNamedAttributeAsInt64(node, "axis", 0);
Axis axis = ConvertONNXAxisToCNTKCppApi(onnxAxis, inputs[0]);
std::vector<Variable> fixedInputs;
if (FixConstantShapeForConstantVariableInputPair(inputs, fixedInputs))
@ -1626,12 +2045,7 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
else if (onnxOpName == "Slice")
// axes is optional so provide a default
std::vector<Axis> axes;
axes = GetNamedAttributeAsAxes(node, "axes", axes);
for (int i = 0; i < axes.size(); i++)
axes[i] = ConvertAxisToCNTKCppApi(axes[i], inputs[0]);
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts");
std::vector<int64_t> ends64 = GetNamedAttributeAsInt64Vec(node, "ends");
@ -1713,13 +2127,15 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
if (HasNamedAttribute(node, "axis"))
Axis defaultAxes(-1);
Axis axis = GetNamedAttributeAsAxis(node, "axis", defaultAxes);
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis", 0);
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
return GatherOp(inputs[1], inputs[0], axis, ToWString(node->Name()));
return GatherOp(inputs[1], inputs[0], ToWString(node->Name()));
FunctionPtr cntkFunction = GatherOp(inputs[1], inputs[0], ToWString(node->Name()));
return cntkFunction;
else if (onnxOpName == "DepthToSpace")
@ -1768,14 +2184,34 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
FunctionPtr ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
std::pair<const Node *, int> FindParentAndChildIndex(const Node *node)
Node::NodeConstIterator it = node->OutputNodes_begin();
if (it != node->OutputNodes_end())
const Node *parent = *it;
int index = 0;
for (auto nodeArg : parent->InputDefs())
if (nodeArg.Name() == node->Name())
return std::make_pair(parent, index);
return std::make_pair(nullptr, -1);
std::vector<FunctionPtr> ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
ONNXToCNTKVariableMap &constructedNodeArgVariableMap,
const Graph* graph, const DeviceDescriptor& computeDevice)
auto nodeOpStr = node->OpType();
ONNXToCNTKMap::iterator itONNXToCNTKMap = constructedNodeMap.find(node);
if (itONNXToCNTKMap != constructedNodeMap.end())
return itONNXToCNTKMap->second;
return std::vector<FunctionPtr>({ itONNXToCNTKMap->second });
std::vector<Variable> inputs;
@ -1789,24 +2225,52 @@ FunctionPtr ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXToCNTKMap &cons
ONNXToCNTKMap::iterator itNodeMap = constructedNodeMap.find(const_cast<Node *>(inputNode));
if (itNodeMap != constructedNodeMap.end())
inputs.insert(inputs.end(), itNodeMap->second.begin(), itNodeMap->second.end());
FunctionPtr input = FromONNXNode(inputNode, constructedNodeMap, graph, computeDevice);
std::vector<FunctionPtr> inputVariables = FromONNXNode(inputNode, constructedNodeMap,
constructedNodeArgVariableMap, graph, computeDevice);
inputs.insert(inputs.end(), inputVariables.begin(), inputVariables.end());
Variable inputVariable = CreateLeafVariableOrConstant(nodeArg, node, graph, computeDevice);
std::string parentONNXOpName = node->OpType();
if (parentONNXOpName == "LSTM")
std::vector<Variable> inputVariables =
CreateRNNLeafVariableOrConstant(nodeArg, node, graph, constructedNodeArgVariableMap, computeDevice);
inputs.insert(inputs.end(), inputVariables.begin(), inputVariables.end());
Variable inputVariable = CreateLeafVariableOrConstant(nodeArg, node, graph, computeDevice);
FunctionPtr cntkFunction = CreateCNTKNode(node, inputs, computeDevice);
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, cntkFunction));
return cntkFunction;
const Node *parentNode;
int childIndex;
std::tie<const Node *, int>(parentNode, childIndex) = FindParentAndChildIndex(node);
if (parentNode != nullptr && parentNode->OpType() == "LSTM")
std::vector<FunctionPtr> cntkFunctions = CreateRNNConstantOp(graph, node, parentNode, childIndex, computeDevice);
if (!cntkFunctions.empty())
// TODO: make node map to vector of FunctionPtr
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, cntkFunctions));
return cntkFunctions;
FunctionPtr cntkFunction = CreateCNTKNode(node, inputs, computeDevice);
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, std::vector<FunctionPtr>({ cntkFunction })));
return std::vector<FunctionPtr>({ cntkFunction });
FunctionPtr ONNXToCNTKHelper::CreateCNTKNode(const Node *node, const std::vector<Variable> &inputs,
@ -2062,13 +2526,15 @@ FunctionPtr ONNXToCNTK::CreateGraph(ONNXIR::Graph* src, const DeviceDescriptor&
// To use depth-first-traversal, keeps a collection of visited nodes.
ONNXToCNTKMap constructedFunctions;
ONNXToCNTKVariableMap constructedNodeArgVariableMap;
for (Graph::NodeIterator it = src->Nodes_begin(); it != src->Nodes_end(); ++it)
const Node *node = *it;
if (constructedFunctions.find(node) == constructedFunctions.end())
FunctionPtr cntkNode = ONNXToCNTKHelper::FromONNXNode(node, constructedFunctions, src, computeDevice);
std::vector<FunctionPtr> cntkNode = ONNXToCNTKHelper::FromONNXNode(node,
constructedFunctions, constructedNodeArgVariableMap, src, computeDevice);
@ -2083,7 +2549,7 @@ FunctionPtr ONNXToCNTK::CreateGraph(ONNXIR::Graph* src, const DeviceDescriptor&
std::vector<FunctionPtr> functions;
for (Node::NodeConstIterator it = itNodeFn->first->InputNodes_begin(); it != itNodeFn->first->InputNodes_end(); ++it)
functions.insert(functions.end(), constructedFunctions[*it].begin(), constructedFunctions[*it].end());
if (functions.empty())

Просмотреть файл

@ -370,6 +370,9 @@ namespace ONNX
{ L"useStatsAcrossChannels", "across_channels" },
{ L"doVarianceScaling", "normalize_variance" },
} } },
{ L"Embedding",{ {
{ L"Embedding", "Gather" },
} } },
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,
@ -425,7 +428,15 @@ namespace ONNX
(onnxOpName == "And") || (onnxOpName == "Or") || (onnxOpName == "Xor");
bool Operators::IsLoopOp(const std::string &opName)
return opName == "PastValue" || opName == "FutureValue";
bool Operators::IsRNNOp(const std::string &opName)
return opName == "LSTM" || opName == "GRU" || opName == "RNN";
std::unordered_map<std::wstring, std::set<size_t>> Operators::_cntkBlockOPInvalidIndices = {
{ L"Clip",{ 1, 2 } },
{ L"LeakyReLU",{ 0, 1 } },

Просмотреть файл

@ -108,6 +108,9 @@ namespace CNTK
static bool SupportBroadcast(const std::wstring& cntkOpName);
static bool SupportBroadcastONNXOp(const std::string& onnxOpName);
static bool IsLoopOp(const std::string &opName);
static bool IsRNNOp(const std::string &opName);
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;

Просмотреть файл

@ -0,0 +1,295 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See file in the project root for full license information.
#include "RNNHelper.h"
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName)
if (activationName == "Relu")
return [](const Variable& x) { return ReLU(x); };
else if (activationName == "Tanh")
return [](const Variable& x) { return Tanh(x); };
else if (activationName == "Sigmoid")
return [](const Variable& x) { return Sigmoid(x); };
// else if (activationName == "Affine")
// else if (activationName == "LeakyRelu")
// else else if (activationName == "ThresholdedRelu")
// else else if (activationName == "ScaledTanh")
// else if (activationName == "HardSigmoid")
else if (activationName == "Elu")
return [](const Variable& x) { return ELU(x); };
else if (activationName == "Softsign")
return [](const Variable& x) { return Softsign(x); };
else if (activationName == "Softplus")
return [](const Variable& x) { return Softplus(x); };
CNTK::LogicError("LSTM does not support activation: %s", activationName.c_str());
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
float activation_alpha)
if (activationName == "LeakyRelu")
return [activation_alpha](const Variable& x) { return LeakyReLU(x, activation_alpha); };
return ActivationMap(activationName);
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
float activation_alpha, float activation_beta)
if (activationName == "HardSigmoid")
return [activation_alpha, activation_beta](const Variable& x) { return HardSigmoid(x, activation_alpha, activation_beta); };
return ActivationMap(activationName, activation_alpha);
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
GetActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
if (activations.size() < (direction + 1) * LSTMActivationCount)
CNTK::LogicError("LSTM activations shall be %d or %d of strings", LSTMActivationCount, LSTMActivationCount * 2);
int iofActivationIndex = direction * LSTMActivationCount + LSTMActivationFIndex;
int cellActivation = direction * LSTMActivationCount + LSTMActivationGIndex;
int hiddenActivationIndex = direction * LSTMActivationCount + LSTMActivationHIndex;
// ONNX spec is not clear on how activation alpha and beta is set.
// Here we assume that if they are set, they are set for all activations, regardless whether
// an activation needs those values or not.
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
bool hasBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
if (hasBeta)
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
cellActivationOp = ActivationMap(activations[cellActivation], activation_alpha[cellActivation], activation_beta[cellActivation]);
hiddenActivationOp = ActivationMap(activations[hiddenActivationIndex], activation_alpha[hiddenActivationIndex], activation_beta[hiddenActivationIndex]);
else if (hasAlpha)
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex]);
cellActivationOp = ActivationMap(activations[cellActivation], activation_alpha[cellActivation]);
hiddenActivationOp = ActivationMap(activations[hiddenActivationIndex], activation_alpha[hiddenActivationIndex]);
iofActivationOp = ActivationMap(activations[iofActivationIndex]);
cellActivationOp = ActivationMap(activations[cellActivation]);
hiddenActivationOp = ActivationMap(activations[hiddenActivationIndex]);
return std::make_tuple(iofActivationOp, cellActivationOp, hiddenActivationOp);
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
Variable prevOutput, Variable prevCellState,
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
size_t outputDim = prevOutput.Shape()[0];
int stacked_dim = (int)outputDim;
FunctionPtr proj4;
if (B.IsInitialized())
proj4 = Plus(Plus(B, Times(W, input)), Times(R, prevOutput));
proj4 = Plus(Times(W, input), Times(R, prevOutput));
// CNTK weight and bias are in icfo order.
std::vector<Axis> stack_axis({ Axis(-1) });
const int IGateIndex = 0, CGateIndex = 1, FGateIndex = 2, OGateIndex = 3;
FunctionPtr it_proj = Slice(proj4, stack_axis, { IGateIndex * stacked_dim }, { (IGateIndex + 1) * stacked_dim });
FunctionPtr bit_proj = Slice(proj4, stack_axis, { CGateIndex * stacked_dim }, { (CGateIndex + 1) * stacked_dim });
FunctionPtr ft_proj = Slice(proj4, stack_axis, { FGateIndex * stacked_dim }, { (FGateIndex + 1) * stacked_dim });
FunctionPtr ot_proj = Slice(proj4, stack_axis, { OGateIndex * stacked_dim }, { (OGateIndex + 1) * stacked_dim });
bool hasPeephole = Ci.IsInitialized();
// Input gate
auto it = hasPeephole ? iofActivationOp(it_proj + ElementTimes(Ci, prevCellState)) : Sigmoid(it_proj);
auto bit = ElementTimes(it, cellActivationOp(bit_proj));
auto ft = hasPeephole ? iofActivationOp(ft_proj + ElementTimes(Cf, prevCellState)) : Sigmoid(ft_proj);
auto bft = ElementTimes(ft, prevCellState);
auto ct = Plus(bft, bit);
auto ot = hasPeephole ? iofActivationOp(ot_proj + ElementTimes(Co, ct)) : Sigmoid(ot_proj);
auto ht = ElementTimes(ot, hiddenActivationOp(ct));
auto c = ct;
auto h = ht;
return{ h, c };
#include "PrimitiveFunction.h"
#include "BlockFunction.h"
std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
const NDShape& cellShape,
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookC,
Constant &W, Constant &R, Constant &B,
Constant &Ci, Constant &Cf, Constant &Co)
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
auto dc = PlaceholderVariable(cellShape, input.DynamicAxes());
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
auto LSTMCell = LSTMPCell(
iofActivationOp, cellActivationOp, hiddenActivationOp,
dh, dc, W, R, B, Ci, Cf, Co);
auto actualDh = recurrenceHookH(LSTMCell.first);
auto actualDc = recurrenceHookC(LSTMCell.second);
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
LSTMCell.first->ReplacePlaceholders({ { inputPlaceholder , input}, { dh, actualDh },{ dc, actualDc } });
return std::make_tuple(LSTMCell.first, LSTMCell.second);
const std::vector<Variable> FindByNameHint(const std::vector<Variable> &inputs, const std::string &hint)
std::vector<Variable> variables;
for (auto v : inputs)
if (ToString(v.Name()).find(hint) != -1)
return variables;
Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDirections,
const std::string &nameHint, DataType datatype)
Variable initialVariable = datatype == DataType::Double ? Constant::Scalar(0.0) : Constant::Scalar(0.0f);
const std::vector<Variable> initialVariables = FindByNameHint(inputs, nameHint);
if (numDirections == 1 && initialVariables.size() >= 1)
initialVariable = initialVariables[0];
else if (numDirections == 2 && initialVariables.size() >= 2)
initialVariable = initialVariables[1];
return initialVariable;
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
int numDirections = direction == LSTMDirectionBidirection ? 2 : 1;
std::vector<FunctionPtr> outputHs;
for (int dir = 0; dir < numDirections; dir++)
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
// the first a few inputs are (in order): X, numDirections * W, numDirections * R
Variable X = inputs[0];
Variable W = inputs[1 + dir];
Variable R = inputs[1 + numDirections + dir];
Variable B;
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
if (numDirections == 1 && biasVariables.size() >= 1)
B = biasVariables[0];
else if (numDirections == 2 && biasVariables.size() == 2)
B = biasVariables[1];
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialCNameHint, X.GetDataType());
Variable initCVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialHNameHint, X.GetDataType());
std::vector<Variable> peepholeVariables = FindByNameHint(inputs, LSTMInputPeepholeNameHint);
Variable Ci, Cf, Co;
if (peepholeVariables.size() != 0 && peepholeVariables.size() != LSTMPeepholeCount && peepholeVariables.size() != 2 * LSTMPeepholeCount)
CNTK::LogicError("Peephole Variable count (%d) should be 0, 1 or 2 times the number of peephole factors (%d).",
(int)(peepholeVariables.size()), (int)LSTMPeepholeCount);
else if (numDirections == 1 && peepholeVariables.size() >= LSTMPeepholeCount)
Ci = peepholeVariables[LSTMPeepholeCountCiIndex];
Co = peepholeVariables[LSTMPeepholeCountCoIndex];
Cf = peepholeVariables[LSTMPeepholeCountCfIndex];
else if (numDirections == 2 && peepholeVariables.size() == numDirections * LSTMPeepholeCount)
Ci = peepholeVariables[LSTMPeepholeCount + LSTMPeepholeCountCiIndex];
Co = peepholeVariables[LSTMPeepholeCount + LSTMPeepholeCountCoIndex];
Cf = peepholeVariables[LSTMPeepholeCount + LSTMPeepholeCountCfIndex];
// ONNX spec
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
// as two separate LSTM. Therefore we can divide the dimension of the first axis
// by 4 to get the hidden size.
int hiddenDim = W.Shape()[0] / 4;
FunctionPtr outputH;
FunctionPtr outputC;
// if it is bidirectional LSTM, the second one will be the backword one.
bool go_backwards = direction == LSTMDirectionReverse || (numDirections == 2 && dir == 1);
std::function<FunctionPtr(const Variable&)> futureValueRecurrenceHook;
if (go_backwards)
futureValueRecurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
futureValueRecurrenceHook = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp,
futureValueRecurrenceHook, futureValueRecurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B,
(Constant &)Ci, (Constant &)Cf, (Constant &)Co);
if (outputHs.size() == 1)
return outputHs[0];
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
return Splice(operands, Axis(0), ToWString(node->Name()));

Просмотреть файл

@ -0,0 +1,77 @@
#pragma once
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See file in the project root for full license information.
// originally from CNTK\Tests\EndToEndTests\CNTKv2Library\Common\Common.h
#pragma once
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "Utils.h"
#include "proto/onnx/core/model.h"
#include <algorithm>
#include "CNTKLibrary.h"
#include <functional>
using namespace CNTK;
using namespace ONNXIR;
const std::string LSTMInputBiasNameHint = "_bias_";
const std::string LSTMInputInitialHNameHint = "_initial_h_";
const std::string LSTMInputInitialCNameHint = "_initial_c_";
const std::string LSTMInputPeepholeNameHint = "_peephole_";
LSTMInputIndexX = 0,
LSTMInputIndexW = 1,
LSTMInputIndexH = 2,
LSTMInputIndexB = 3,
LSTMInputIndexSequenceLens = 4,
LSTMInputIndexinitial_h = 5,
LSTMInputIndexinitial_c = 6,
LSTMInputIndexP = 7
LSTMActivationFIndex = 0,
LSTMActivationGIndex = 1,
LSTMActivationHIndex = 2,
LSTMActivationCount = 3
enum {
LSTMPeepholeCountCiIndex = 0,
LSTMPeepholeCountCoIndex = 1,
LSTMPeepholeCountCfIndex = 2,
LSTMPeepholeCount = 3
typedef enum {
} LSTMDirection;
CNTKLSTMBiasIndex = 0,
CNTKLSTMWeightIndex = 1,
CNTKLSTMHiddenWeightIndex = 2
CNTKLSTMOutputYhIndex = 0,
CNTKLSTMOutputChIndex = 1
const string LSTMDirectionBidirection = "bidirectional";
const string LSTMDirectionReverse = "reverse";
const string LSTMDirectionForward = "forward";
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta);

Просмотреть файл

@ -59,6 +59,9 @@ def verify_one_input(model, data, tmpdir, name):
o0 = model.eval({model.arguments[0]:data})
o1 = loaded_model.eval({loaded_model.arguments[0]:data})
if (type(o0) is list):
o0 = o0[0]
assert np.allclose(o0, o1)
validation_filename = os.path.join(str(tmpdir), name + R'_validation.onnx')
@ -515,6 +518,61 @@ def test_LRN(tmpdir):
model = C.local_response_normalization(x_r, 2, 1.0, 0.0001, 0.75)
verify_one_input(model, img, tmpdir, 'LRN_1')
from cntk.layers import *
from itertools import product
def CreateLSTMModel(activation,
return Sequential([
use_peepholes = peepholes,
activation = activation,
enable_self_stabilization = self_stabilization),
initial_state = initial_state)
# lstm attributes
use_peepholes_options = [False]
enable_self_stabilization_options = [False]
activation_options = [C.tanh]
#Recurrence attributes
initial_state_options = [0]
input_dim = 2
cell_dim = 3
batch_size = 1
sequence_len = 5
def MakeLSTMNameFromConfig(use_peepholes, enable_self_stabilization, initial_state, activtion):
model_name = 'LSTM.' + activtion.__name__
if (use_peepholes):
model_name += '.peephole'
model_name += '.stabilize'
if (initial_state != 0):
model_name += '.initial'
return model_name
def test_LSTM(tmpdir):
for config in list(product(use_peepholes_options, enable_self_stabilization_options,
initial_state_options, activation_options)):
model_filename = MakeLSTMNameFromConfig(*config)
use_peepholes, enable_self_stabilization, initial_state, activation = config
x = C.input_variable(input_dim, dynamic_axes=[Axis.default_batch_axis(), C.Axis('sequenceAxis')])
LSTMmodel = CreateLSTMModel(peepholes = use_peepholes,
activation = activation,
initial_state = initial_state,
cell_dim = cell_dim,
self_stabilization = enable_self_stabilization)(x)
data = np.random.uniform(low=0.0, high=1.0, size=(batch_size, sequence_len, input_dim)).astype('f')
verify_one_input(LSTMmodel, data, tmpdir, model_filename)
def test_MatMul(tmpdir):
data0 = np.asarray([[1,2],[3,4]], dtype=np.float32)