diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 3124874c1..4f1abc6cf 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -107,7 +107,7 @@ namespace CNTK static ONNXIR::Node *InsertReshapeNodeToCNTKFunction(const FunctionPtr &src, ONNXIR::Node* node, const std::vector &shape, ONNXIR::Graph* graph); // - // Create a LSTM node. + // methods to create a RNN/LSTM/GRU node. // static ONNXIR::Node* CreateLSTMNode(const FunctionPtr& src, ONNXIR::Graph* graph, @@ -119,19 +119,29 @@ namespace CNTK std::unordered_map& functionNodes, std::unordered_map& variableNodes, const std::unordered_map& compositeOutputsMap); + static ONNXIR::Node *CreateRNNNode(const FunctionPtr &src, + ONNXIR::Graph* graph, + std::unordered_map& functionNodes, + std::unordered_map& variableNodes, + const std::unordered_map& compositeOutputsMap); static void PrepareRNNInput(const Variable &X, std::vector &nodeInputs); static void PrepareLSTMInitialStateNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &initialVariables, int batchSize, int cellSize, const std::string &uid, std::vector &nodeInputs); - static void PrepareGRUWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, - const std::vector &Ws, std::vector &nodeInputs); + static void PrepareRNNWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Ws, std::vector &nodeInputs, + std::function &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType)> weightConverter); static void PrepareGRUZRHWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Rs, const std::vector &Rh1s, std::vector &nodeInputs); static void PrepareGRUBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Bs, std::vector &nodeInputs); + static void PrepareRNNBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Bs, std::vector &nodeInputs); + static void PrepareLSTMWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Ws, double *stabilizerConstants, std::vector &nodeInputs); static void PrepareLSTMBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, @@ -157,7 +167,7 @@ namespace CNTK onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); - static void CopyGRUBiasTensors(const std::vector &srcTensors, + static void CopyRNNBiasTensors(const std::vector &srcTensors, onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); static void CopyGRUWeightTensors(const std::vector &srcTensors, @@ -167,6 +177,9 @@ namespace CNTK const std::vector &srcZRTensors, const std::vector &srcHTensors, onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); + static void CopyRNNWeightTensors(const std::vector &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); + static void FillTensorWithScalar(const std::vector& src, onnx::TensorProto& dst, const std::vector dstShape); // @@ -434,7 +447,7 @@ void AppendCNTKBiasWeightToONNXTensor(DType *data, const NDShape &shape, onnx::T row -= 2 * cell_size; } - // soruce is collmn major + // source is collmn major int src_index = row; if (typeid(DType) == typeid(float)) *(dst.mutable_float_data()->Add()) = (float)data[src_index]; @@ -502,7 +515,7 @@ void AppendCNTKWeightToONNXTensor(DType *data, const NDShape &shape, onnx::Tenso row -= 2 * cell_size; } - // soruce is collum major + // source is column major int src_index = LSTMWeightDimensionHiddenMultiplier * cell_size * col + row; if (typeid(DType) == typeid(float)) *(dst.mutable_float_data()->Add()) = (float)(data[src_index] * stabilizer); @@ -618,7 +631,7 @@ void CNTKToONNXHelper::CopyTensorsWithMultipliers(const std::vector &srcTensors, +void CNTKToONNXHelper::CopyRNNBiasTensors(const std::vector &srcTensors, onnx::TensorProto& dst, const onnx::TypeProto &inputArgType) { if (srcTensors.empty()) @@ -688,7 +701,7 @@ void CNTKToONNXHelper::CopyGRUWeightTensors(const std::vector &s int row = targetIndex / input_size, col = targetIndex % input_size; - // soruce is collum major + // source is column major int srcIndex = 3 * cell_size * col + row; AddDataElementArrayViewToTensorProto(srcTemp, srcIndex, dst); } @@ -755,6 +768,42 @@ void CNTKToONNXHelper::CopyGRUStateWeightTensors( CopyShapeTypeProtoToTensorProto(inputArgType, dst); } +void CNTKToONNXHelper::CopyRNNWeightTensors(const std::vector &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType) +{ + if (srcTensors.empty()) + return; + + DataType dataType = srcTensors[0]->GetDataType(); + SetTensorType(dst, dataType); + + for (int i = 0; i < srcTensors.size(); i++) + { + auto srcTemp = srcTensors[i]->DeepClone(); + auto srcShape = srcTemp->Shape(); + + int cell_size = srcShape[0]; + int input_size = srcShape[1]; + + // This is our own copy so move it to the CPU. + srcTemp->ChangeDevice(DeviceDescriptor::CPUDevice()); + + auto totalSize = srcShape.TotalSize(); + for (size_t targetIndex = 0; targetIndex < totalSize; targetIndex++) + { + // row major layout + int row = targetIndex / input_size, + col = targetIndex % input_size; + + // source is column major + int srcIndex = cell_size * col + row; + AddDataElementArrayViewToTensorProto(srcTemp, srcIndex, dst); + } + } + + CopyShapeTypeProtoToTensorProto(inputArgType, dst); +} + void CNTKToONNXHelper::CopyTensor(const NDArrayViewPtr src, onnx::TensorProto& dst, onnx::TypeProto *inputArgType /*=nullptr*/) { auto dataType = src->GetDataType(); @@ -1496,31 +1545,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, std::unordered_map& variableNodes, const std::unordered_map& compositeOutputsMap) { - // sanity check: - std::vector lstms; - if (src->OpName() == L"LSTM") - { - lstms.push_back(src); - } - else if (src->OpName() == L"Splice") // src is a Splice op with inputs from two LSTM ops. - { - for (auto &input : src->Inputs()) - { - lstms.push_back(input.Owner()); - } - } - else - { - LogicError("An LSTM op should start with an LSTM op (single direction) or a Splice op (bidirectional)."); - } - - // For single direction LSTM, lstms.size() == 1. For bidirectional LSTM, lstms.size() == 2. - // It is an error otherwise. - if (lstms.size() == 0 || lstms.size() > 2 || - std::any_of(lstms.cbegin(), lstms.cend(), [](const FunctionPtr &f) {return f->OpName() != L"LSTM"; })) - { - LogicError("Invalid number of LSTM ops to construct an ONNX LSTM node."); - } + std::vector lstms = GetRNNBlocksFromSingleOrBidirectionalRNN(src, "LSTM"); // order forward, backward std::map directionCount({ { RNNDirection::Forward, 0 } ,{ RNNDirection::Backward, 0 } }); @@ -1789,7 +1814,7 @@ void CNTKToONNXHelper::PrepareGRUBiasNode(ONNXIR::Graph* graph, std::unordered_m onnx::TensorProto dstTensor; - CopyGRUBiasTensors(srcTensors, dstTensor, inputArgType); + CopyRNNBiasTensors(srcTensors, dstTensor, inputArgType); variableNode->AddAttribute("value", dstTensor); nodeInputs.push_back(inputArg); @@ -1828,8 +1853,11 @@ void CNTKToONNXHelper::PrepareGRUZRHWeightNode(ONNXIR::Graph* graph, std::unorde variableNodes.emplace(Rzrs[0], variableNode); } -void CNTKToONNXHelper::PrepareGRUWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, - const std::vector &Ws, std::vector &nodeInputs) + +void CNTKToONNXHelper::PrepareRNNWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Ws, std::vector &nodeInputs, + std::function &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType)> weightConverter) { // TODO: sanity check for all variables to have the same shape and data types. bool doReverseVec = false; @@ -1852,7 +1880,7 @@ void CNTKToONNXHelper::PrepareGRUWeightNode(ONNXIR::Graph* graph, std::unordered onnx::TensorProto dstTensor; - CopyGRUWeightTensors(srcTensors, dstTensor, inputArgType); + weightConverter(srcTensors, dstTensor, inputArgType); variableNode->AddAttribute("value", dstTensor); nodeInputs.push_back(inputArg); @@ -1865,31 +1893,7 @@ ONNXIR::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src, std::unordered_map& variableNodes, const std::unordered_map& compositeOutputsMap) { - // sanity check: - std::vector grus; - if (src->OpName() == L"GRU") - { - grus.push_back(src); - } - else if (src->OpName() == L"Splice") // src is a Splice op with inputs from two LSTM ops. - { - for (auto &input : src->Inputs()) - { - grus.push_back(input.Owner()); - } - } - else - { - LogicError("An GRU op should start with an GRU op (single direction) or a Splice op (bidirectional)."); - } - - // For single direction GRU, grus.size() == 1. For bidirectional GRU, grus.size() == 2. - // It is an error otherwise. - if (grus.size() == 0 || grus.size() > 2 || - std::any_of(grus.cbegin(), grus.cend(), [](const FunctionPtr &f) {return f->OpName() != L"GRU"; })) - { - LogicError("Invalid number of GRU ops to construct an ONNX GRU node."); - } + std::vector grus = GetRNNBlocksFromSingleOrBidirectionalRNN(src, "GRU"); // order forward, backward std::map directionCount({ { RNNDirection::Forward, 0 } ,{ RNNDirection::Backward, 0 } }); @@ -1932,7 +1936,6 @@ ONNXIR::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src, initialHs[directionIndex] = initStateH; - activations[directionIndex * GRUActivationCount + GRUActivationFIndex] = f_activation; activations[directionIndex * GRUActivationCount + GRUActivationGIndex] = g_activation; @@ -1978,7 +1981,7 @@ ONNXIR::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src, // inputs std::vector nodeInputs; PrepareRNNInput(Xs[0], nodeInputs); - PrepareGRUWeightNode(graph, variableNodes, Ws, nodeInputs); + PrepareRNNWeightNode(graph, variableNodes, Ws, nodeInputs, CopyGRUWeightTensors); PrepareGRUZRHWeightNode(graph, variableNodes, Rzrs, Rhs, nodeInputs); { @@ -2071,6 +2074,219 @@ ONNXIR::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src, return squeezedLSTMNode; } +void CNTKToONNXHelper::PrepareRNNBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Bs, std::vector &nodeInputs) +{ + // TODO: sanity check for all variables to have the same shape and data types. + bool doReverseVec = false; + int numDirections = Bs.size(); + int hiddenSize = Bs[0].Shape()[0]; + + std::vector shape({ numDirections, 2 * hiddenSize }); + + // ONNX GRU spec has 2 bias, for forward and backward. + onnx::TypeProto inputArgType = ToTypeProto(shape, doReverseVec); + UpdateONNXType(Bs[0].GetDataType(), inputArgType); + ONNXIR::NodeArg inputArg(ToString(Bs[0].Uid()), &inputArgType); + std::vector varOutputs({ inputArg }); + std::vector varInputs; + std::string inputName = inputArg.Name(); + ONNXIR::Node* variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs); + + std::vector srcTensors; + for (int i = 0; i < Bs.size(); i++) + { + const Variable &variable = Bs[i]; + srcTensors.push_back(variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value()); + } + + onnx::TensorProto dstTensor; + + CopyRNNBiasTensors(srcTensors, dstTensor, inputArgType); + variableNode->AddAttribute("value", dstTensor); + nodeInputs.push_back(inputArg); + + variableNodes.emplace(Bs[0], variableNode); +} + + +ONNXIR::Node *CNTKToONNXHelper::CreateRNNNode(const FunctionPtr &src, + ONNXIR::Graph* graph, + std::unordered_map& functionNodes, + std::unordered_map& variableNodes, + const std::unordered_map& compositeOutputsMap) +{ + std::vector rnns = GetRNNBlocksFromSingleOrBidirectionalRNN(src, "RNNStep"); + + // order forward, backward + std::map directionCount({ { RNNDirection::Forward, 0 } ,{ RNNDirection::Backward, 0 } }); + + // The following construct refers to ONNX spec: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#lstm + // specifically, for attrubute and variable dimension. + // We use the term from the spec as possible as we can to maintain a close correlation + // to the ONNX specification. + + int num_directions = rnns.size(); + // A list of 3 (or 6 if bidirectional) activation functions for input, output, forget, cell, and hidden. + std::vector activations(num_directions); + + // TODO: + // In principle all these variables shall be treated as either constant or op output. + // In reality except X, all other inputs to LSTM can be treated as constant. + std::vector Xs(num_directions), Ws(num_directions), Rs(num_directions), + Bs(num_directions), initialHs(num_directions); + + std::vector Yhs(rnns.size()); + + for (std::vector::const_iterator itRNNBlock = rnns.cbegin(); itRNNBlock != rnns.cend(); itRNNBlock++) + { + // src has to be an RNN node. + const FunctionPtr& rnn = *itRNNBlock; + std::vector inputs = rnn->Inputs(); + if (inputs.size() != CNTKRNNInputCount) + LogicError("A RNN block does not have expected input count (%d). Actual input count is %d", (int)CNTKRNNInputCount, (int)inputs.size()); + + string activation; + RNNDirection direction; + Variable initStateH; + TraceRNNPathes(rnn, activation, direction, initStateH); + + directionCount[direction]++; + + int directionIndex = rnns.size() == 1 ? 0 : (direction ? 1 : 0); + + initialHs[directionIndex] = initStateH; + + activations[directionIndex] = activation; + + Xs[directionIndex] = inputs[CNTKRNNInputIndex]; + + Ws[directionIndex] = inputs[CNTKRNNWeightIndex]; + + Rs[directionIndex] = inputs[CNTKRNNHweightIndex]; + + Bs[directionIndex] = inputs[CNTKRNNBiasIndex]; + + std::vector outputs = rnn->Outputs(); + + Yhs[directionIndex] = outputs[CNTKRNNOutputYhIndex]; + } + + SanityCheckForConstantOrParameters(Ws); + SanityCheckForConstantOrParameters(Rs); + SanityCheckForConstantOrParameters(Bs); + + // ensure that if there is one direction, it is not backward. + // if there two directions, they are forward and backward, and + // that the inputs (Xs) are the same. + if (std::any_of(directionCount.begin(), directionCount.end(), [](std::map::value_type &v) {return v.second > 1; })) + { + LogicError("RNN node is invalid because there should be no more than one path in each direction."); + } + if (rnns.size() == 2 && Xs[0] != Xs[1]) + { + LogicError("Bi-directional RNN node is invalid because the two RNN nodes do not share one same input."); + } + + string direction = DeriveDirectionString(rnns, directionCount); + + // an RNN output size is the hidden size + int hidden_size = rnns[0]->Outputs()[0].Shape()[0]; + + // inputs + std::vector nodeInputs; + PrepareRNNInput(Xs[0], nodeInputs); + PrepareRNNWeightNode(graph, variableNodes, Ws, nodeInputs, CopyRNNWeightTensors); + PrepareRNNWeightNode(graph, variableNodes, Rs, nodeInputs, CopyRNNWeightTensors); + + { + bool hasBias = std::all_of(Bs.begin(), Bs.end(), [](Variable &v) {return v.IsInitialized(); }); + if (hasBias) + { + PrepareRNNBiasNode(graph, variableNodes, Bs, nodeInputs); + } + else + { + AddEmptyInput(nodeInputs); + } + + { + // sequence_lens is not supported + AddEmptyInput(nodeInputs); + } + + bool has_initial_h = std::all_of(initialHs.begin(), initialHs.end(), [](Variable &v) {return v.IsInitialized(); }); + if (has_initial_h) + { + std::string hiddenUid = ToString(Yhs[0].Uid()) + "_initial_h"; + PrepareLSTMInitialStateNode(graph, variableNodes, initialHs, FreeBatchSize, hidden_size, hiddenUid, nodeInputs); + } + else + { + AddEmptyInput(nodeInputs); + } + } + + const int output_sequence = RNNOutputSequence; // RNN in CNTK always output full sequence of output + std::vector nodeOutputs; + { + if (output_sequence == 1) + { + std::string nodeName; + if (rnns.size() == 1) + nodeName = ToString(Yhs[0].Uid()); + else + nodeName = ToString(src->Output().Uid()); + + auto outputArgType = ToTypeProto(std::vector({ FreeSequenceLen, (int)Yhs.size(), FreeBatchSize, (int)Yhs[0].Shape()[0] }), false); + UpdateONNXType(Yhs[0].GetDataType(), outputArgType); + ONNXIR::NodeArg outputArg(nodeName, &outputArgType); + nodeOutputs.push_back(outputArg); + } + else + { + ONNXIR::NodeArg outputArg("", nullptr); + nodeOutputs.push_back(outputArg); + } + + { + Variable Yh = Yhs[0]; + std::string nodeName = ToString(Yh.Uid()) + "_h"; + + const int batchSize = 1; + const bool doReverseVec = false; + auto outputArgType = ToTypeProto(std::vector({ (int)Yhs.size(), batchSize, (int)Yh.Shape()[0] }), doReverseVec); + UpdateONNXType(Yh.GetDataType(), outputArgType); + ONNXIR::NodeArg outputArg(nodeName, &outputArgType); + nodeOutputs.push_back(outputArg); + } + } + + if (Xs[0].Owner().get() != nullptr) + CreateNode(Xs[0].Owner(), graph, functionNodes, variableNodes, compositeOutputsMap); + + auto nodeName = src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name()); + ONNXIR::Node *rnnNode = graph->AddNode(nodeName, "RNN", "", nodeInputs, nodeOutputs); + + rnnNode->AddAttribute("activations", activations); + rnnNode->AddAttribute("direction", direction); + rnnNode->AddAttribute("hidden_size", (int64_t)hidden_size); + rnnNode->AddAttribute("output_sequence", (int64_t)output_sequence); + + //// TODO: make bidirectional RNN work by figuring out output data + //// layout transpose in InsertReshapeNodeToCNTKFunction. + if (rnns.size() == 2) + NOT_IMPLEMENTED; + + //// TODO: uncomment this code once LotusRT output shape matches ONNX + //// squeeze direction axis out. This is safe because it is not bi-directional node. + std::vector shape({ FreeSequenceLen, 1, hidden_size }); + ONNXIR::Node *squeezedRNNNode = InsertReshapeNodeToCNTKFunction(src, rnnNode, shape, graph); + functionNodes.emplace(src, squeezedRNNNode); + return squeezedRNNNode; +} + ONNXIR::Node *CNTKToONNXHelper::AddReshapeNode(const ONNXIR::NodeArg &nodeArg, const std::vector &newShape, const std::string &outArgName, ONNXIR::Graph* graph) { ONNXIR::NodeArg outputArg(outArgName, nullptr); @@ -2176,7 +2392,11 @@ ONNXIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src, // return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); //} //else - if (opName == "GRU") + if (opName == "RNNStep") + { + return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); + } + else if (opName == "GRU") { return CreateGRUNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp index df2b10140..f52545fee 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp @@ -641,7 +641,7 @@ std::vector CreateRNNConstant( { case LSTMInputIndexX: // X, should not come to here - return inputs; + CNTK::LogicError("input to a recurrent node shall not be a constant"); case LSTMInputIndexW: case LSTMInputIndexH: // W, R: @@ -659,7 +659,7 @@ std::vector CreateRNNConstant( for (int dir = 0; dir < num_directions; dir++) { - std::string nodeName = name + (index == 1 ? "_W_" : "_R_") + (char)dir; + std::string nodeName = name + (index == 1 ? "_W_" : "_R_") + (char)('0' + dir); int totalSizePerDirection = rows * cols; // TODO: what about double? @@ -706,7 +706,7 @@ std::vector CreateRNNConstant( NDShape weightShape({ (size_t)(4 * cell_size) }); for (int dir = 0; dir < num_directions; dir++) { - std::string nodeName = name + std::string(1, (char)dir) + LSTMInputBiasNameHint; + std::string nodeName = name + std::string(1, (char)('0' + dir)) + LSTMInputBiasNameHint; int totalSizePerDirection = 4 * cell_size; float *data = new float[totalSizePerDirection]; for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++) @@ -726,7 +726,7 @@ std::vector CreateRNNConstant( row -= 2 * cell_size; } - // soruce is collmn major + // source is column major int src_index = row; // "fuse" data[targetIndex] = @@ -760,7 +760,7 @@ std::vector CreateRNNConstant( NDShape weightShape({ (size_t)(cell_size) }); for (int dir = 0; dir < num_directions; dir++) { - std::string nodeName = name + std::string(1, (char)dir); + std::string nodeName = name + std::string(1, (char)('0' + dir)); if (index == 5) nodeName += LSTMInputInitialHNameHint; else @@ -787,7 +787,7 @@ std::vector CreateRNNConstant( for (int i = 0; i < 3; i++) { std::string nodeName = name + ((i == 0) ? "_i" : ((i == 1) ? "_o" : "_f")) + - std::string(1, (char)dir) + LSTMInputPeepholeNameHint; + std::string(1, (char)('0' + dir)) + LSTMInputPeepholeNameHint; float *data = new float[cell_size]; NDShape weightShape({ (size_t)(cell_size) }); for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++) @@ -800,9 +800,8 @@ std::vector CreateRNNConstant( } return inputs; } - break; default: - CNTK::LogicError("CreateRNNConstant received unepxpeted index: %d", index); + CNTK::LogicError("CreateRNNConstant received unexpected index: %d", index); } } else if (parentONNXOpName == "GRU") @@ -812,7 +811,7 @@ std::vector CreateRNNConstant( { case GRUInputIndexX: // X, should not come to here - return inputs; + CNTK::LogicError("input to a recurrent node shall not be a constant"); case GRUInputIndexW: { // see ONNX spec for the tensor shape @@ -828,7 +827,7 @@ std::vector CreateRNNConstant( for (int dir = 0; dir < num_directions; dir++) { - std::string nodeName = name + "_W_" + (char)dir; + std::string nodeName = name + "_W_" + (char)('0' + dir); int totalSizePerDirection = rows * cols; // TODO: what about double? @@ -863,8 +862,8 @@ std::vector CreateRNNConstant( inputs.resize(num_directions * 2); for (int dir = 0; dir < num_directions; dir++) { - std::string hNodeName = name + "_H_" + (char)dir; - std::string h1NodeName = name + "_H1_" + (char)dir; + std::string hNodeName = name + "_H_" + (char)('0' + dir); + std::string h1NodeName = name + "_H1_" + (char)('0' + dir); int totalSizePerDirection = rows * cols; float *hData = new float[hShape.TotalSize()]; @@ -900,18 +899,18 @@ std::vector CreateRNNConstant( // see ONNX spec for the tensor shape int num_directions = valueProto.dims(0); int cell_size = valueProto.dims(1) / GRUBiasDimensionHiddenMultiplier; - // shape size is devided by 2 so that it only applies to input (CNTK) + // shape size is divided by 2 so that it only applies to input (CNTK) // TODO: this incompatibility needs further investigation. NDShape weightShape({ (size_t)(GRUBiasDimensionHiddenMultiplier / 2 * cell_size) }); for (int dir = 0; dir < num_directions; dir++) { - std::string nodeName = name + std::string(1, (char)dir) + LSTMInputBiasNameHint; + std::string nodeName = name + std::string(1, '0' + dir) + LSTMInputBiasNameHint; int totalSizePerDirection = GRUBiasDimensionHiddenMultiplier / 2 * cell_size; float *data = new float[totalSizePerDirection]; for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++) { int row = targetIndex; - // soruce is collmn major + // source is column major int src_index = row; // "fuse" data[targetIndex] = @@ -934,7 +933,7 @@ std::vector CreateRNNConstant( NDShape weightShape({ (size_t)(cell_size) }); for (int dir = 0; dir < num_directions; dir++) { - std::string nodeName = name + std::string(1, (char)dir) + LSTMInputInitialHNameHint; + std::string nodeName = name + std::string(1, (char)('0' + dir)) + LSTMInputInitialHNameHint; float *data = new float[cell_size]; for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++) @@ -947,10 +946,113 @@ std::vector CreateRNNConstant( } return inputs; } - break; - return inputs; default: - CNTK::LogicError("CreateRNNConstant for GRU op received unepxpeted index: %d", index); + CNTK::LogicError("CreateRNNConstant for GRU op received unexpected index: %d", index); + } + } + else if (parentONNXOpName == "RNN") + { + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6-1 + switch (index) + { + case RNNInputIndexX: + // X, should not come to here + CNTK::LogicError("input to a recurrent node shall not be a constant"); + case RNNInputIndexW: + case RNNInputIndexR: + { + // see ONNX spec for the tensor shape + int num_directions = valueProto.dims(0); + size_t rows = valueProto.dims(1); + size_t cols = valueProto.dims(2); + + // CNTK cpp requires shape: (input_size, 3 * hidden_size) + NDShape weightShape({ rows, cols }); + + int input_size = cols; + int cell_size = rows; + + for (int dir = 0; dir < num_directions; dir++) + { + std::string nodeName = name + (index == RNNInputIndexW ? "_W_" : "_R_") + (char)('0' + dir); + int totalSizePerDirection = rows * cols; + + // TODO: what about double? + float *data = new float[totalSizePerDirection]; + for (size_t count = 0; count < totalSizePerDirection; count++) + { + int row = count / input_size; + int col = count % input_size; + int sourceIndex = dir * totalSizePerDirection + count; + int targetIndex = col * cell_size + row; + data[targetIndex] = valueProto.float_data()[sourceIndex]; + } + + Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice); + inputs.push_back(constant); + } + return inputs; + } + case RNNInputIndexB: + // B + { + // see ONNX spec for the tensor shape: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6-1 + // shape of bias is [num_directions, 2*hidden_size] thus we divide dim(1) by 2 + // to get cell_size. + int num_directions = valueProto.dims(0); + int cell_size = valueProto.dims(1) / 2; + NDShape weightShape({ (size_t)(cell_size) }); + for (int dir = 0; dir < num_directions; dir++) + { + std::string nodeName = name + std::string(1, '0' + dir) + LSTMInputBiasNameHint; + int totalSizePerDirection = cell_size; + float *data = new float[totalSizePerDirection]; + for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++) + { + int row = targetIndex; + // source is column major + int src_index = row; + // "fuse" + // RNN only has one bias vector. It is applied after element-wise addition + // of projected input and hidden states. Therefore we need to fuse two biases + // in ONNX into one. + // RNNBiasMultiplier = 2 + data[targetIndex] = + valueProto.float_data()[dir * RNNBiasMultiplier * totalSizePerDirection + src_index] + + valueProto.float_data()[dir * RNNBiasMultiplier * totalSizePerDirection + totalSizePerDirection + src_index]; + } + + Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice); + inputs.push_back(constant); + } + return inputs; + } + case RNNInputIndexSequenceLens: + return inputs; + case RNNInitialH: + { + // initial_h + int num_directions = valueProto.dims(0); + int cell_size = valueProto.dims(2); + NDShape weightShape({ (size_t)(cell_size) }); + for (int dir = 0; dir < num_directions; dir++) + { + std::string nodeName = name + std::string(1, (char)('0' + dir)) + LSTMInputInitialHNameHint; + + float *data = new float[cell_size]; + for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++) + { + data[targetIndex] = valueProto.float_data()[dir * cell_size + targetIndex]; + } + + Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice); + inputs.push_back(constant); + } + return inputs; + } + default: + CNTK::LogicError("CreateRNNConstant for GRU op received unexpected index: %d", index); } } else @@ -1065,7 +1167,37 @@ std::vector ONNXToCNTKHelper::CreateRNNLeafVariableOrConstant(const No case GRUInitialH: NOT_IMPLEMENTED; default: - LogicError("LSTM node has unexpected input"); + LogicError("GRU node has unexpected input"); + } + } + else if (parentONNXOpName == "RNN") + { + int inputIndex = CalculateNodeArgInputIndex(nodeArg, parentNode); + switch (inputIndex) + { + case GRUInputIndexX: + // X: `[seq_length, batch_size, input_size]`. + { + Variable inputVariable; + if (constructedNodeArgVariableMap.find(nodeArg->Name()) == constructedNodeArgVariableMap.end()) + { + DataType dataType = FromONNXType(nodeArg->ToProto().type()); + int input_size = shapeProto->dim(2).dim_value(); + NDShape shape({ (size_t)(input_size) }); + inputVariable = InputVariable(shape, dataType, ToWString(nodeArg->Name()), dynamicAxes); + constructedNodeArgVariableMap.insert(ONNXToCNTKVariableMap::value_type(nodeArg->Name(), inputVariable)); + } + return std::vector({ constructedNodeArgVariableMap[nodeArg->Name()] }); + } + // other inputs shall be ONNX constant node and be created as CNTK Constant in CreateRNNConstant + case GRUInputIndexW: + case GRUInputIndexR: + case GRUInputIndexB: + case GRUInputIndexSequenceLens: + case GRUInitialH: + NOT_IMPLEMENTED; + default: + LogicError("RNN node has unexpected input"); } } else @@ -1652,6 +1784,15 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector std::vector({ "Sigmoid", "Tanh" })); return CreateGRU(node, inputs, direction, activations, activation_alpha, activation_beta); } + else if (onnxOpName == "RNN") + { + const string direction = GetNamedAttributeAsString(node, "direction"); + std::vector activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector()); + std::vector activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector()); + const std::vector activations = GetNamedAttributeAsStringVec(node, "activations", + std::vector({ "Tanh" })); + return CreateRNN(node, inputs, direction, activations, activation_alpha, activation_beta); + } if (onnxOpName == "FC") { return CreateCNTKFCNode(ToWString(node->Name()), inputs); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp index 722250d86..dc0ad725d 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp @@ -448,7 +448,7 @@ namespace ONNX bool Operators::IsRNNOp(const std::string &opName) { - return opName == "LSTM" || opName == "GRU" || opName == "RNN"; + return opName == "LSTM" || opName == "GRU" || opName == "RNN" || opName == "RNNStep"; } std::unordered_map> Operators::_cntkBlockOPInvalidIndices = { { L"Clip",{ 1, 2 } }, diff --git a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp index a45cd5bd2..d8ea24c0b 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp @@ -122,7 +122,7 @@ std::tuple, std::function &activations, const std::vector &activation_alpha, const std::vector &activation_beta, int direction) { if (activations.size() < (direction + 1) * LSTMActivationCount) - CNTK::LogicError("LSTM activations shall be %d or %d of strings", LSTMActivationCount, LSTMActivationCount * 2); + CNTK::LogicError("LSTM activations shall be a list of strings of size %d or %d ", LSTMActivationCount, LSTMActivationCount * 2); // int iofActivationIndex = direction * LSTMActivationCount + LSTMActivationFIndex; @@ -162,7 +162,7 @@ std::tuple, std::function &activations, const std::vector &activation_alpha, const std::vector &activation_beta, int direction) { if (activations.size() < (direction + 1) * GRUActivationCount) - CNTK::LogicError("LSTM activations shall be %d or %d of strings", GRUActivationCount, GRUActivationCount * 2); + CNTK::LogicError("GRU activations shall be a list of strings of size %d or %d", GRUActivationCount, GRUActivationCount * 2); // int fActivationIndex = direction * GRUActivationCount + GRUActivationFIndex; @@ -190,6 +190,34 @@ GetGRUActivations(const std::vector &activations, const std::vector return std::make_tuple(fActivationOp, gActivationOp); } +std::function +GetRNNActivations(const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta, int direction) +{ + if (activations.size() < (direction + 1)) + CNTK::LogicError("RNN activations shall be a list of strings of size 1 or 2"); + + // + int activationIndex = direction; + + bool hasAlpha = activation_alpha.size() == (direction + 1); + bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1); + std::function activationOp; + if (hasAlphaBeta) + { + activationOp = ActivationMap(activations[activationIndex], activation_alpha[activationIndex], activation_beta[activationIndex]); + } + else if (hasAlpha) + { + activationOp = ActivationMap(activations[activationIndex], activation_alpha[activationIndex]); + } + else + { + activationOp = ActivationMap(activations[activationIndex]); + } + + return activationOp; +} + std::pair LSTMPCell(Variable input, const std::function &iofActivationOp, const std::function &cellActivationOp, @@ -285,6 +313,20 @@ FunctionPtr GRUCell(Variable input, return ht; } +FunctionPtr RNNCell(Variable input, + const std::function &activationOp, + Variable prevOutput, + Constant &W, Constant &R, Constant &B) +{ + FunctionPtr proj = Times(W, input) + Times(R, prevOutput);; + if (B.IsInitialized()) + proj = B + proj; + + FunctionPtr h = activationOp(proj); + return h; +} + + #include "PrimitiveFunction.h" #include "BlockFunction.h" @@ -332,8 +374,30 @@ FunctionPtr GRUComponent(Variable input, auto actualDh = recurrenceHookH(gruCell); - gruCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } }); - return gruCell; + gruCell->ReplacePlaceholders({ { dh, actualDh } }); + + auto gruBlock = AsBlock(std::move(gruCell), { { inputPlaceholder , input } }, L"GRU", L""); + return gruBlock; +} + +FunctionPtr RNNComponent(Variable input, + const NDShape& cellShape, + const std::function &activationOp, + const std::function& recurrenceHookH, + Constant &W, Constant &R, Constant &B) +{ + auto dh = PlaceholderVariable(cellShape, input.DynamicAxes()); + auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes()); + + auto rnnCell = RNNCell( + inputPlaceholder, + activationOp, + dh, W, R, B); + + auto actualDh = recurrenceHookH(rnnCell); + + rnnCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } }); + return rnnCell; } const std::vector FindByNameHint(const std::vector &inputs, const std::string &hint) @@ -511,6 +575,56 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector &inp } } +FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector &inputs, const std::string &direction, + const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta) +{ + int numDirections = direction == RNNDirectionBidirection ? 2 : 1; + std::vector outputHs; + for (int dir = 0; dir < numDirections; dir++) + { + std::function activationOp = + GetRNNActivations(activations, activation_alpha, activation_beta, dir); + + // the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1 + Variable X = inputs[0]; + Variable W = inputs[1 * numDirections + dir - ((numDirections == 2) ? 1 : 0)]; + Variable R = inputs[2 * numDirections + dir - ((numDirections == 2) ? 1 : 0)]; + Variable B; + std::vector biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint); + if (numDirections == 1 && biasVariables.size() >= 1) + B = biasVariables[0]; + else if (numDirections == 2 && biasVariables.size() == 2) + B = biasVariables[1]; + + Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType()); + + int hiddenDim = W.Shape()[0]; + + FunctionPtr outputH; + + // if it is bidirectional LSTM, the second one will be the backword one. + bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1); + + std::function recurrenceHook; + if (go_backwards) + recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); }; + else + recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); }; + + outputH = RNNComponent( + X, { (size_t)hiddenDim }, activationOp, + recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B); + outputHs.push_back(outputH); + } + if (outputHs.size() == 1) + return outputHs[0]; + else + { + std::vector operands({ outputHs[0], outputHs[1] }); + return Splice(operands, Axis(0), ToWString(node->Name())); + } +} + template void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set& visitedFunctions, FunctionType preFunctor, FunctionType postFunctor) @@ -922,3 +1036,60 @@ void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_acti f_activation = "Sigmoid"; g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName())); } + +void TraceRNNPathes(const FunctionPtr& src, string &activation, + RNNDirection &direction, Variable &initStateH) +{ + std::vector inputVars = src->Inputs(); + std::vector pastValueOps, futureValueOps; + GetDelayOps(inputVars, pastValueOps, futureValueOps); + + // indices here coresponding with CNTK python layer code. + if (pastValueOps.size() == 1 && futureValueOps.size() == 0) + { + direction = RNNDirection::Forward; + initStateH = pastValueOps[0]->Inputs()[1]; + } + else if (pastValueOps.size() == 0 && futureValueOps.size() == 1) + { + direction = RNNDirection::Backward; + initStateH = futureValueOps[0]->Inputs()[1]; + } + else + { + CNTK::LogicError("Node %s (%s) is not a valid RNN node", ToString(src->Name()).c_str(), ToString(src->Uid()).c_str()); + } + + FunctionPtr activationFunction = src->BlockRoot(); + activation = MapActivationNameCNTKToONNX(ToString(activationFunction->OpName())); +} + +std::vector GetRNNBlocksFromSingleOrBidirectionalRNN(const FunctionPtr src, const std::string &RNNStepOpName) +{ + std::vector rnns; + if (ToString(src->OpName()) == RNNStepOpName) + { + rnns.push_back(src); + } + else if (src->OpName() == L"Splice") // src is a Splice op with inputs from two LSTM ops. + { + for (auto &input : src->Inputs()) + { + rnns.push_back(input.Owner()); + } + } + else + { + CNTK::LogicError("An %s op should start with an GRU op (single direction) or a Splice op (bidirectional).", RNNStepOpName.c_str()); + } + + // For single direction RNN, rnns.size() == 1. For bidirectional RNN, rnns.size() == 2. + // It is an error otherwise. + if (rnns.size() == 0 || rnns.size() > 2 || + std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) {return ToString(f->OpName()) != RNNStepOpName; })) + { + CNTK::LogicError("Invalid number of RNN ops to construct an ONNX %s node.", RNNStepOpName.c_str()); + } + + return rnns; +} diff --git a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h index 9113e7eaa..87767ebc5 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h @@ -107,6 +107,21 @@ enum GRUInitialH = 5, }; +enum +{ + RNNInputIndexX = 0, + RNNInputIndexW = 1, + RNNInputIndexR = 2, + RNNInputIndexB = 3, + RNNInputIndexSequenceLens = 4, + RNNInitialH = 5, +}; + +enum +{ + CNTKRNNOutputYhIndex = 0 +}; + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6 // size of weight/bias matrix is a multiple of hidden size enum @@ -130,6 +145,20 @@ enum CNTKGRUInputCount = 7 }; +enum +{ + CNTKRNNWeightIndex = 0, + CNTKRNNHweightIndex = 1, + CNTKRNNBiasIndex = 2, + CNTKRNNDelayIndex = 3, + CNTKRNNInputIndex = 4, + CNTKRNNInputCount = 5 +}; + +enum +{ + RNNBiasMultiplier = 2 +}; const string RNNDirectionBidirection = "bidirectional"; const string RNNDirectionReverse = "reverse"; @@ -141,6 +170,9 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector &in FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector &inputs, const std::string &direction, const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta); +FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector &inputs, const std::string &direction, + const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta); + void TraceLSTMPathes(const FunctionPtr& src, string &f_activation, string &g_activation, string &h_activation, RNNDirection &direction, Variable &initStateH, Variable &initStateC, Variable &peepholeCi, Variable &peepholeCo, Variable &peepholeCf, double &stabilizer_dh, double &stabilizer_dc, double &stabilizer_c); @@ -148,5 +180,10 @@ void TraceLSTMPathes(const FunctionPtr& src, string &f_activation, string &g_act void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation, RNNDirection &direction, Variable &initStateH); +void TraceRNNPathes(const FunctionPtr& src, string &activation, + RNNDirection &direction, Variable &initStateH); + std::string MapActivationNameONNXToCNTK(const std::string &onnxOp); -std::string MapActivationNameCNTKToONNX(const std::string &cntkOp); \ No newline at end of file +std::string MapActivationNameCNTKToONNX(const std::string &cntkOp); + +std::vector GetRNNBlocksFromSingleOrBidirectionalRNN(const FunctionPtr src, const std::string &RNNStepOpName); \ No newline at end of file diff --git a/bindings/python/cntk/tests/onnx_op_test.py b/bindings/python/cntk/tests/onnx_op_test.py index dc4aa7022..f6677cf81 100644 --- a/bindings/python/cntk/tests/onnx_op_test.py +++ b/bindings/python/cntk/tests/onnx_op_test.py @@ -433,26 +433,26 @@ def test_Greater(tmpdir): verify_no_input(model, tmpdir, 'Greater_0') #GRU -def MakeGRUNameFromConfig(backward, initial_state, activtion): - model_name = 'GRU.' + activtion.__name__ - if (initial_state != 0): - model_name += '.initial' - if (backward): - model_name += '.backward' - else: - model_name += '.forward' - return model_name - -direction_options = [False, True] -activation_options = [C.tanh] -initial_state_options = [0] - -input_dim = 2 -cell_dim = 3 -batch_size = 1 -sequence_len = 5 - def test_GRU(tmpdir): + def MakeGRUNameFromConfig(backward, initial_state, activition): + model_name = 'GRU.' + activition.__name__ + if (initial_state != 0): + model_name += '.initial' + if (backward): + model_name += '.backward' + else: + model_name += '.forward' + return model_name + + direction_options = [False, True] + activation_options = [C.tanh] + initial_state_options = [0] + + input_dim = 2 + cell_dim = 3 + batch_size = 1 + sequence_len = 5 + for config in list(product(direction_options, initial_state_options, activation_options)): model_filename = MakeGRUNameFromConfig(*config) print(model_filename) @@ -560,43 +560,44 @@ def test_LRN(tmpdir): verify_one_input(model, img, tmpdir, 'LRN_1') #LSTM -def CreateLSTMModel(activation, - peepholes, - self_stabilization, - cell_dim, - initial_state): - return C.layers.Sequential([ - C.layers.Recurrence(C.layers.LSTM(cell_dim, - use_peepholes = peepholes, - activation = activation, - enable_self_stabilization = self_stabilization), - initial_state = initial_state) - ]) - -# lstm attributes -use_peepholes_options = [False] -enable_self_stabilization_options = [False] -activation_options = [C.tanh] - -#Recurrence attributes -initial_state_options = [0] - -input_dim = 2 -cell_dim = 3 -batch_size = 1 -sequence_len = 5 - -def MakeLSTMNameFromConfig(use_peepholes, enable_self_stabilization, initial_state, activtion): - model_name = 'LSTM.' + activtion.__name__ - if (use_peepholes): - model_name += '.peephole' - if(enable_self_stabilization): - model_name += '.stabilize' - if (initial_state != 0): - model_name += '.initial' - return model_name - def test_LSTM(tmpdir): + def CreateLSTMModel(activation, + peepholes, + self_stabilization, + cell_dim, + initial_state): + return C.layers.Sequential([ + C.layers.Recurrence(C.layers.LSTM(cell_dim, + use_peepholes = peepholes, + activation = activation, + enable_self_stabilization = self_stabilization), + initial_state = initial_state) + ]) + + + def MakeLSTMNameFromConfig(use_peepholes, enable_self_stabilization, initial_state, activition): + model_name = 'LSTM.' + activition.__name__ + if (use_peepholes): + model_name += '.peephole' + if(enable_self_stabilization): + model_name += '.stabilize' + if (initial_state != 0): + model_name += '.initial' + return model_name + + # lstm attributes + use_peepholes_options = [False] + enable_self_stabilization_options = [False] + activation_options = [C.tanh] + + #Recurrence attributes + initial_state_options = [0] + + input_dim = 2 + cell_dim = 3 + batch_size = 1 + sequence_len = 5 + for config in list(product(use_peepholes_options, enable_self_stabilization_options, initial_state_options, activation_options)): model_filename = MakeLSTMNameFromConfig(*config) @@ -830,6 +831,81 @@ def test_Reshape(tmpdir): model = C.reshape(i1, (2,3)) verify_one_input(model, data, tmpdir, 'Reshape_1') +#RNN +def test_GRU(tmpdir): + def CreatRNN(cell_dim, + activation, + initial_state, + direction, + num_layers, + init=C.default_override_or(C.glorot_uniform()), + init_bias=C.default_override_or(0)): + if direction == 'bidirectional': + return C.layers.Sequential([ + C.layers.For(range(num_layers), lambda i: [ + (C.layers.Recurrence(C.layers.RNNStep(cell_dim, + activation = activation, + init = init, + init_bias = init_bias), + initial_state = initial_state, + return_full_state = False, go_backwards=False), + C.layers.Recurrence(C.layers.RNNStep(cell_dim, activation = activation, + init = init, + init_bias = init_bias), + initial_state = initial_state, + return_full_state = False, go_backwards=True)), + C.splice])]) + else: + go_backward = False if direction == 'forward' else True + return C.layers.Sequential([ + C.layers.For(range(num_layers), lambda i: [ + C.layers.Recurrence(C.layers.RNNStep(cell_dim, + activation = activation, + init = init, + init_bias = init_bias), + initial_state = initial_state, + return_full_state = False, go_backwards=go_backward)])]) + + def MakeRNNNameFromConfig(direction, num_layers, initial_state, activition): + model_name = 'GRU.' + direction + '.' + + if num_layers == 1: + model_name += 'one_layer.' + else: + assert (num_layers == 2), "needs 1 or 2 layers!" + model_name += 'two_layer.' + + if (initial_state != 0): + model_name += 'initial.' + + model_name += activition.__name__ + return model_name + + direction_options = ['forward', 'reverse', 'bidirectional'] + num_layers_options = [1, 2] + initial_state_options = [0] + activation_options = [C.tanh, C.relu, C.sigmoid] + + input_dim = 2 + hidden_dim = 3 + batch_size = 1 + sequence_len = 5 + + for config in list(product(direction_options, num_layers_options, initial_state_options, activation_options)): + model_filename = MakeRNNNameFromConfig(*config) + print(model_filename) + direction, num_layers, initial_state, activation = config + + x = C.input_variable(input_dim, dynamic_axes=[C.Axis.default_batch_axis(), C.Axis('sequenceAxis')]) + RNNModel = CreatRNN( + hidden_dim, + activation, + initial_state, + direction, + num_layers)(x) + data = np.random.uniform(low=0.0, high=1.0, size=(batch_size, sequence_len, input_dim)).astype('f') + verify_one_input(RNNModel, data, tmpdir, model_filename) + #Selu def test_Selu(tmpdir): model = C.selu([[-1, -0.5, 0, 1, 2]])