ONNX LSTM support
This commit is contained in:
Родитель
6a8773bcf2
Коммит
7651b0d567
3
Makefile
3
Makefile
|
@ -531,7 +531,8 @@ CNTKLIBRARY_COMMON_SRC =\
|
|||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/model.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/Operators.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNX.cpp \
|
||||
|
||||
|
|
|
@ -1975,6 +1975,13 @@ namespace CNTK
|
|||
[](const Axis& axis) { return (axis == Axis::DefaultBatchAxis()); });
|
||||
}
|
||||
|
||||
bool HasSequenceAxis() const {
|
||||
return (DynamicAxes().size() - (HasBatchAxis() ? 1 : 0)) > 0;
|
||||
}
|
||||
|
||||
bool IsInitialized() const {
|
||||
return m_dataFields != nullptr;
|
||||
}
|
||||
///
|
||||
/// Returns the name of 'this' variable
|
||||
///
|
||||
|
|
|
@ -187,6 +187,7 @@
|
|||
<ClInclude Include="proto\onnx\ONNX.h" />
|
||||
<ClInclude Include="proto\onnx\ONNXToCNTK.h" />
|
||||
<ClInclude Include="proto\onnx\Operators.h" />
|
||||
<ClInclude Include="proto\onnx\RNNHelper.h" />
|
||||
<ClInclude Include="Serialization.h" />
|
||||
<ClInclude Include="tensorboard\TensorBoardUtils.h" />
|
||||
<ClInclude Include="UserDefinedFunction.h" />
|
||||
|
@ -243,6 +244,7 @@
|
|||
<ClCompile Include="proto\onnx\ONNXToCNTK.cpp" />
|
||||
<ClCompile Include="proto\onnx\Operators.cpp" />
|
||||
<ClCompile Include="proto\onnx\protobuf\onnx-ml.pb.cc.VS_wrapper.cpp" />
|
||||
<ClCompile Include="proto\onnx\RNNHelper.cpp" />
|
||||
<ClCompile Include="Serialization.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader>Create</PrecompiledHeader>
|
||||
|
|
|
@ -108,6 +108,9 @@
|
|||
<ClCompile Include="proto\onnx\protobuf\onnx-ml.pb.cc.VS_wrapper.cpp">
|
||||
<Filter>proto\onnx\protobuf</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\RNNHelper.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
@ -190,6 +193,9 @@
|
|||
<ClInclude Include="API\CNTKLibraryC.h">
|
||||
<Filter>API</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\RNNHelper.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="API">
|
||||
|
@ -252,4 +258,4 @@
|
|||
<Filter>proto\onnx\protobuf</Filter>
|
||||
</Proto>
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
</Project>
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -9,6 +9,7 @@
|
|||
#include "Operators.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "RNNHelper.h"
|
||||
|
||||
using namespace ONNXIR;
|
||||
using namespace CNTK;
|
||||
|
@ -17,14 +18,16 @@ using namespace CNTK::ONNX;
|
|||
namespace CNTK
|
||||
{
|
||||
|
||||
typedef std::unordered_map<const Node *, FunctionPtr> ONNXToCNTKMap;
|
||||
typedef std::unordered_map<const Node *, std::vector<FunctionPtr>> ONNXToCNTKMap;
|
||||
typedef std::unordered_map<std::string, Variable> ONNXToCNTKVariableMap;
|
||||
class ONNXToCNTKHelper
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Convert an ONNX graph to a CNTK graph (Function).
|
||||
//
|
||||
static FunctionPtr FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
|
||||
static std::vector<FunctionPtr> FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
|
||||
ONNXToCNTKVariableMap &constructedNodeArgVariableMap,
|
||||
const Graph* graph, const DeviceDescriptor& computeDevice);
|
||||
|
||||
private:
|
||||
|
@ -37,6 +40,9 @@ namespace CNTK
|
|||
const DeviceDescriptor& computeDevice);
|
||||
static Variable CreateLeafVariableOrConstant(const NodeArg *nodeArg, const Node *parentNode, const Graph *graph,
|
||||
const DeviceDescriptor& computeDevice);
|
||||
static std::vector<Variable> CreateRNNLeafVariableOrConstant(const NodeArg *nodeArg,
|
||||
const Node *parentNode, const Graph* graph,
|
||||
ONNXToCNTKVariableMap &constructedNodeArgVariableMap, const DeviceDescriptor& computeDevice);
|
||||
static FunctionPtr CreateFunction(const Node *node, const std::vector<Variable> &inputs);
|
||||
|
||||
static bool IsSecondInputOfElementWiseOpsWithBroadcast(const Node *parentNode, const NodeArg *nodeArg);
|
||||
|
@ -100,14 +106,17 @@ namespace CNTK
|
|||
static string GetNamedAttributeAsString(const Node *node, const string &attributeName);
|
||||
static string GetNamedAttributeAsString(const Node *node, const string &attributeName, const string& defaultValue);
|
||||
|
||||
static std::vector<std::string> GetNamedAttributeAsStringVec(const Node *node, const string &attributeName,
|
||||
const std::vector<std::string> &defaultValues);
|
||||
|
||||
static std::vector<int64_t> GetNamedAttributeAsInt64Vec(const Node *node, const string &attributeName);
|
||||
static std::vector<int64_t> GetNamedAttributeAsInt64Vec(const Node *node, const string &attributeName, const std::vector<int64_t>& defaultValue);
|
||||
|
||||
static std::vector<float> GetNamedAttributeAsFloatVec(const Node *node, const string &attributeName);
|
||||
static std::vector<float> GetNamedAttributeAsFloatVec(const Node *node, const string &attributeName, const std::vector<float>& defaultValue);
|
||||
|
||||
static Axis ConvertAxisToCNTKCppApi(const Axis& axes, const Variable& input);
|
||||
static std::vector<Axis> ConvertAxesToCNTKCppApi(const std::vector<Axis> &axes, const Variable& operand);
|
||||
static Axis ConvertONNXAxisToCNTKCppApi(int64_t axes, const Variable& input);
|
||||
static std::vector<Axis> ConvertONNXAxesToCNTKCppApi(const std::vector<int64_t> &axes, const Variable& operand);
|
||||
|
||||
static void AdjustAutoPaddingAndStrideForCNTKSpecialCases(const Variable &operand,
|
||||
std::vector<bool> &autoPadding, NDShape &strides);
|
||||
|
@ -547,18 +556,344 @@ bool ONNXToCNTKHelper::FixConstantShapeForConstantVariableInputPair(const std::v
|
|||
return true;
|
||||
}
|
||||
|
||||
Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
|
||||
const Node *parentNode, const Graph* graph, const DeviceDescriptor& computeDevice)
|
||||
int CalculateNodeArgInputIndex(const NodeArg *nodeArg, const Node *parentNode)
|
||||
{
|
||||
std::string nodeName = nodeArg->Name();
|
||||
std::vector<NodeArg>::const_iterator it = std::find_if(parentNode->InputDefs().cbegin(),
|
||||
parentNode->InputDefs().cend(), [nodeArg](const NodeArg &other) {return other.Name() == nodeArg->Name(); });
|
||||
if (it == parentNode->InputDefs().cend())
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
else
|
||||
return it - parentNode->InputDefs().cbegin();
|
||||
}
|
||||
|
||||
template<typename DType>
|
||||
Constant CreateConstantWithRawData(DType *data,const NDShape &shape, const std::string &name,
|
||||
const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
DataType dataType = AsDataType<DType>();
|
||||
|
||||
int totalSize = shape.TotalSize();
|
||||
NDArrayViewPtr dstFinal(new NDArrayView(dataType, shape, data,
|
||||
totalSize * sizeof(DType), computeDevice.CPUDevice()));
|
||||
|
||||
if (computeDevice.Type() == DeviceKind::CPU)
|
||||
{
|
||||
Constant constantVariable(dstFinal, ToWString(name));
|
||||
return constantVariable;
|
||||
}
|
||||
else
|
||||
{
|
||||
// this is the way to load values into GPU:
|
||||
// Create a GPU NDArrayView and CopyFrom a CPU NDArrayView that holding the data.
|
||||
NDArrayViewPtr dstFinalGPU(new NDArrayView(dataType, StorageFormat::Dense, shape, computeDevice));
|
||||
dstFinalGPU->CopyFrom(*dstFinal);
|
||||
Constant constantVariable(dstFinalGPU, ToWString(name));
|
||||
return constantVariable;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Variable> CreateRNNConstant(
|
||||
const Node *parentNode, int index, const std::string &name, onnx::TensorProto &valueProto, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
std::vector<Variable> inputs;
|
||||
string parentONNXOpName = parentNode->OpType();
|
||||
auto dataType = valueProto.data_type();
|
||||
|
||||
switch (dataType)
|
||||
{
|
||||
case TensorProto_DataType_FLOAT:
|
||||
{
|
||||
if (valueProto.float_data().empty())
|
||||
{
|
||||
RetrieveRawDataAsFloat(valueProto);
|
||||
}
|
||||
}
|
||||
case TensorProto_DataType_DOUBLE:
|
||||
{
|
||||
if (valueProto.double_data().empty())
|
||||
{
|
||||
RetrieveRawDataAsDouble(valueProto);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// index to LSTM inputs as specified in the ONNX document.
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||
if (parentONNXOpName == "LSTM")
|
||||
{
|
||||
switch (index)
|
||||
{
|
||||
case LSTMInputIndexX:
|
||||
// X, should not come to here
|
||||
return inputs;
|
||||
case LSTMInputIndexW:
|
||||
case LSTMInputIndexH:
|
||||
// W, R:
|
||||
{
|
||||
// see ONNX spec for the tensor shape
|
||||
int num_directions = valueProto.dims(0);
|
||||
size_t rows = valueProto.dims(1);
|
||||
size_t cols = valueProto.dims(2);
|
||||
|
||||
// CNTK cpp requires shape being (input_size, 4 * hidden_size)
|
||||
NDShape weightShape({ rows, cols });
|
||||
|
||||
int input_size = cols;
|
||||
int cell_size = rows / 4;
|
||||
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
{
|
||||
std::string nodeName = name + (index == 1 ? "_W_" : "_R_") + (char)dir;
|
||||
int totalSizePerDirection = rows * cols;
|
||||
|
||||
// TODO: what about double?
|
||||
float *data = new float[totalSizePerDirection];
|
||||
for (size_t count = 0; count < totalSizePerDirection; count++)
|
||||
{
|
||||
int row = count / input_size;
|
||||
int col = count % input_size;
|
||||
int block = row / cell_size;
|
||||
|
||||
if (block == 1)
|
||||
{
|
||||
// o
|
||||
row += cell_size * 2;
|
||||
}
|
||||
else if (block == 3)
|
||||
{
|
||||
// c
|
||||
row -= cell_size * 2;
|
||||
}
|
||||
|
||||
int sourceIndex = dir * totalSizePerDirection + count;
|
||||
int targetIndex = col * cell_size * 4 + row;
|
||||
data[targetIndex] = valueProto.float_data()[sourceIndex];
|
||||
}
|
||||
|
||||
Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice);
|
||||
inputs.push_back(constant);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
case LSTMInputIndexB:
|
||||
// B
|
||||
{
|
||||
// see ONNX spec for the tensor shape
|
||||
int num_directions = valueProto.dims(0);
|
||||
int cell_size = valueProto.dims(1) / 8;
|
||||
// there is an ONNX spec issue with bias input. It states that
|
||||
// "This tensor has shape `[num_directions, 8*hidden_size]", which means
|
||||
// hidden and input are applied with bias separately after weight.
|
||||
// In CNTK, bias is be applied in fused form, after hidden and input
|
||||
// are element-wise added. In this case
|
||||
// the bias shape is [num_directions, 4*hidden_size]
|
||||
NDShape weightShape({ (size_t)(4 * cell_size) });
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
{
|
||||
std::string nodeName = name + std::string(1, (char)dir) + LSTMInputBiasNameHint;
|
||||
int totalSizePerDirection = 4 * cell_size;
|
||||
float *data = new float[totalSizePerDirection];
|
||||
for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++)
|
||||
{
|
||||
int row = targetIndex;
|
||||
|
||||
// TODO: specific to LSTM. icfo (CNTK) to iofc(ONNX)
|
||||
int block = row / cell_size;
|
||||
if (block == 1)
|
||||
{
|
||||
// c
|
||||
row += 2 * cell_size;
|
||||
}
|
||||
else if (block == 3)
|
||||
{
|
||||
// o
|
||||
row -= 2 * cell_size;
|
||||
}
|
||||
|
||||
// soruce is collmn major
|
||||
int src_index = row;
|
||||
// "fuse"
|
||||
data[targetIndex] =
|
||||
valueProto.float_data()[dir * 2 * totalSizePerDirection + src_index] +
|
||||
valueProto.float_data()[dir * 2 * totalSizePerDirection + totalSizePerDirection + src_index];
|
||||
}
|
||||
|
||||
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
|
||||
inputs.push_back(constant);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
case LSTMInputIndexSequenceLens:
|
||||
// sequence length is treated as free dimension
|
||||
return inputs;
|
||||
case LSTMInputIndexinitial_h:
|
||||
case LSTMInputIndexinitial_c:
|
||||
{
|
||||
// initial_h, initial_c
|
||||
int num_directions = valueProto.dims(0);
|
||||
|
||||
// TODO: batch shall be one?
|
||||
// int batchSize = valueProto.dims(1);
|
||||
int cell_size = valueProto.dims(2);
|
||||
// there is an ONNX spec issue with bias input. It states that
|
||||
// "This tensor has shape `[num_directions, 8*hidden_size]", which means
|
||||
// hidden and input are applied with bias separately after weight.
|
||||
// In CNTK, bias is be applied in fused form, after hidden and input
|
||||
// are element-wise added. In this case
|
||||
// the bias shape is [num_directions, 4*hidden_size]
|
||||
NDShape weightShape({ (size_t)(cell_size) });
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
{
|
||||
std::string nodeName = name + std::string(1, (char)dir);
|
||||
if (index == 5)
|
||||
nodeName += LSTMInputInitialHNameHint;
|
||||
else
|
||||
nodeName += LSTMInputInitialCNameHint;
|
||||
|
||||
float *data = new float[cell_size];
|
||||
for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
|
||||
{
|
||||
data[targetIndex] = valueProto.float_data()[dir * cell_size + targetIndex];
|
||||
}
|
||||
|
||||
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
|
||||
inputs.push_back(constant);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
break;
|
||||
case LSTMInputIndexP:
|
||||
// P
|
||||
{
|
||||
int num_directions = valueProto.dims(0);
|
||||
int cell_size = valueProto.dims(1) / 3;
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
for (int i = 0; i < 3; i++)
|
||||
{
|
||||
std::string nodeName = name + ((i == 0) ? "_i" : ((i == 1) ? "_o" : "_f")) +
|
||||
std::string(1, (char)dir) + LSTMInputPeepholeNameHint;
|
||||
float *data = new float[cell_size];
|
||||
NDShape weightShape({ (size_t)(cell_size) });
|
||||
for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
|
||||
{
|
||||
data[targetIndex] = valueProto.float_data()[(dir * 3 + i) * cell_size + targetIndex];
|
||||
}
|
||||
|
||||
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
|
||||
inputs.push_back(constant);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
CNTK::LogicError("CreateRNNConstant received unepxpeted index: %d", index);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<FunctionPtr> CreateRNNConstantOp(const Graph* graph, const Node *node, const Node *parentNode, int index,
|
||||
const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
onnx::TensorProto valueProto;
|
||||
if (!graph->GetInitialTensor(node->Name(), valueProto))
|
||||
{
|
||||
NodeAttributes::const_iterator itValue = node->GetAttributes().find("value");
|
||||
if (itValue == node->GetAttributes().cend())
|
||||
{
|
||||
return std::vector<FunctionPtr>();
|
||||
}
|
||||
valueProto = itValue->second.t();
|
||||
}
|
||||
|
||||
std::vector<Variable> constantNodes = CreateRNNConstant(parentNode, index, node->Name(), valueProto, computeDevice);
|
||||
std::vector<FunctionPtr> returns;
|
||||
for (auto c : constantNodes)
|
||||
returns.push_back(c);
|
||||
return returns;
|
||||
}
|
||||
|
||||
std::vector<Variable> ONNXToCNTKHelper::CreateRNNLeafVariableOrConstant(const NodeArg *nodeArg,
|
||||
const Node *parentNode, const Graph* graph, ONNXToCNTKVariableMap &constructedNodeArgVariableMap,
|
||||
const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
string parentONNXOpName = parentNode->OpType();
|
||||
std::string nodeName = nodeArg->Name();
|
||||
onnx::TensorProto valueProto;
|
||||
if (graph->GetInitialTensor(nodeName, valueProto))
|
||||
{
|
||||
int index = CalculateNodeArgInputIndex(nodeArg, parentNode);
|
||||
return CreateRNNConstant(parentNode, index, nodeName, valueProto, computeDevice);
|
||||
}
|
||||
|
||||
const TensorShapeProto *shapeProto = nodeArg->Shape();
|
||||
if (shapeProto == nullptr)
|
||||
{
|
||||
// dummy input,
|
||||
return std::vector<Variable>();
|
||||
}
|
||||
|
||||
// std::string nodeName = nodeArg->Name();
|
||||
std::vector<Axis> dynamicAxes({ Axis::OperandSequenceAxis() , Axis::DefaultBatchAxis()});
|
||||
|
||||
|
||||
if (parentONNXOpName == "LSTM")
|
||||
{
|
||||
// index to LSTM inputs as specified in the ONNX document.
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||
int inputIndex = CalculateNodeArgInputIndex(nodeArg, parentNode);
|
||||
switch (inputIndex)
|
||||
{
|
||||
case LSTMInputIndexX:
|
||||
// X: `[seq_length, batch_size, input_size]`.
|
||||
{
|
||||
Variable inputVariable;
|
||||
if (constructedNodeArgVariableMap.find(nodeArg->Name()) == constructedNodeArgVariableMap.end())
|
||||
{
|
||||
DataType dataType = FromONNXType(nodeArg->ToProto().type());
|
||||
int input_size = shapeProto->dim(2).dim_value();
|
||||
NDShape shape({ (size_t)(input_size) });
|
||||
inputVariable = InputVariable(shape, dataType, ToWString(nodeArg->Name()), dynamicAxes);
|
||||
constructedNodeArgVariableMap.insert(ONNXToCNTKVariableMap::value_type(nodeArg->Name(), inputVariable));
|
||||
}
|
||||
return std::vector<Variable>({ constructedNodeArgVariableMap[nodeArg->Name()]});
|
||||
}
|
||||
// other inputs shall be ONNX constant node and be created as CNTK Constant in CreateRNNConstant
|
||||
case LSTMInputIndexW: // W
|
||||
case LSTMInputIndexH: // R
|
||||
case LSTMInputIndexB: // B
|
||||
case LSTMInputIndexSequenceLens: // sequence_lens
|
||||
case LSTMInputIndexinitial_h: // initial_h
|
||||
case LSTMInputIndexinitial_c: // initial_c
|
||||
case LSTMInputIndexP: // P
|
||||
NOT_IMPLEMENTED;
|
||||
default:
|
||||
LogicError("LSTM node has unexpected input");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
|
||||
Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
|
||||
const Node *parentNode, const Graph* graph, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
string parentONNXOpName = parentNode->OpType();
|
||||
|
||||
std::string nodeName = nodeArg->Name();
|
||||
onnx::TensorProto valueProto;
|
||||
if (graph->GetInitialTensor(nodeName, valueProto))
|
||||
{
|
||||
return CreateConstant(valueProto, nodeName, computeDevice);
|
||||
}
|
||||
|
||||
auto dataType = FromONNXType(nodeArg->ToProto().type());
|
||||
auto shapeProto = nodeArg->Shape();
|
||||
|
||||
// in CNTK constants are created as Node (not a leaf) with values.
|
||||
|
@ -571,15 +906,38 @@ Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
|
|||
shape = shape.SubShape(0, shape.Rank() - 1);
|
||||
}
|
||||
|
||||
std::vector<Axis> dynamicAxes({ Axis::DefaultBatchAxis() });
|
||||
|
||||
// TODO: this is not fully correct. We need to get hasSequenceAxis
|
||||
// over the traverse path. An input will have a sequence axis
|
||||
// only if it outputs to an RNN op along the path.
|
||||
// This requires support from LotusIR.
|
||||
// Now traversing starts from arbitray nodes which may miss the RNN op.
|
||||
bool hasSequenceAxis = false;
|
||||
for (Graph::NodeIterator nodeIt = (const_cast<Graph *>(graph))->Nodes_begin();
|
||||
nodeIt != (const_cast<Graph *>(graph))->Nodes_end(); ++nodeIt)
|
||||
if (Operators::IsRNNOp((*nodeIt)->OpType()))
|
||||
{
|
||||
hasSequenceAxis = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (hasSequenceAxis)
|
||||
{
|
||||
shape = shape.SubShape(0, shape.Rank() - 1);
|
||||
dynamicAxes.insert(dynamicAxes.begin(), Axis::OperandSequenceAxis());
|
||||
}
|
||||
|
||||
auto dataType = FromONNXType(nodeArg->ToProto().type());
|
||||
switch (dataType)
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
return InputVariable(shape, DataType::Float, ToWString(nodeArg->Name()), { Axis::DefaultBatchAxis() });
|
||||
return InputVariable(shape, DataType::Float, ToWString(nodeArg->Name()), dynamicAxes);
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
return InputVariable(shape, DataType::Double, ToWString(nodeArg->Name()), { Axis::DefaultBatchAxis() });
|
||||
return InputVariable(shape, DataType::Double, ToWString(nodeArg->Name()), dynamicAxes);
|
||||
}
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
|
@ -775,6 +1133,17 @@ string ONNXToCNTKHelper::GetNamedAttributeAsString(const Node *node, const strin
|
|||
return attributeProto.s();
|
||||
}
|
||||
|
||||
std::vector<std::string> ONNXToCNTKHelper::GetNamedAttributeAsStringVec(const Node *node, const string &attributeName,
|
||||
const std::vector<std::string> &defaultValues)
|
||||
{
|
||||
NodeAttributes::const_iterator itValue = FindAttributeIterator(node, attributeName, false);
|
||||
if (itValue == node->GetAttributes().end())
|
||||
return defaultValues;
|
||||
|
||||
const AttributeProto &attributeProto = itValue->second;
|
||||
return std::vector<std::string>(attributeProto.strings().begin(), attributeProto.strings().end());
|
||||
}
|
||||
|
||||
std::vector<int64_t> ONNXToCNTKHelper::GetNamedAttributeAsInt64Vec(const Node *node, const string &attributeName)
|
||||
{
|
||||
NodeAttributes::const_iterator itValue = FindAttributeIterator(node, attributeName, true);
|
||||
|
@ -970,40 +1339,50 @@ std::pair<Variable, Variable> ONNXToCNTKHelper::BroadcastElementWiseInput(
|
|||
}
|
||||
else
|
||||
{
|
||||
// this is to handle edge cases where one input has batch (and sequence) axis and the other does not.
|
||||
// the input with rank higher is caused by extra batch/sequence dimension which need to be squeezed away.
|
||||
// TODO: investigate if this edge case can be avoided when converting CNTK to ONNX.
|
||||
if (input0.Shape().Rank() == input1.Shape().Rank() - 1)
|
||||
{
|
||||
if (input0.HasBatchAxis() && !input1.HasBatchAxis())
|
||||
return{ input0, Reshape(input1, shape1.SubShape(0, shape1.Rank() - 1)) };
|
||||
}
|
||||
else if (input0.Shape().Rank() == input1.Shape().Rank() - 2 && input0.DynamicAxes().size() == 2 &&
|
||||
input1.DynamicAxes().size() == 0)
|
||||
{
|
||||
return{ input0, Reshape(input1, shape1.SubShape(0, shape1.Rank() - 2)) };
|
||||
}
|
||||
else if (input0.Shape().Rank() - 1 == input1.Shape().Rank())
|
||||
{
|
||||
if (!input0.HasBatchAxis() && input1.HasBatchAxis())
|
||||
return{ Reshape(input0, shape0.SubShape(0, shape0.Rank() - 1)), input1 };
|
||||
}
|
||||
else if (input0.Shape().Rank() - 2 == input1.Shape().Rank() && input1.DynamicAxes().size() == 2 &&
|
||||
input0.DynamicAxes().size() == 0)
|
||||
{
|
||||
return{ Reshape(input0, shape0.SubShape(0, shape0.Rank() - 2)), input1 };
|
||||
}
|
||||
return{ input0 , input1 };
|
||||
}
|
||||
}
|
||||
|
||||
Axis ONNXToCNTKHelper::ConvertAxisToCNTKCppApi(const Axis& axis, const Variable& operand)
|
||||
Axis ONNXToCNTKHelper::ConvertONNXAxisToCNTKCppApi(int64_t axis, const Variable& operand)
|
||||
{
|
||||
// reverse CNTKToONNXHelper::ConvertAxisToOnnx
|
||||
// note that axis is already decreased by one (assuming there is a batch axis)
|
||||
int64_t index = axis.StaticAxisIndex();
|
||||
if (!operand.HasBatchAxis())
|
||||
{
|
||||
index++;
|
||||
}
|
||||
int index = axis - operand.DynamicAxes().size();
|
||||
if (index < 0)
|
||||
LogicError("ConvertAxisToCNTKCppApi cannot convert index < 0 to axis");
|
||||
|
||||
// apply -index - 1
|
||||
return Axis(-index - 1);
|
||||
}
|
||||
|
||||
std::vector<Axis> ONNXToCNTKHelper::ConvertAxesToCNTKCppApi(const std::vector<Axis> &axes, const Variable& operand)
|
||||
std::vector<Axis> ONNXToCNTKHelper::ConvertONNXAxesToCNTKCppApi(const std::vector<int64_t> &axes, const Variable& operand)
|
||||
{
|
||||
std::vector<Axis> cntkAxes(axes.size());
|
||||
for (int i = 0; i < axes.size(); i++)
|
||||
{
|
||||
cntkAxes[i] = ConvertAxisToCNTKCppApi(axes[i], operand);
|
||||
cntkAxes[i] = ConvertONNXAxisToCNTKCppApi(axes[i], operand);
|
||||
}
|
||||
|
||||
return cntkAxes;
|
||||
|
@ -1050,15 +1429,29 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
{
|
||||
string onnxOpName = node->OpType();
|
||||
|
||||
if (onnxOpName == "Cast" && inputs[0].GetDataType() == DataType::Float && inputs[0].Owner() != nullptr)
|
||||
{
|
||||
// CNTK does not support cast op. Only float is available with ONNX support.
|
||||
// Question for having a cast op: Why not cast data as necessary internally.
|
||||
return inputs[0].Owner();
|
||||
}
|
||||
else if (onnxOpName == "LSTM")
|
||||
{
|
||||
const string direction = GetNamedAttributeAsString(node, "direction");
|
||||
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
|
||||
std::vector<float> activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector<float>());
|
||||
const std::vector<string> activations = GetNamedAttributeAsStringVec(node, "activations",
|
||||
std::vector<string>({"Sigmoid", "Tanh", "Tanh"}));
|
||||
return CreateLSTM(node, inputs, direction, activations, activation_alpha, activation_beta);
|
||||
}
|
||||
if (onnxOpName == "FC")
|
||||
{
|
||||
return CreateCNTKFCNode(ToWString(node->Name()), inputs);
|
||||
}
|
||||
else if (onnxOpName == "Flatten")
|
||||
{
|
||||
Axis defaultAxis(-1);
|
||||
Axis axis = GetNamedAttributeAsAxis(node, "axis", defaultAxis);
|
||||
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
|
||||
int64_t axisIndex = (size_t)GetNamedAttributeAsInt64(node, "axis", 0);
|
||||
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
|
||||
FunctionPtr cntkFunction = Flatten(inputs[0], axis, ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
|
@ -1513,57 +1906,57 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
}
|
||||
else if (onnxOpName == "ReduceMax")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
FunctionPtr cntkFunction = ReduceMax(inputs[0], axes[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceMin")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
FunctionPtr cntkFunction = ReduceMin(inputs[0], axes[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceSum")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
FunctionPtr cntkFunction = ReduceSum(inputs[0], axes[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceMean")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
FunctionPtr cntkFunction = ReduceMean(inputs[0], axes[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceProd")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
FunctionPtr cntkFunction = ReduceProd(inputs[0], axes[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceLogSumExp")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
FunctionPtr cntkFunction = ReduceLogSum(inputs[0], axes[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceL1")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
bool keepdims = GetNamedAttributeAsInt64(node, "keepdims", 1) == 1;
|
||||
FunctionPtr cntkFunction = ReduceL1(inputs[0], axes, keepdims, ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceL2")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
bool keepdims = GetNamedAttributeAsInt64(node, "keepdims", 1) == 1;
|
||||
FunctionPtr cntkFunction = ReduceL2(inputs[0], axes, keepdims, ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "ReduceSumSquare")
|
||||
{
|
||||
std::vector<Axis> axes = ConvertAxesToCNTKCppApi(GetNamedAttributeAsAxes(node, "axes"), inputs[0]);
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
bool keepdims = GetNamedAttributeAsInt64(node, "keepdims", 1) == 1;
|
||||
FunctionPtr cntkFunction = ReduceSumSquare(inputs[0], axes, keepdims, ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
|
@ -1572,34 +1965,60 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
{
|
||||
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis");
|
||||
// -1 to compensate what ConvertAxisToCNTKCppApi assumes that axis is already decreased by 1
|
||||
Axis axis(axisIndex - 1);
|
||||
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
|
||||
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
|
||||
FunctionPtr cntkfunction = Argmax(inputs[0], axis, ToWString(node->Name()));
|
||||
return cntkfunction;
|
||||
}
|
||||
else if (onnxOpName == "ArgMin")
|
||||
{
|
||||
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis");
|
||||
// -1 to compensate what ConvertAxisToCNTKCppApi assumes that axis is already decreased by 1
|
||||
Axis axis(axisIndex - 1);
|
||||
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
|
||||
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
|
||||
FunctionPtr cntkFunction = Argmin(inputs[0], axis, ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
else if (onnxOpName == "Reshape")
|
||||
{
|
||||
// Skip reshape is it follows a LSTM.
|
||||
const Node* childNode = GetChildNode(node, &node->InputDefs()[0]);
|
||||
if (childNode != nullptr && childNode->OpType() == "LSTM")
|
||||
{
|
||||
// TODO: this is to undo reshape after LSTM in CNTKToONNX to workaround
|
||||
// output shape mismatch with input to the nexk LSTM layer.
|
||||
NDShape newShape = GetNamedAttributeAsShape(node, "shape", false);
|
||||
newShape = newShape.SubShape(0, newShape.Rank() - inputs[0].DynamicAxes().size());
|
||||
return Reshape(inputs[0], newShape, ToWString(node->Name()));
|
||||
}
|
||||
|
||||
NDShape newShape = GetNamedAttributeAsShape(node, "shape", false);
|
||||
if (inputs[0].HasBatchAxis())
|
||||
if (inputs[0].DynamicAxes().size() > 0)
|
||||
{
|
||||
if (newShape.Rank() == 1)
|
||||
LogicError("Reshape: 'shape' attribute must include element for batch axis.");
|
||||
newShape = newShape.SubShape(0, newShape.Rank() - 1);
|
||||
newShape = newShape.SubShape(0, newShape.Rank() - inputs[0].DynamicAxes().size());
|
||||
}
|
||||
|
||||
int inferredDim = -1;
|
||||
for (size_t i = 0; i < newShape.Dimensions().size(); ++i)
|
||||
{
|
||||
if (newShape[i] == 0)
|
||||
newShape[i] = inputs[0].Shape()[i];
|
||||
else if (newShape[i] == 0)
|
||||
{
|
||||
if (inferredDim == -1)
|
||||
inferredDim = i;
|
||||
else
|
||||
LogicError("Reshape: 'shape' contains more than one inferred dimension.");
|
||||
}
|
||||
}
|
||||
|
||||
if (inferredDim != -1)
|
||||
{
|
||||
if (inputs[0].Shape().TotalSize() % newShape.TotalSize() != 0)
|
||||
LogicError("Reshape: 'shape' contains more than one inferred dimension.");
|
||||
int inferredDimSize = inputs[0].Shape().TotalSize() / newShape.TotalSize();
|
||||
newShape[inferredDim] = inferredDimSize;
|
||||
}
|
||||
|
||||
FunctionPtr cntkFunction = Reshape(inputs[0], newShape, ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
|
@ -1608,8 +2027,8 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
// We allow the 'axis' attribute to be optional, and not required (as
|
||||
// given in Concat's ONNX spec), to be consistent with other frameworks.
|
||||
// 'axis' can be enforced as a required attribute, if needed.
|
||||
Axis axis = GetNamedAttributeAsAxis(node, "axis", Axis(0));
|
||||
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
|
||||
int64_t onnxAxis = GetNamedAttributeAsInt64(node, "axis", 0);
|
||||
Axis axis = ConvertONNXAxisToCNTKCppApi(onnxAxis, inputs[0]);
|
||||
std::vector<Variable> fixedInputs;
|
||||
if (FixConstantShapeForConstantVariableInputPair(inputs, fixedInputs))
|
||||
{
|
||||
|
@ -1626,12 +2045,7 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
else if (onnxOpName == "Slice")
|
||||
{
|
||||
// axes is optional so provide a default
|
||||
std::vector<Axis> axes;
|
||||
axes = GetNamedAttributeAsAxes(node, "axes", axes);
|
||||
for (int i = 0; i < axes.size(); i++)
|
||||
{
|
||||
axes[i] = ConvertAxisToCNTKCppApi(axes[i], inputs[0]);
|
||||
}
|
||||
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
|
||||
|
||||
std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts");
|
||||
std::vector<int64_t> ends64 = GetNamedAttributeAsInt64Vec(node, "ends");
|
||||
|
@ -1713,13 +2127,15 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
{
|
||||
if (HasNamedAttribute(node, "axis"))
|
||||
{
|
||||
Axis defaultAxes(-1);
|
||||
Axis axis = GetNamedAttributeAsAxis(node, "axis", defaultAxes);
|
||||
axis = ConvertAxisToCNTKCppApi(axis, inputs[0]);
|
||||
int64_t axisIndex = GetNamedAttributeAsInt64(node, "axis", 0);
|
||||
Axis axis = ConvertONNXAxisToCNTKCppApi(axisIndex, inputs[0]);
|
||||
return GatherOp(inputs[1], inputs[0], axis, ToWString(node->Name()));
|
||||
}
|
||||
else
|
||||
return GatherOp(inputs[1], inputs[0], ToWString(node->Name()));
|
||||
{
|
||||
FunctionPtr cntkFunction = GatherOp(inputs[1], inputs[0], ToWString(node->Name()));
|
||||
return cntkFunction;
|
||||
}
|
||||
}
|
||||
else if (onnxOpName == "DepthToSpace")
|
||||
{
|
||||
|
@ -1768,14 +2184,34 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
}
|
||||
}
|
||||
|
||||
FunctionPtr ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
|
||||
std::pair<const Node *, int> FindParentAndChildIndex(const Node *node)
|
||||
{
|
||||
Node::NodeConstIterator it = node->OutputNodes_begin();
|
||||
if (it != node->OutputNodes_end())
|
||||
{
|
||||
const Node *parent = *it;
|
||||
int index = 0;
|
||||
for (auto nodeArg : parent->InputDefs())
|
||||
{
|
||||
if (nodeArg.Name() == node->Name())
|
||||
{
|
||||
return std::make_pair(parent, index);
|
||||
}
|
||||
index++;
|
||||
}
|
||||
}
|
||||
return std::make_pair(nullptr, -1);
|
||||
}
|
||||
|
||||
std::vector<FunctionPtr> ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXToCNTKMap &constructedNodeMap,
|
||||
ONNXToCNTKVariableMap &constructedNodeArgVariableMap,
|
||||
const Graph* graph, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
auto nodeOpStr = node->OpType();
|
||||
ONNXToCNTKMap::iterator itONNXToCNTKMap = constructedNodeMap.find(node);
|
||||
if (itONNXToCNTKMap != constructedNodeMap.end())
|
||||
{
|
||||
return itONNXToCNTKMap->second;
|
||||
return std::vector<FunctionPtr>({ itONNXToCNTKMap->second });
|
||||
}
|
||||
|
||||
std::vector<Variable> inputs;
|
||||
|
@ -1789,24 +2225,52 @@ FunctionPtr ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXToCNTKMap &cons
|
|||
ONNXToCNTKMap::iterator itNodeMap = constructedNodeMap.find(const_cast<Node *>(inputNode));
|
||||
if (itNodeMap != constructedNodeMap.end())
|
||||
{
|
||||
inputs.push_back(itNodeMap->second);
|
||||
inputs.insert(inputs.end(), itNodeMap->second.begin(), itNodeMap->second.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
FunctionPtr input = FromONNXNode(inputNode, constructedNodeMap, graph, computeDevice);
|
||||
inputs.push_back(input);
|
||||
std::vector<FunctionPtr> inputVariables = FromONNXNode(inputNode, constructedNodeMap,
|
||||
constructedNodeArgVariableMap, graph, computeDevice);
|
||||
inputs.insert(inputs.end(), inputVariables.begin(), inputVariables.end());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Variable inputVariable = CreateLeafVariableOrConstant(nodeArg, node, graph, computeDevice);
|
||||
inputs.push_back(inputVariable);
|
||||
std::string parentONNXOpName = node->OpType();
|
||||
if (parentONNXOpName == "LSTM")
|
||||
{
|
||||
std::vector<Variable> inputVariables =
|
||||
CreateRNNLeafVariableOrConstant(nodeArg, node, graph, constructedNodeArgVariableMap, computeDevice);
|
||||
inputs.insert(inputs.end(), inputVariables.begin(), inputVariables.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
Variable inputVariable = CreateLeafVariableOrConstant(nodeArg, node, graph, computeDevice);
|
||||
inputs.push_back(inputVariable);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr cntkFunction = CreateCNTKNode(node, inputs, computeDevice);
|
||||
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, cntkFunction));
|
||||
return cntkFunction;
|
||||
//
|
||||
const Node *parentNode;
|
||||
int childIndex;
|
||||
std::tie<const Node *, int>(parentNode, childIndex) = FindParentAndChildIndex(node);
|
||||
if (parentNode != nullptr && parentNode->OpType() == "LSTM")
|
||||
{
|
||||
std::vector<FunctionPtr> cntkFunctions = CreateRNNConstantOp(graph, node, parentNode, childIndex, computeDevice);
|
||||
if (!cntkFunctions.empty())
|
||||
{
|
||||
// TODO: make node map to vector of FunctionPtr
|
||||
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, cntkFunctions));
|
||||
}
|
||||
return cntkFunctions;
|
||||
}
|
||||
else
|
||||
{
|
||||
FunctionPtr cntkFunction = CreateCNTKNode(node, inputs, computeDevice);
|
||||
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, std::vector<FunctionPtr>({ cntkFunction })));
|
||||
return std::vector<FunctionPtr>({ cntkFunction });
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr ONNXToCNTKHelper::CreateCNTKNode(const Node *node, const std::vector<Variable> &inputs,
|
||||
|
@ -2062,13 +2526,15 @@ FunctionPtr ONNXToCNTK::CreateGraph(ONNXIR::Graph* src, const DeviceDescriptor&
|
|||
|
||||
// To use depth-first-traversal, keeps a collection of visited nodes.
|
||||
ONNXToCNTKMap constructedFunctions;
|
||||
ONNXToCNTKVariableMap constructedNodeArgVariableMap;
|
||||
for (Graph::NodeIterator it = src->Nodes_begin(); it != src->Nodes_end(); ++it)
|
||||
{
|
||||
const Node *node = *it;
|
||||
|
||||
if (constructedFunctions.find(node) == constructedFunctions.end())
|
||||
{
|
||||
FunctionPtr cntkNode = ONNXToCNTKHelper::FromONNXNode(node, constructedFunctions, src, computeDevice);
|
||||
std::vector<FunctionPtr> cntkNode = ONNXToCNTKHelper::FromONNXNode(node,
|
||||
constructedFunctions, constructedNodeArgVariableMap, src, computeDevice);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2083,7 +2549,7 @@ FunctionPtr ONNXToCNTK::CreateGraph(ONNXIR::Graph* src, const DeviceDescriptor&
|
|||
std::vector<FunctionPtr> functions;
|
||||
for (Node::NodeConstIterator it = itNodeFn->first->InputNodes_begin(); it != itNodeFn->first->InputNodes_end(); ++it)
|
||||
{
|
||||
functions.push_back(constructedFunctions[*it]);
|
||||
functions.insert(functions.end(), constructedFunctions[*it].begin(), constructedFunctions[*it].end());
|
||||
}
|
||||
|
||||
if (functions.empty())
|
||||
|
|
|
@ -370,6 +370,9 @@ namespace ONNX
|
|||
{ L"useStatsAcrossChannels", "across_channels" },
|
||||
{ L"doVarianceScaling", "normalize_variance" },
|
||||
} } },
|
||||
{ L"Embedding",{ {
|
||||
{ L"Embedding", "Gather" },
|
||||
} } },
|
||||
};
|
||||
|
||||
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,
|
||||
|
@ -425,7 +428,15 @@ namespace ONNX
|
|||
(onnxOpName == "And") || (onnxOpName == "Or") || (onnxOpName == "Xor");
|
||||
}
|
||||
|
||||
bool Operators::IsLoopOp(const std::string &opName)
|
||||
{
|
||||
return opName == "PastValue" || opName == "FutureValue";
|
||||
}
|
||||
|
||||
bool Operators::IsRNNOp(const std::string &opName)
|
||||
{
|
||||
return opName == "LSTM" || opName == "GRU" || opName == "RNN";
|
||||
}
|
||||
std::unordered_map<std::wstring, std::set<size_t>> Operators::_cntkBlockOPInvalidIndices = {
|
||||
{ L"Clip",{ 1, 2 } },
|
||||
{ L"LeakyReLU",{ 0, 1 } },
|
||||
|
|
|
@ -108,6 +108,9 @@ namespace CNTK
|
|||
static bool SupportBroadcast(const std::wstring& cntkOpName);
|
||||
static bool SupportBroadcastONNXOp(const std::string& onnxOpName);
|
||||
|
||||
static bool IsLoopOp(const std::string &opName);
|
||||
static bool IsRNNOp(const std::string &opName);
|
||||
|
||||
private:
|
||||
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
|
||||
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
|
||||
|
|
|
@ -0,0 +1,295 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
#include "RNNHelper.h"
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName)
|
||||
{
|
||||
if (activationName == "Relu")
|
||||
{
|
||||
return [](const Variable& x) { return ReLU(x); };
|
||||
}
|
||||
else if (activationName == "Tanh")
|
||||
{
|
||||
return [](const Variable& x) { return Tanh(x); };
|
||||
}
|
||||
else if (activationName == "Sigmoid")
|
||||
{
|
||||
return [](const Variable& x) { return Sigmoid(x); };
|
||||
}
|
||||
// else if (activationName == "Affine")
|
||||
// else if (activationName == "LeakyRelu")
|
||||
// else else if (activationName == "ThresholdedRelu")
|
||||
// else else if (activationName == "ScaledTanh")
|
||||
// else if (activationName == "HardSigmoid")
|
||||
else if (activationName == "Elu")
|
||||
{
|
||||
return [](const Variable& x) { return ELU(x); };
|
||||
}
|
||||
else if (activationName == "Softsign")
|
||||
{
|
||||
return [](const Variable& x) { return Softsign(x); };
|
||||
}
|
||||
else if (activationName == "Softplus")
|
||||
{
|
||||
return [](const Variable& x) { return Softplus(x); };
|
||||
}
|
||||
else
|
||||
{
|
||||
CNTK::LogicError("LSTM does not support activation: %s", activationName.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
|
||||
float activation_alpha)
|
||||
{
|
||||
if (activationName == "LeakyRelu")
|
||||
{
|
||||
return [activation_alpha](const Variable& x) { return LeakyReLU(x, activation_alpha); };
|
||||
}
|
||||
else
|
||||
{
|
||||
return ActivationMap(activationName);
|
||||
}
|
||||
}
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
|
||||
float activation_alpha, float activation_beta)
|
||||
{
|
||||
if (activationName == "HardSigmoid")
|
||||
{
|
||||
return [activation_alpha, activation_beta](const Variable& x) { return HardSigmoid(x, activation_alpha, activation_beta); };
|
||||
}
|
||||
else
|
||||
{
|
||||
return ActivationMap(activationName, activation_alpha);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
GetActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||
{
|
||||
if (activations.size() < (direction + 1) * LSTMActivationCount)
|
||||
CNTK::LogicError("LSTM activations shall be %d or %d of strings", LSTMActivationCount, LSTMActivationCount * 2);
|
||||
|
||||
//
|
||||
int iofActivationIndex = direction * LSTMActivationCount + LSTMActivationFIndex;
|
||||
int cellActivation = direction * LSTMActivationCount + LSTMActivationGIndex;
|
||||
int hiddenActivationIndex = direction * LSTMActivationCount + LSTMActivationHIndex;
|
||||
|
||||
// ONNX spec is not clear on how activation alpha and beta is set.
|
||||
// Here we assume that if they are set, they are set for all activations, regardless whether
|
||||
// an activation needs those values or not.
|
||||
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
|
||||
bool hasBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
|
||||
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||
if (hasBeta)
|
||||
{
|
||||
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
|
||||
cellActivationOp = ActivationMap(activations[cellActivation], activation_alpha[cellActivation], activation_beta[cellActivation]);
|
||||
hiddenActivationOp = ActivationMap(activations[hiddenActivationIndex], activation_alpha[hiddenActivationIndex], activation_beta[hiddenActivationIndex]);
|
||||
}
|
||||
else if (hasAlpha)
|
||||
{
|
||||
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex]);
|
||||
cellActivationOp = ActivationMap(activations[cellActivation], activation_alpha[cellActivation]);
|
||||
hiddenActivationOp = ActivationMap(activations[hiddenActivationIndex], activation_alpha[hiddenActivationIndex]);
|
||||
}
|
||||
else
|
||||
{
|
||||
iofActivationOp = ActivationMap(activations[iofActivationIndex]);
|
||||
cellActivationOp = ActivationMap(activations[cellActivation]);
|
||||
hiddenActivationOp = ActivationMap(activations[hiddenActivationIndex]);
|
||||
}
|
||||
|
||||
return std::make_tuple(iofActivationOp, cellActivationOp, hiddenActivationOp);
|
||||
|
||||
}
|
||||
|
||||
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
||||
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
|
||||
Variable prevOutput, Variable prevCellState,
|
||||
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
|
||||
{
|
||||
size_t outputDim = prevOutput.Shape()[0];
|
||||
int stacked_dim = (int)outputDim;
|
||||
|
||||
FunctionPtr proj4;
|
||||
if (B.IsInitialized())
|
||||
{
|
||||
proj4 = Plus(Plus(B, Times(W, input)), Times(R, prevOutput));
|
||||
}
|
||||
else
|
||||
{
|
||||
proj4 = Plus(Times(W, input), Times(R, prevOutput));
|
||||
}
|
||||
|
||||
// CNTK weight and bias are in icfo order.
|
||||
std::vector<Axis> stack_axis({ Axis(-1) });
|
||||
const int IGateIndex = 0, CGateIndex = 1, FGateIndex = 2, OGateIndex = 3;
|
||||
FunctionPtr it_proj = Slice(proj4, stack_axis, { IGateIndex * stacked_dim }, { (IGateIndex + 1) * stacked_dim });
|
||||
FunctionPtr bit_proj = Slice(proj4, stack_axis, { CGateIndex * stacked_dim }, { (CGateIndex + 1) * stacked_dim });
|
||||
FunctionPtr ft_proj = Slice(proj4, stack_axis, { FGateIndex * stacked_dim }, { (FGateIndex + 1) * stacked_dim });
|
||||
FunctionPtr ot_proj = Slice(proj4, stack_axis, { OGateIndex * stacked_dim }, { (OGateIndex + 1) * stacked_dim });
|
||||
|
||||
bool hasPeephole = Ci.IsInitialized();
|
||||
|
||||
// Input gate
|
||||
auto it = hasPeephole ? iofActivationOp(it_proj + ElementTimes(Ci, prevCellState)) : Sigmoid(it_proj);
|
||||
auto bit = ElementTimes(it, cellActivationOp(bit_proj));
|
||||
|
||||
auto ft = hasPeephole ? iofActivationOp(ft_proj + ElementTimes(Cf, prevCellState)) : Sigmoid(ft_proj);
|
||||
auto bft = ElementTimes(ft, prevCellState);
|
||||
|
||||
auto ct = Plus(bft, bit);
|
||||
|
||||
auto ot = hasPeephole ? iofActivationOp(ot_proj + ElementTimes(Co, ct)) : Sigmoid(ot_proj);
|
||||
auto ht = ElementTimes(ot, hiddenActivationOp(ct));
|
||||
|
||||
auto c = ct;
|
||||
auto h = ht;
|
||||
|
||||
return{ h, c };
|
||||
}
|
||||
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "BlockFunction.h"
|
||||
|
||||
std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
|
||||
const NDShape& cellShape,
|
||||
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookC,
|
||||
Constant &W, Constant &R, Constant &B,
|
||||
Constant &Ci, Constant &Cf, Constant &Co)
|
||||
{
|
||||
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||
auto dc = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
|
||||
|
||||
auto LSTMCell = LSTMPCell(
|
||||
inputPlaceholder,
|
||||
iofActivationOp, cellActivationOp, hiddenActivationOp,
|
||||
dh, dc, W, R, B, Ci, Cf, Co);
|
||||
|
||||
auto actualDh = recurrenceHookH(LSTMCell.first);
|
||||
auto actualDc = recurrenceHookC(LSTMCell.second);
|
||||
|
||||
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
|
||||
LSTMCell.first->ReplacePlaceholders({ { inputPlaceholder , input}, { dh, actualDh },{ dc, actualDc } });
|
||||
return std::make_tuple(LSTMCell.first, LSTMCell.second);
|
||||
}
|
||||
|
||||
const std::vector<Variable> FindByNameHint(const std::vector<Variable> &inputs, const std::string &hint)
|
||||
{
|
||||
std::vector<Variable> variables;
|
||||
for (auto v : inputs)
|
||||
{
|
||||
if (ToString(v.Name()).find(hint) != -1)
|
||||
{
|
||||
variables.push_back(v);
|
||||
}
|
||||
}
|
||||
return variables;
|
||||
}
|
||||
|
||||
Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDirections,
|
||||
const std::string &nameHint, DataType datatype)
|
||||
{
|
||||
Variable initialVariable = datatype == DataType::Double ? Constant::Scalar(0.0) : Constant::Scalar(0.0f);
|
||||
const std::vector<Variable> initialVariables = FindByNameHint(inputs, nameHint);
|
||||
if (numDirections == 1 && initialVariables.size() >= 1)
|
||||
{
|
||||
initialVariable = initialVariables[0];
|
||||
}
|
||||
else if (numDirections == 2 && initialVariables.size() >= 2)
|
||||
{
|
||||
initialVariable = initialVariables[1];
|
||||
}
|
||||
|
||||
return initialVariable;
|
||||
}
|
||||
|
||||
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
||||
{
|
||||
int numDirections = direction == LSTMDirectionBidirection ? 2 : 1;
|
||||
std::vector<FunctionPtr> outputHs;
|
||||
for (int dir = 0; dir < numDirections; dir++)
|
||||
{
|
||||
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
|
||||
|
||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R
|
||||
Variable X = inputs[0];
|
||||
Variable W = inputs[1 + dir];
|
||||
Variable R = inputs[1 + numDirections + dir];
|
||||
Variable B;
|
||||
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
||||
if (numDirections == 1 && biasVariables.size() >= 1)
|
||||
B = biasVariables[0];
|
||||
else if (numDirections == 2 && biasVariables.size() == 2)
|
||||
B = biasVariables[1];
|
||||
|
||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialCNameHint, X.GetDataType());
|
||||
Variable initCVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialHNameHint, X.GetDataType());
|
||||
|
||||
std::vector<Variable> peepholeVariables = FindByNameHint(inputs, LSTMInputPeepholeNameHint);
|
||||
Variable Ci, Cf, Co;
|
||||
if (peepholeVariables.size() != 0 && peepholeVariables.size() != LSTMPeepholeCount && peepholeVariables.size() != 2 * LSTMPeepholeCount)
|
||||
{
|
||||
CNTK::LogicError("Peephole Variable count (%d) should be 0, 1 or 2 times the number of peephole factors (%d).",
|
||||
(int)(peepholeVariables.size()), (int)LSTMPeepholeCount);
|
||||
}
|
||||
else if (numDirections == 1 && peepholeVariables.size() >= LSTMPeepholeCount)
|
||||
{
|
||||
Ci = peepholeVariables[LSTMPeepholeCountCiIndex];
|
||||
Co = peepholeVariables[LSTMPeepholeCountCoIndex];
|
||||
Cf = peepholeVariables[LSTMPeepholeCountCfIndex];
|
||||
}
|
||||
else if (numDirections == 2 && peepholeVariables.size() == numDirections * LSTMPeepholeCount)
|
||||
{
|
||||
Ci = peepholeVariables[LSTMPeepholeCount + LSTMPeepholeCountCiIndex];
|
||||
Co = peepholeVariables[LSTMPeepholeCount + LSTMPeepholeCountCoIndex];
|
||||
Cf = peepholeVariables[LSTMPeepholeCount + LSTMPeepholeCountCfIndex];
|
||||
}
|
||||
|
||||
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
|
||||
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
|
||||
// as two separate LSTM. Therefore we can divide the dimension of the first axis
|
||||
// by 4 to get the hidden size.
|
||||
int hiddenDim = W.Shape()[0] / 4;
|
||||
|
||||
FunctionPtr outputH;
|
||||
FunctionPtr outputC;
|
||||
|
||||
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||
bool go_backwards = direction == LSTMDirectionReverse || (numDirections == 2 && dir == 1);
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> futureValueRecurrenceHook;
|
||||
if (go_backwards)
|
||||
futureValueRecurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
||||
else
|
||||
futureValueRecurrenceHook = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
|
||||
|
||||
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
|
||||
X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp,
|
||||
futureValueRecurrenceHook, futureValueRecurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B,
|
||||
(Constant &)Ci, (Constant &)Cf, (Constant &)Co);
|
||||
outputHs.push_back(outputH);
|
||||
}
|
||||
if (outputHs.size() == 1)
|
||||
return outputHs[0];
|
||||
else
|
||||
{
|
||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
#pragma once
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
// originally from CNTK\Tests\EndToEndTests\CNTKv2Library\Common\Common.h
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
|
||||
#include "proto/onnx/core/model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "CNTKLibrary.h"
|
||||
#include <functional>
|
||||
|
||||
using namespace CNTK;
|
||||
using namespace ONNXIR;
|
||||
|
||||
const std::string LSTMInputBiasNameHint = "_bias_";
|
||||
const std::string LSTMInputInitialHNameHint = "_initial_h_";
|
||||
const std::string LSTMInputInitialCNameHint = "_initial_c_";
|
||||
const std::string LSTMInputPeepholeNameHint = "_peephole_";
|
||||
|
||||
enum
|
||||
{
|
||||
LSTMInputIndexX = 0,
|
||||
LSTMInputIndexW = 1,
|
||||
LSTMInputIndexH = 2,
|
||||
LSTMInputIndexB = 3,
|
||||
LSTMInputIndexSequenceLens = 4,
|
||||
LSTMInputIndexinitial_h = 5,
|
||||
LSTMInputIndexinitial_c = 6,
|
||||
LSTMInputIndexP = 7
|
||||
};
|
||||
|
||||
enum
|
||||
{
|
||||
LSTMActivationFIndex = 0,
|
||||
LSTMActivationGIndex = 1,
|
||||
LSTMActivationHIndex = 2,
|
||||
LSTMActivationCount = 3
|
||||
};
|
||||
|
||||
enum {
|
||||
LSTMPeepholeCountCiIndex = 0,
|
||||
LSTMPeepholeCountCoIndex = 1,
|
||||
LSTMPeepholeCountCfIndex = 2,
|
||||
LSTMPeepholeCount = 3
|
||||
};
|
||||
|
||||
typedef enum {
|
||||
Forward,
|
||||
Backward,
|
||||
} LSTMDirection;
|
||||
|
||||
enum
|
||||
{
|
||||
CNTKLSTMBiasIndex = 0,
|
||||
CNTKLSTMWeightIndex = 1,
|
||||
CNTKLSTMHiddenWeightIndex = 2
|
||||
};
|
||||
|
||||
enum
|
||||
{
|
||||
CNTKLSTMOutputYhIndex = 0,
|
||||
CNTKLSTMOutputChIndex = 1
|
||||
};
|
||||
const string LSTMDirectionBidirection = "bidirectional";
|
||||
const string LSTMDirectionReverse = "reverse";
|
||||
const string LSTMDirectionForward = "forward";
|
||||
|
||||
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||
const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta);
|
|
@ -59,6 +59,9 @@ def verify_one_input(model, data, tmpdir, name):
|
|||
o0 = model.eval({model.arguments[0]:data})
|
||||
o1 = loaded_model.eval({loaded_model.arguments[0]:data})
|
||||
|
||||
if (type(o0) is list):
|
||||
o0 = o0[0]
|
||||
|
||||
assert np.allclose(o0, o1)
|
||||
|
||||
validation_filename = os.path.join(str(tmpdir), name + R'_validation.onnx')
|
||||
|
@ -515,6 +518,61 @@ def test_LRN(tmpdir):
|
|||
model = C.local_response_normalization(x_r, 2, 1.0, 0.0001, 0.75)
|
||||
verify_one_input(model, img, tmpdir, 'LRN_1')
|
||||
|
||||
#LSTM
|
||||
from cntk.layers import *
|
||||
from itertools import product
|
||||
|
||||
def CreateLSTMModel(activation,
|
||||
peepholes,
|
||||
self_stabilization,
|
||||
cell_dim,
|
||||
initial_state):
|
||||
return Sequential([
|
||||
Recurrence(LSTM(cell_dim,
|
||||
use_peepholes = peepholes,
|
||||
activation = activation,
|
||||
enable_self_stabilization = self_stabilization),
|
||||
initial_state = initial_state)
|
||||
])
|
||||
|
||||
# lstm attributes
|
||||
use_peepholes_options = [False]
|
||||
enable_self_stabilization_options = [False]
|
||||
activation_options = [C.tanh]
|
||||
|
||||
#Recurrence attributes
|
||||
initial_state_options = [0]
|
||||
|
||||
input_dim = 2
|
||||
cell_dim = 3
|
||||
batch_size = 1
|
||||
sequence_len = 5
|
||||
|
||||
def MakeLSTMNameFromConfig(use_peepholes, enable_self_stabilization, initial_state, activtion):
|
||||
model_name = 'LSTM.' + activtion.__name__
|
||||
if (use_peepholes):
|
||||
model_name += '.peephole'
|
||||
if(enable_self_stabilization):
|
||||
model_name += '.stabilize'
|
||||
if (initial_state != 0):
|
||||
model_name += '.initial'
|
||||
return model_name
|
||||
|
||||
def test_LSTM(tmpdir):
|
||||
for config in list(product(use_peepholes_options, enable_self_stabilization_options,
|
||||
initial_state_options, activation_options)):
|
||||
model_filename = MakeLSTMNameFromConfig(*config)
|
||||
use_peepholes, enable_self_stabilization, initial_state, activation = config
|
||||
|
||||
x = C.input_variable(input_dim, dynamic_axes=[Axis.default_batch_axis(), C.Axis('sequenceAxis')])
|
||||
LSTMmodel = CreateLSTMModel(peepholes = use_peepholes,
|
||||
activation = activation,
|
||||
initial_state = initial_state,
|
||||
cell_dim = cell_dim,
|
||||
self_stabilization = enable_self_stabilization)(x)
|
||||
data = np.random.uniform(low=0.0, high=1.0, size=(batch_size, sequence_len, input_dim)).astype('f')
|
||||
verify_one_input(LSTMmodel, data, tmpdir, model_filename)
|
||||
|
||||
#MatMul
|
||||
def test_MatMul(tmpdir):
|
||||
data0 = np.asarray([[1,2],[3,4]], dtype=np.float32)
|
||||
|
|
Загрузка…
Ссылка в новой задаче