From f1f3bb4e637ee65ab86034258ca1a02366d8419e Mon Sep 17 00:00:00 2001 From: liqfu Date: Tue, 13 Mar 2018 11:24:33 -0700 Subject: [PATCH] Support ONNX GRU --- .../proto/onnx/CNTKToONNX.cpp | 1093 ++++++++++------- .../proto/onnx/ONNXToCNTK.cpp | 193 ++- .../CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp | 651 +++++++++- .../CNTKv2LibraryDll/proto/onnx/RNNHelper.h | 83 +- bindings/python/cntk/tests/onnx_op_test.py | 57 +- 5 files changed, 1591 insertions(+), 486 deletions(-) diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 710594964..e94005f81 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -114,15 +114,27 @@ namespace CNTK std::unordered_map& functionNodes, std::unordered_map& variableNodes, const std::unordered_map& compositeOutputsMap); + static ONNXIR::Node *CreateGRUNode(const FunctionPtr &src, + ONNXIR::Graph* graph, + std::unordered_map& functionNodes, + std::unordered_map& variableNodes, + const std::unordered_map& compositeOutputsMap); - static void PrepareLSTMInput(const Variable &X, std::vector &nodeInputs); + static void PrepareRNNInput(const Variable &X, std::vector &nodeInputs); static void PrepareLSTMInitialStateNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &initialVariables, int batchSize, int cellSize, const std::string &uid, std::vector &nodeInputs); + static void PrepareGRUWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Ws, std::vector &nodeInputs); + static void PrepareGRUZRHWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Rs, const std::vector &Rh1s, std::vector &nodeInputs); + static void PrepareGRUBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Bs, std::vector &nodeInputs); + static void PrepareLSTMWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Ws, double *stabilizerConstants, std::vector &nodeInputs); - static void PrepareBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + static void PrepareLSTMBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Ws, std::vector &nodeInputs); static void PrepareLSTMPeepholeNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Ps, @@ -141,17 +153,30 @@ namespace CNTK // static void CopyTensor(const NDArrayViewPtr src, onnx::TensorProto& dst, onnx::TypeProto *inputArgType = nullptr); - static void CopyTensorsWithMultipliers(const std::vector srcTensors, const std::vector multipliers, - onnx::TensorProto& dst, onnx::TypeProto *inputArgType); + static void CopyTensorsWithMultipliers(const std::vector &srcTensors, const std::vector &multipliers, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); + static void CopyGRUBiasTensors(const std::vector &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); + + static void CopyGRUWeightTensors(const std::vector &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); + + static void CopyGRUStateWeightTensors( + const std::vector &srcZRTensors, const std::vector &srcHTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); + static void FillTensorWithScalar(const std::vector& src, onnx::TensorProto& dst, const std::vector dstShape); // // Create an ONNX weight tensor for LSTM op. It handles memory mapping from CNTK to ONNX. // static void CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(const std::vector &src, double *stabilizerConstants, - onnx::TensorProto& dst, onnx::TypeProto *inputArgType); + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType); + + static void CopyShapeTypeProtoToTensorProto(const onnx::TypeProto &inputArgType, onnx::TensorProto& dst); + // // Copy supported attributes from CNTK node to corresponding ONNX node. // @@ -363,13 +388,35 @@ void CNTKToONNXHelper::Copy(const FunctionPtr& src, ONNXIR::Graph* dst) CreateNode(src, dst, functionNodes, variableNodes, compositeOutputsMap); } +void AddDataElementArrayViewToTensorProto(const NDArrayViewPtr src, int srcIndex, onnx::TensorProto& dst) +{ + DataType dataType = src->GetDataType(); + switch (dataType) + { + case DataType::Float: + { + auto data = src->DataBuffer(); + *(dst.mutable_float_data()->Add()) = data[srcIndex]; + } + break; + case DataType::Double: + { + auto data = src->DataBuffer(); + *(dst.mutable_double_data()->Add()) = data[srcIndex]; + } + break; + default: + NOT_IMPLEMENTED; + } +} + // LSTM gate bias order difference between CNTK (icfo) and ONNX (iofc) is // handled while building ONNX LSTM bias tensor. template void AppendCNTKBiasWeightToONNXTensor(DType *data, const NDShape &shape, onnx::TensorProto& dst) { auto totalSize = shape.TotalSize(); - int cell_size = shape[0] / 4; + int cell_size = shape[0] / LSTMWeightDimensionHiddenMultiplier; for (size_t targetIndex = 0; targetIndex < totalSize; targetIndex++) { int row = targetIndex; @@ -425,7 +472,7 @@ void AppendCNTKWeightToONNXTensor(DType *data, const NDShape &shape, onnx::Tenso auto totalSize = shape.TotalSize(); for (size_t targetIndex = 0; targetIndex < totalSize; targetIndex++) { - int cell_size = shape[0] / 4; + int cell_size = shape[0] / LSTMWeightDimensionHiddenMultiplier; int input_size = shape[1]; bool rowMajor = true; @@ -438,8 +485,8 @@ void AppendCNTKWeightToONNXTensor(DType *data, const NDShape &shape, onnx::Tenso } else { - row = targetIndex % (cell_size * 4); - col = targetIndex / (cell_size * 4); + row = targetIndex % (cell_size * LSTMWeightDimensionHiddenMultiplier); + col = targetIndex / (cell_size * LSTMWeightDimensionHiddenMultiplier); } // TODO: specific to LSTM. icfo (CNTK) to iofc(ONNX) @@ -456,7 +503,7 @@ void AppendCNTKWeightToONNXTensor(DType *data, const NDShape &shape, onnx::Tenso } // soruce is collum major - int src_index = 4 * cell_size * col + row; + int src_index = LSTMWeightDimensionHiddenMultiplier * cell_size * col + row; if (typeid(DType) == typeid(float)) *(dst.mutable_float_data()->Add()) = (float)(data[src_index] * stabilizer); else if(typeid(DType) == typeid(double)) @@ -466,16 +513,8 @@ void AppendCNTKWeightToONNXTensor(DType *data, const NDShape &shape, onnx::Tenso } } -void CNTKToONNXHelper::CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(const std::vector &src, double *stabilizerConstants, - onnx::TensorProto& dst, onnx::TypeProto *inputArgType) +void SetTensorType(onnx::TensorProto& dst, DataType dataType) { - // TODO: all NDArrayViewPtr shall have the same shape and data types. - if (src.empty()) - { - // TODO: error - return; - } - auto dataType = src[0]->GetDataType(); switch (dataType) { case DataType::Float: @@ -487,6 +526,20 @@ void CNTKToONNXHelper::CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(const default: NOT_IMPLEMENTED; } +} + +void CNTKToONNXHelper::CopyShapeTypeProtoToTensorProto(const onnx::TypeProto &inputArgType, onnx::TensorProto& dst) +{ + std::vector dimensions = CNTKToONNXHelper::ToINTS(inputArgType); + for (auto dim : dimensions) + *(dst.mutable_dims()->Add()) = dim; +} + +void CNTKToONNXHelper::CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(const std::vector &src, double *stabilizerConstants, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType) +{ + auto dataType = src[0]->GetDataType(); + SetTensorType(dst, dataType); for (int i = 0; i < src.size(); i++) { @@ -517,21 +570,24 @@ void CNTKToONNXHelper::CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(const } } - std::vector dimensions = CNTKToONNXHelper::ToINTS(*inputArgType); - for (auto dim : dimensions) - *(dst.mutable_dims()->Add()) = dim; + CopyShapeTypeProtoToTensorProto(inputArgType, dst); } -void CNTKToONNXHelper::CopyTensorsWithMultipliers(const std::vector srcTensors, - const std::vector multipliers, - onnx::TensorProto& dst, onnx::TypeProto *inputArgType) +void CNTKToONNXHelper::CopyTensorsWithMultipliers(const std::vector &srcTensors, + const std::vector &multipliers, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType) { // TODO: verify that srcTensors has consistant shapes + if (multipliers.size() != srcTensors.size()) + LogicError("To apply multiplier when copying tensors, number of multipliers must be the same as number of tensors."); + for (int viewIndex = 0; viewIndex < srcTensors.size(); viewIndex++) { auto view = srcTensors[viewIndex]; double multiplier = multipliers[viewIndex]; auto dataType = view->GetDataType(); + SetTensorType(dst, dataType); + auto srcTemp = view->DeepClone(); auto srcShape = srcTemp->Shape(); auto totalSize = srcShape.TotalSize(); @@ -540,16 +596,14 @@ void CNTKToONNXHelper::CopyTensorsWithMultipliers(const std::vectorDataBuffer(); for (size_t index = 0; index < totalSize; index++) - *(dst.mutable_float_data()->Add()) = (float) (data[index] * multiplier); + *(dst.mutable_float_data()->Add()) = (float)(data[index] * multiplier); break; } case DataType::Double: { - dst.set_data_type(onnx::TensorProto_DataType_DOUBLE); auto data = srcTemp->DataBuffer(); for (size_t index = 0; index < totalSize; index++) *(dst.mutable_double_data()->Add()) = data[index] * multiplier; @@ -561,9 +615,144 @@ void CNTKToONNXHelper::CopyTensorsWithMultipliers(const std::vector dimensions = CNTKToONNXHelper::ToINTS(*inputArgType); - for (auto dim : dimensions) - *(dst.mutable_dims()->Add()) = dim; + CopyShapeTypeProtoToTensorProto(inputArgType, dst); +} + +void CNTKToONNXHelper::CopyGRUBiasTensors(const std::vector &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType) +{ + if (srcTensors.empty()) + return; + + DataType dataType = srcTensors[0]->GetDataType(); + SetTensorType(dst, dataType); + + for (int i = 0; i < srcTensors.size(); i++) + { + auto srcTemp = srcTensors[i]->DeepClone(); + auto srcShape = srcTemp->Shape(); + + // This is our own copy so move it to the CPU. + srcTemp->ChangeDevice(DeviceDescriptor::CPUDevice()); + + auto totalSize = srcShape.TotalSize(); + for (size_t index = 0; index < totalSize; index++) + { + AddDataElementArrayViewToTensorProto(srcTemp, index, dst); + } + + // fill zeros for Rb[zrh] because CNTK GRU does not support Rb. + for (size_t index = 0; index < totalSize; index++) + switch (dataType) + { + case DataType::Float: + { + *(dst.mutable_float_data()->Add()) = 0; + } + break; + case DataType::Double: + { + *(dst.mutable_double_data()->Add()) = 0; + } + break; + } + } + + CopyShapeTypeProtoToTensorProto(inputArgType, dst); +} + +void CNTKToONNXHelper::CopyGRUWeightTensors(const std::vector &srcTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType) +{ + if (srcTensors.empty()) + return; + + DataType dataType = srcTensors[0]->GetDataType(); + SetTensorType(dst, dataType); + + for (int i = 0; i < srcTensors.size(); i++) + { + auto srcTemp = srcTensors[i]->DeepClone(); + auto srcShape = srcTemp->Shape(); + + // This is our own copy so move it to the CPU. + srcTemp->ChangeDevice(DeviceDescriptor::CPUDevice()); + + auto totalSize = srcShape.TotalSize(); + for (size_t targetIndex = 0; targetIndex < totalSize; targetIndex++) + { + int cell_size = srcShape[0] / 3; + int input_size = srcShape[1]; + + // row major layout + int row = targetIndex / input_size, + col = targetIndex % input_size; + + // soruce is collum major + int srcIndex = 3 * cell_size * col + row; + AddDataElementArrayViewToTensorProto(srcTemp, srcIndex, dst); + } + } + + CopyShapeTypeProtoToTensorProto(inputArgType, dst); +} + +void CNTKToONNXHelper::CopyGRUStateWeightTensors( + const std::vector &srcZRTensors, const std::vector &srcHTensors, + onnx::TensorProto& dst, const onnx::TypeProto &inputArgType) +{ + if (srcZRTensors.size() < 1 || srcZRTensors.size() > 2 || srcZRTensors.size() != srcHTensors.size()) + LogicError("Invalid number of GRU weight tensors"); + + DataType dataType = srcZRTensors[0]->GetDataType(); + SetTensorType(dst, dataType); + + for (int i = 0; i < srcZRTensors.size(); i++) + { + auto srcZRTemp = srcZRTensors[i]->DeepClone(); + auto srcZRShape = srcZRTemp->Shape(); + + auto srcHTemp = srcHTensors[i]->DeepClone(); + auto srcHShape = srcHTemp->Shape(); + + int cell_size = srcZRShape[1]; + + // This is our own copy so move it to the CPU. + srcZRTemp->ChangeDevice(DeviceDescriptor::CPUDevice()); + srcHTemp->ChangeDevice(DeviceDescriptor::CPUDevice()); + + auto totalSize = srcZRShape.TotalSize() + srcHShape.TotalSize(); + for (size_t targetIndex = 0; targetIndex < totalSize; targetIndex++) + { + // row major layout + int row = targetIndex / cell_size, + col = targetIndex % cell_size; + + int src_index; + NDArrayViewPtr srcBlockTensor; + int block = row / cell_size; + if (block == 0 || block == 1) + { + // zr blocks + srcBlockTensor = srcZRTemp; + src_index = 2 * cell_size * col + row; + } + else if (block == 2) + { + // h block + srcBlockTensor = srcHTemp; + src_index = cell_size * col + row - cell_size * 2; + } + else + { + LogicError("Invalid GRU state weight shape"); + } + + AddDataElementArrayViewToTensorProto(srcBlockTensor, src_index, dst); + } + } + + CopyShapeTypeProtoToTensorProto(inputArgType, dst); } void CNTKToONNXHelper::CopyTensor(const NDArrayViewPtr src, onnx::TensorProto& dst, onnx::TypeProto *inputArgType /*=nullptr*/) @@ -576,11 +765,12 @@ void CNTKToONNXHelper::CopyTensor(const NDArrayViewPtr src, onnx::TensorProto& d // This is our own copy so move it to the CPU. srcTemp->ChangeDevice(DeviceDescriptor::CPUDevice()); + SetTensorType(dst, dataType); + switch (dataType) { case DataType::Float: { - dst.set_data_type(onnx::TensorProto_DataType_FLOAT); auto data = srcTemp->DataBuffer(); for (size_t index = 0; index < totalSize; index++) *(dst.mutable_float_data()->Add()) = data[index]; @@ -589,7 +779,6 @@ void CNTKToONNXHelper::CopyTensor(const NDArrayViewPtr src, onnx::TensorProto& d } case DataType::Double: { - dst.set_data_type(onnx::TensorProto_DataType_DOUBLE); auto data = srcTemp->DataBuffer(); for (size_t index = 0; index < totalSize; index++) *(dst.mutable_double_data()->Add()) = data[index]; @@ -603,9 +792,7 @@ void CNTKToONNXHelper::CopyTensor(const NDArrayViewPtr src, onnx::TensorProto& d // use if (inputArgType != nullptr) { - std::vector dimensions = CNTKToONNXHelper::ToINTS(*inputArgType); - for (auto dim : dimensions) - *(dst.mutable_dims()->Add()) = dim; + CopyShapeTypeProtoToTensorProto(*inputArgType, dst); } else { @@ -1062,370 +1249,32 @@ std::tuple, bool, int, bool> CNTKToONNXHelper::CalculateBroadca return make_tuple(dims2, broadCast, axis_start, swapInput); } -axis_start = axis_start > 0 ? axis_start : 0; + axis_start = std::max(0, axis_start); -const std::vector broadcaseInputDims = swapInput ? dims1 : dims2; -// sanity check; -for (int i = 0; i < broadcaseInputDims.size(); i++) -{ - if ((i < axis_start || i >= axis_stop) && broadcaseInputDims[i] != 1) + const std::vector broadcaseInputDims = swapInput ? dims1 : dims2; + // sanity check; + for (int i = 0; i < broadcaseInputDims.size(); i++) { - LogicError("dimension %d cannot be broadcasted", i); - } - else if (i >= axis_start && i < axis_stop && dims1[i] != dims2[i]) - { - LogicError("dimension %d cannot be broadcasted", i); - } -} -std::vector dimensions; -for (int i = axis_start; i < axis_stop; i++) -{ - dimensions.push_back(broadcaseInputDims[i]); -} - -return make_tuple(dimensions, broadCast, axis_start, swapInput); -} - -template -void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set& visitedFunctions, - FunctionType preFunctor, FunctionType postFunctor) -{ -visitedFunctions.insert(cntkFunction); -preFunctor(cntkFunction); - -std::vector functionInputs = cntkFunction->Inputs(); -for (const auto& input : functionInputs) -{ - if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end()) - { - const auto& inputFunction = input.Owner(); - TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor); - } -} - -postFunctor(cntkFunction); -} - -bool IsSupportedRNNActivation(const std::wstring &cntkOpName) -{ - static std::vector supportedRNNActivations( + if ((i < axis_start || i >= axis_stop) && broadcaseInputDims[i] != 1) { - L"ReLU", - L"Tanh", - L"StableSigmoid" - }); - return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) != - supportedRNNActivations.cend(); -} - -std::string FindActivation(const std::vector &path, int nth) -{ - int count = 0; - for (std::vector::const_iterator it = path.begin(); it != path.end(); it++) - { - std::wstring opName = (*it)->OpName(); - if (IsSupportedRNNActivation(opName)) + LogicError("dimension %d cannot be broadcasted", i); + } + else if (i >= axis_start && i < axis_stop && dims1[i] != dims2[i]) { - if (count == nth) - { - std::unordered_multimap::const_iterator itLookup = Operators::CntkToONNXLookup().find(opName); - if (itLookup == Operators::CntkToONNXLookup().cend()) - CNTK::LogicError("Invalid activation (%s)", ToString(opName).c_str()); - - std::unordered_map::const_iterator itMap = (*itLookup).second.map.find(opName); - if (itMap == (*itLookup).second.map.cend()) - CNTK::LogicError("Invalid activation (%s)", ToString(opName).c_str()); - return itMap->second; - } - count++; + LogicError("dimension %d cannot be broadcasted", i); } } - return ""; -} - -Variable GetPeepholeVariableFromOp(FunctionPtr peepholeOp) -{ - // peephole variable is that child of peepholeOp that is neither stabilizer nor place holder - if (peepholeOp->OpName() != L"ElementTimes") - CNTK::LogicError("Peephole operation must be ElementTimes"); - - Variable peepholeVariable; - FunctionPtr stabilizerOp; - for (int i = 0; i < peepholeOp->Inputs().size(); i++) + std::vector dimensions; + for (int i = axis_start; i < axis_stop; i++) { - if (peepholeOp->Inputs()[i].Owner() && peepholeOp->Inputs()[i].Owner()->OpName() == L"Stabilizer") - { - stabilizerOp = peepholeOp->Inputs()[i].Owner(); - } - else if (peepholeOp->Inputs()[i].IsConstant() || peepholeOp->Inputs()[i].IsParameter()) - { - if (!peepholeVariable.IsInitialized()) - peepholeVariable = peepholeOp->Inputs()[i]; - else - CNTK::LogicError("Cannot find peephole variable from peephole op. Multiple qualified variables found."); - } + dimensions.push_back(broadcaseInputDims[i]); } - if (!peepholeVariable.IsInitialized()) - CNTK::LogicError("Cannot find peephole variable from peephole op."); - return peepholeVariable; -} - -// this method helps getting a stabilizer op from its parent/grandparent op. -// the parent op can be Times or ElementTimes. The grandparent op can be a plus op. -FunctionPtr GetStabilizerOp(FunctionPtr parentOp) -{ - FunctionPtr timesOp; - if (parentOp->OpName() == L"Plus") - { - for (int i = 0; i < parentOp->Inputs().size(); i++) - { - if (parentOp->Inputs()[i].Owner() && - (parentOp->Inputs()[i].Owner()->OpName() == L"Times" || - parentOp->Inputs()[i].Owner()->OpName() == L"ElementTimes")) - { - timesOp = parentOp->Inputs()[i].Owner(); - break; - } - } - } - else if (parentOp->OpName() == L"Times" || parentOp->OpName() == L"ElementTimes") - { - timesOp = parentOp; - } - - if (!timesOp) - { - CNTK::LogicError("Cannot find stabilizer op. A stabilizer op must be from Times or ElementTimes ops or skipped from a Plus op."); - } - - for (int j = 0; j < timesOp->Inputs().size(); j++) - { - if (timesOp->Inputs()[j].Owner() && timesOp->Inputs()[j].Owner()->OpName() == L"Stabilizer") - { - return timesOp->Inputs()[j].Owner(); - } - } - - return nullptr; -} - -double GetScaler(Variable variable) -{ - NDArrayViewPtr v = variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value(); - NDArrayViewPtr cpuV = v->DeepClone(); - cpuV->ChangeDevice(DeviceDescriptor::CPUDevice()); - - switch (variable.GetDataType()) - { - case DataType::Float: - return *((float *)cpuV->DataBuffer()); - case DataType::Double: - return *((double *)cpuV->DataBuffer()); - default: - NOT_IMPLEMENTED; - } -} - -double GetStabilizerCoef(const FunctionPtr stabilizerDhOp) -{ - double alpha = GetScaler(stabilizerDhOp->Inputs()[3]); - double steepness = GetScaler(stabilizerDhOp->Inputs()[1]); - return (log(exp(alpha * steepness) + 1.0F) / steepness); -} - -// A CNTK LSTM op is created with stacked matmul followed by a slice op for 4 gates. -// Slice order tells which graph path is for which gate. This method is -// to traverse the graph to find the 4 paths along the 4 gates. It helps to -// subsequently find needed attributes in order to build an ONNX LSTM op. -void TraceLSTMPathes(const FunctionPtr& src, - string &f_activation, - string &g_activation, - string &h_activation, - LSTMDirection &direction, - Variable &initStateH, - Variable &initStateC, - Variable &peepholeCi, - Variable &peepholeCo, - Variable &peepholeCf, - double &stabilizer_dh, - double &stabilizer_dc, - double &stabilizer_c) -{ - // src has to be an LSTM node. - std::vector inputVars = src->Inputs(); - std::vector pastValueOps, futureValueOps; - for (std::vector::iterator it = inputVars.begin(); it != inputVars.end(); ++it) - { - if ((*it).Owner() != nullptr && (*it).Owner()->OpName() == L"PastValue") - pastValueOps.push_back((*it).Owner()); - else if ((*it).Owner() != nullptr && (*it).Owner()->OpName() == L"FutureValue") - futureValueOps.push_back((*it).Owner()); - } - - if (pastValueOps.size() == 2 && futureValueOps.size() == 0) - { - direction = LSTMDirection::Forward; - initStateH = pastValueOps[0]->Inputs()[1]; - initStateC = pastValueOps[1]->Inputs()[1]; - } - else if (pastValueOps.size() == 0 && futureValueOps.size() == 2) - { - direction = LSTMDirection::Backward; - initStateH = futureValueOps[0]->Inputs()[1]; - initStateC = futureValueOps[1]->Inputs()[1]; - } - else - { - CNTK::LogicError("Node %s (%s) is not a valid LSTM node", ToString(src->Name()).c_str(), ToString(src->Uid()).c_str()); - } - - // set up traverse boundary - std::unordered_set visitedFunctions; - for (std::vector::const_iterator it = inputVars.begin(); it != inputVars.end(); it++) - { - visitedFunctions.insert(it->Owner()); - } - - // First find the peephole op node. - // see CNTK\bindings\python\cntk\layers\blocks.py node references. - std::vector> pathesBitBftJoint; - { - std::vector currentPeepholePath; - - // make a copy of traverse boundary - std::unordered_set peepHoleVisitedFunctions = visitedFunctions; - - // traverse to find the joint of bit and bft - TraverseGraphWithPrePostActions(src->BlockRoot(), - peepHoleVisitedFunctions, - (std::function)[ - &peepHoleVisitedFunctions, &pathesBitBftJoint, ¤tPeepholePath](const FunctionPtr& function) - { - currentPeepholePath.push_back(function); - if (function->OpName() == L"Plus" && - function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" && - function->Inputs()[1].Owner() && function->Inputs()[1].Owner()->OpName() == L"ElementTimes") - { - pathesBitBftJoint.push_back(currentPeepholePath); - peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(), - [function](FunctionPtr f) {return function == f; })); - } - }, - (std::function)[¤tPeepholePath](const FunctionPtr& function) - { - currentPeepholePath.pop_back(); - }); - } - - FunctionPtr peepholeCoOp; - bool haspeephole = pathesBitBftJoint.size() == 3; - if (haspeephole) - { - // the last ElementTimes op is the peephole op - std::vector &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(), - [](std::vector &p1, std::vector &p2) {return p1.size() < p2.size(); }); - std::vector::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(), - [](FunctionPtr function) {return function->OpName() == L"ElementTimes"; }); - if (itPeepholeOp == peepholePath.rend()) - { - CNTK::LogicError("Cannot find peephole op from a LSTM graph"); - } - - peepholeCoOp = *itPeepholeOp; - peepholeCo = GetPeepholeVariableFromOp(peepholeCoOp); - - FunctionPtr stabilizer_h_op = GetStabilizerOp(peepholeCoOp); - if (stabilizer_h_op) - { - stabilizer_c = GetStabilizerCoef(stabilizer_h_op); - } - } - - std::vector> pathesToPlusSlice; - std::vector currentPath; - - if (haspeephole) - // so that traverse will not be affected by the peephole path - visitedFunctions.insert(peepholeCoOp); - - TraverseGraphWithPrePostActions(src->BlockRoot(), - visitedFunctions, - (std::function)[&pathesToPlusSlice, ¤tPath](const FunctionPtr& function) - { - currentPath.push_back(function); - if (function->OpName() == L"Slice") - { - FunctionPtr functionSource = function->Inputs()[0].Owner(); - if (functionSource->OpName() == L"Plus") - { - pathesToPlusSlice.push_back(currentPath); - } - } - }, - (std::function)[¤tPath](const FunctionPtr& function) - { - currentPath.pop_back(); - }); - - if (pathesToPlusSlice.size() != 4) - { - CNTK::LogicError("pathesToPlusSlice.size() != 4"); - } - - std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(), - [](const std::vector& path1, const std::vector& path2) - { - FunctionPtr slice1 = *path1.rbegin(); - FunctionPtr slice2 = *path2.rbegin(); - int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value(); - int beginIndex2 = slice2->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value(); - return beginIndex1 < beginIndex2; - }); - - std::vector &ht_it_path = pathesToPlusSlice[0]; - std::vector &ht_bit_path = pathesToPlusSlice[1]; - std::vector &ht_ft_path = pathesToPlusSlice[2]; - std::vector &ht_ot_path = pathesToPlusSlice[3]; - - f_activation = FindActivation(ht_ot_path, 0); - g_activation = FindActivation(ht_bit_path, 1); - h_activation = FindActivation(ht_bit_path, 0); - - // stabilizer_dh - FunctionPtr stackedProjPlusOp = ht_it_path[ht_it_path.size() - 1]->Inputs()[0].Owner(); - FunctionPtr stabilizerDhOp = GetStabilizerOp(stackedProjPlusOp); - if (stabilizerDhOp) - { - stabilizer_dh = GetStabilizerCoef(stabilizerDhOp); - } - - if (haspeephole) - { - { - // Ci merges to ht_it_path via element-wise time - FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2]; - FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? - plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner(); - peepholeCi = GetPeepholeVariableFromOp(peepholeOp); - - } - { - // Cf merges to ht_ft_path via element-wise time - FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2]; - FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? - plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner(); - peepholeCf = GetPeepholeVariableFromOp(peepholeOp); - - FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp); - if (stabilizerDcOp) - stabilizer_dc = GetStabilizerCoef(stabilizerDcOp); - } - } + return make_tuple(dimensions, broadCast, axis_start, swapInput); } // prepare an input node arg with correct name and meta data so that LotusIR can make the connection. -void CNTKToONNXHelper::PrepareLSTMInput(const Variable &X, std::vector &nodeInputs) +void CNTKToONNXHelper::PrepareRNNInput(const Variable &X, std::vector &nodeInputs) { Variable input; wstring opName = X.Owner() ? X.Owner()->OpName() : L""; @@ -1442,7 +1291,6 @@ void CNTKToONNXHelper::PrepareLSTMInput(const Variable &X, std::vectormutable_shape()->mutable_dim())[0].set_dim_param(FreeSequenceDimParam); @@ -1487,47 +1335,9 @@ void CNTKToONNXHelper::PrepareLSTMInitialStateNode(ONNXIR::Graph* graph, std::un variableNode->AddAttribute("value", dstTensor); nodeInputs.push_back(inputArg); - // TODO: variableNodes.emplace(initialVariables[0], variableNode); } -void CNTKToONNXHelper::PrepareBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, - const std::vector &Bs, std::vector &nodeInputs) -{ - // TODO: sanity check for all variables to have the same shape and data types. - // NDShape is in reversed order relative CNTK python so doReverseVec need to be true - // when converting to ONNX tensor. - // However with LSTM, CNTK python weight tensor shape is already reversed relative to ONNX. - // We do not want to reverse again. - bool doReverseVec = false; - - std::vector shape = Cast((NDShape({ Bs.size() }).AppendShape(Bs[0].Shape())).Dimensions()); - shape[1] *= 2; - onnx::TypeProto inputArgType = ToTypeProto(shape, doReverseVec); - UpdateONNXType(Bs[0].GetDataType(), inputArgType); - ONNXIR::NodeArg inputArg(ToString(Bs[0].Uid()), &inputArgType); - std::vector varOutputs({ inputArg }); - std::vector varInputs; - std::string inputName = inputArg.Name(); - ONNXIR::Node* variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs); - - std::vector srcTensors; - for (int i = 0; i < Bs.size(); i++) - { - const Variable &variable = Bs[i]; - srcTensors.push_back(variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value()); - } - - onnx::TensorProto dstTensor; - - CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(srcTensors, nullptr, dstTensor, &inputArgType); - variableNode->AddAttribute("value", dstTensor); - nodeInputs.push_back(inputArg); - - // TODO: - variableNodes.emplace(Bs[0], variableNode); -} - void CNTKToONNXHelper::PrepareLSTMPeepholeNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Ps, const std::vector &stabilizerDcCoefs, const std::vector &stabilizerCCoefs, @@ -1575,15 +1385,52 @@ void CNTKToONNXHelper::PrepareLSTMPeepholeNode(ONNXIR::Graph* graph, onnx::TensorProto dstTensor; - CopyTensorsWithMultipliers(srcTensors, multipliers, dstTensor, &inputArgType); + CopyTensorsWithMultipliers(srcTensors, multipliers, dstTensor, inputArgType); variableNode->AddAttribute("value", dstTensor); nodeInputs.push_back(inputArg); - // TODO: variableNodes.emplace(Ps[0], variableNode); } +void CNTKToONNXHelper::PrepareLSTMBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Bs, std::vector &nodeInputs) +{ + // TODO: sanity check for all variables to have the same shape and data types. + // NDShape is in reversed order relative CNTK python so doReverseVec need to be true + // when converting to ONNX tensor. + // However with LSTM, CNTK python weight tensor shape is already reversed relative to ONNX. + // We do not want to reverse again. + bool doReverseVec = false; + + std::vector shape = Cast((NDShape({ Bs.size() }).AppendShape(Bs[0].Shape())).Dimensions()); + + // ONNX LSTM spec has 2 bias, for forward and backward. + shape[1] *= 2; + onnx::TypeProto inputArgType = ToTypeProto(shape, doReverseVec); + UpdateONNXType(Bs[0].GetDataType(), inputArgType); + ONNXIR::NodeArg inputArg(ToString(Bs[0].Uid()), &inputArgType); + std::vector varOutputs({ inputArg }); + std::vector varInputs; + std::string inputName = inputArg.Name(); + ONNXIR::Node* variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs); + + std::vector srcTensors; + for (int i = 0; i < Bs.size(); i++) + { + const Variable &variable = Bs[i]; + srcTensors.push_back(variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value()); + } + + onnx::TensorProto dstTensor; + + CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(srcTensors, nullptr, dstTensor, inputArgType); + variableNode->AddAttribute("value", dstTensor); + nodeInputs.push_back(inputArg); + + variableNodes.emplace(Bs[0], variableNode); +} + void CNTKToONNXHelper::PrepareLSTMWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, const std::vector &Ws, double *stabilizerConstants, std::vector &nodeInputs) { @@ -1612,14 +1459,37 @@ void CNTKToONNXHelper::PrepareLSTMWeightNode(ONNXIR::Graph* graph, std::unordere onnx::TensorProto dstTensor; - CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(srcTensors, stabilizerConstants, dstTensor, &inputArgType); + CopyTensorsWithCNTKToONNXLSTMWeightLayoutConversion(srcTensors, stabilizerConstants, dstTensor, inputArgType); variableNode->AddAttribute("value", dstTensor); nodeInputs.push_back(inputArg); - // TODO: variableNodes.emplace(Ws[0], variableNode); } +std::string DeriveDirectionString(const std::vector lstms, + std::map directionCount) +{ + return lstms.size() == 2 ? RNNDirectionBidirection : + (directionCount[RNNDirection::Backward] == 1 ? RNNDirectionReverse : RNNDirectionForward); +} + +void AddEmptyInput(std::vector &nodeInputs) +{ + ONNXIR::NodeArg inputArg("", nullptr); + nodeInputs.emplace_back(inputArg); +} + +void SanityCheckForConstantOrParameters(const std::vector &variables) +{ + for (auto variable : variables) + { + if (variable.IsInitialized() && !variable.IsConstant() && !variable.IsParameter()) + CNTK::LogicError("Input to RNN op is not a constant or parameter: Variable Name: %S, Variable Uid: %S", + variable.Name().c_str(), + variable.Uid().c_str()); + } +} + ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, ONNXIR::Graph* graph, std::unordered_map& functionNodes, @@ -1653,7 +1523,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, } // order forward, backward - std::map directionCount({ { LSTMDirection::Forward, 0 } ,{ LSTMDirection::Backward, 0 } }); + std::map directionCount({ { RNNDirection::Forward, 0 } ,{ RNNDirection::Backward, 0 } }); // The following construct refers to ONNX spec: // https://github.com/onnx/onnx/blob/master/docs/Operators.md#lstm @@ -1682,7 +1552,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, // src has to be an LSTM node. const FunctionPtr& lstm = *itLSTMBlock; string f_activation, g_activation, h_activation; - LSTMDirection direction; + RNNDirection direction; Variable initStateH, initStateC; Variable peepholeCi, peepholeCo, peepholeCf; double stabilizer_dh = 1, stabilizer_dc = 1, stabilizer_c = 1; @@ -1726,10 +1596,14 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, Ycs[directionIndex] = outputs[CNTKLSTMOutputChIndex]; } + SanityCheckForConstantOrParameters(initialHs); + SanityCheckForConstantOrParameters(initialCs); + SanityCheckForConstantOrParameters(Ps); + // ensure that if there is one direction, it is not backward. // if there two directions, they are forward and backward, and // that the inputs (Xs) are the same. - if (std::any_of(directionCount.begin(), directionCount.end(), [](std::map::value_type &v) {return v.second > 1; })) + if (std::any_of(directionCount.begin(), directionCount.end(), [](std::map::value_type &v) {return v.second > 1; })) { LogicError("LSTM node is invalid because there should be no more than one path in each direction."); } @@ -1738,22 +1612,21 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, LogicError("Bi-directional LSTM node is invalid because the two LSTM nodes do not share one same input."); } - string direction = lstms.size() == 2 ? LSTMDirectionBidirection : - (directionCount[LSTMDirection::Backward] == 1 ? LSTMDirectionReverse : LSTMDirectionForward); + string direction = DeriveDirectionString(lstms, directionCount); // TODO: following commented out attributes are not supported. Use default. // float clip; // no clip yet // std::vector activation_alpha; // no supported activation need alpha. // std::vector activation_beta; // no supported activation need beta. int hidden_size = lstms[0]->Outputs()[0].Shape()[0]; - int output_sequence = 1; // LSTM in CNTK always output full sequence of output + int output_sequence = RNNOutputSequence; // LSTM in CNTK always output full sequence of output // TODO: implement peephole // Variable P; // inputs std::vector nodeInputs; - PrepareLSTMInput(Xs[0], nodeInputs); + PrepareRNNInput(Xs[0], nodeInputs); PrepareLSTMWeightNode(graph, variableNodes, Ws, nullptr, nodeInputs); PrepareLSTMWeightNode(graph, variableNodes, Rs, &stabilizerDhCoefs[0], nodeInputs); @@ -1761,12 +1634,11 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, bool hasBias = std::all_of(Bs.begin(), Bs.end(), [](Variable &v) {return v.IsInitialized(); }); if (hasBias) { - PrepareBiasNode(graph, variableNodes, Bs, nodeInputs); + PrepareLSTMBiasNode(graph, variableNodes, Bs, nodeInputs); } else { - ONNXIR::NodeArg inputArg("", nullptr); - nodeInputs.push_back(inputArg); + AddEmptyInput(nodeInputs); } // TODO: enable sequence_lens. It requires additional model input of batched sequence data layout. @@ -1782,8 +1654,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, } else { - ONNXIR::NodeArg inputArg("", nullptr); - nodeInputs.push_back(inputArg); + AddEmptyInput(nodeInputs); } bool has_initial_h = std::all_of(initialHs.begin(), initialHs.end(), [](Variable &v) {return v.IsInitialized(); }); @@ -1794,10 +1665,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, } else { - { - ONNXIR::NodeArg inputArg("", nullptr); - nodeInputs.push_back(inputArg); - } + AddEmptyInput(nodeInputs); } bool has_initial_c = std::all_of(initialCs.begin(), initialCs.end(), [](Variable &v) {return v.IsInitialized(); }); @@ -1807,8 +1675,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, PrepareLSTMInitialStateNode(graph, variableNodes, initialCs, FreeBatchSize, hidden_size, cellUid, nodeInputs); } else { - ONNXIR::NodeArg inputArg("", nullptr); - nodeInputs.push_back(inputArg); + AddEmptyInput(nodeInputs); } // peephole @@ -1819,8 +1686,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, } else { - ONNXIR::NodeArg inputArg("", nullptr); - nodeInputs.push_back(inputArg); + AddEmptyInput(nodeInputs); } } @@ -1848,8 +1714,8 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, { Variable Yh = Yhs[0]; std::string nodeName = ToString(Yh.Uid()) + "_h"; - // TODO: - int batchSize = 1; + // TODO: batchSize is fixed to one. Needs to find out how to handle bacth axis as a free dimension. + const int batchSize = 1; auto outputArgType = ToTypeProto(std::vector({ (int)Yhs.size(), batchSize, (int)Yh.Shape()[0]}), false); UpdateONNXType(Yh.GetDataType(), outputArgType); ONNXIR::NodeArg outputArg(nodeName, &outputArgType); @@ -1895,6 +1761,316 @@ ONNXIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src, return squeezedLSTMNode; } +void CNTKToONNXHelper::PrepareGRUBiasNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Bs, std::vector &nodeInputs) +{ + // TODO: sanity check for all variables to have the same shape and data types. + bool doReverseVec = false; + int numDirections = Bs.size(); + int hiddenSize = Bs[0].Shape()[0] / GRUWeightDimensionHiddenMultiplier; + + std::vector shape({ numDirections, GRUBiasDimensionHiddenMultiplier * hiddenSize }); + + // ONNX GRU spec has 2 bias, for forward and backward. + onnx::TypeProto inputArgType = ToTypeProto(shape, doReverseVec); + UpdateONNXType(Bs[0].GetDataType(), inputArgType); + ONNXIR::NodeArg inputArg(ToString(Bs[0].Uid()), &inputArgType); + std::vector varOutputs({ inputArg }); + std::vector varInputs; + std::string inputName = inputArg.Name(); + ONNXIR::Node* variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs); + + std::vector srcTensors; + for (int i = 0; i < Bs.size(); i++) + { + const Variable &variable = Bs[i]; + srcTensors.push_back(variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value()); + } + + onnx::TensorProto dstTensor; + + CopyGRUBiasTensors(srcTensors, dstTensor, inputArgType); + variableNode->AddAttribute("value", dstTensor); + nodeInputs.push_back(inputArg); + + variableNodes.emplace(Bs[0], variableNode); +} + +void CNTKToONNXHelper::PrepareGRUZRHWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Rzrs, const std::vector &Rhs, std::vector &nodeInputs) +{ + int numDirections = Rzrs.size(); + int hiddenSize = Rzrs[0].Shape().Dimensions()[1]; + std::vector shape({ numDirections, GRUWeightDimensionHiddenMultiplier * hiddenSize, hiddenSize }); + onnx::TypeProto inputArgType = ToTypeProto(shape, false); + UpdateONNXType(Rzrs[0].GetDataType(), inputArgType); + ONNXIR::NodeArg inputArg(ToString(Rzrs[0].Uid()), &inputArgType); + std::vector varOutputs({ inputArg }); + std::vector varInputs; + std::string inputName = inputArg.Name(); + ONNXIR::Node* variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs); + + std::vector srcZRTensors, srcHTensors; + for (int i = 0; i < Rzrs.size(); i++) + { + const Variable &variable = Rzrs[i]; + srcZRTensors.push_back(variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value()); + + const Variable &variableH1 = Rhs[i]; + srcHTensors.push_back(variableH1.IsParameter() ? Parameter(variableH1).Value() : Constant(variableH1).Value()); + } + + onnx::TensorProto dstTensor; + + CopyGRUStateWeightTensors(srcZRTensors, srcHTensors, dstTensor, inputArgType); + variableNode->AddAttribute("value", dstTensor); + nodeInputs.push_back(inputArg); + + variableNodes.emplace(Rzrs[0], variableNode); +} +void CNTKToONNXHelper::PrepareGRUWeightNode(ONNXIR::Graph* graph, std::unordered_map& variableNodes, + const std::vector &Ws, std::vector &nodeInputs) +{ + // TODO: sanity check for all variables to have the same shape and data types. + bool doReverseVec = false; + + std::vector shape = Cast((NDShape({ Ws.size() }).AppendShape(Ws[0].Shape())).Dimensions()); + onnx::TypeProto inputArgType = ToTypeProto(shape, doReverseVec); + UpdateONNXType(Ws[0].GetDataType(), inputArgType); + ONNXIR::NodeArg inputArg(ToString(Ws[0].Uid()), &inputArgType); + std::vector varOutputs({ inputArg }); + std::vector varInputs; + std::string inputName = inputArg.Name(); + ONNXIR::Node* variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs); + + std::vector srcTensors; + for (int i = 0; i < Ws.size(); i++) + { + const Variable &variable = Ws[i]; + srcTensors.push_back(variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value()); + } + + onnx::TensorProto dstTensor; + + CopyGRUWeightTensors(srcTensors, dstTensor, inputArgType); + variableNode->AddAttribute("value", dstTensor); + nodeInputs.push_back(inputArg); + + variableNodes.emplace(Ws[0], variableNode); +} + +ONNXIR::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src, + ONNXIR::Graph* graph, + std::unordered_map& functionNodes, + std::unordered_map& variableNodes, + const std::unordered_map& compositeOutputsMap) +{ + // sanity check: + std::vector grus; + if (src->OpName() == L"GRU") + { + grus.push_back(src); + } + else if (src->OpName() == L"Splice") // src is a Splice op with inputs from two LSTM ops. + { + for (auto &input : src->Inputs()) + { + grus.push_back(input.Owner()); + } + } + else + { + LogicError("An GRU op should start with an GRU op (single direction) or a Splice op (bidirectional)."); + } + + // For single direction GRU, grus.size() == 1. For bidirectional GRU, grus.size() == 2. + // It is an error otherwise. + if (grus.size() == 0 || grus.size() > 2 || + std::any_of(grus.cbegin(), grus.cend(), [](const FunctionPtr &f) {return f->OpName() != L"GRU"; })) + { + LogicError("Invalid number of GRU ops to construct an ONNX GRU node."); + } + + // order forward, backward + std::map directionCount({ { RNNDirection::Forward, 0 } ,{ RNNDirection::Backward, 0 } }); + + // The following construct refers to ONNX spec: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#lstm + // specifically, for attrubute and variable dimension. + // We use the term from the spec as possible as we can to maintain a close correlation + // to the ONNX specification. + + int num_directions = grus.size(); + // A list of 3 (or 6 if bidirectional) activation functions for input, output, forget, cell, and hidden. + std::vector activations(num_directions * GRUActivationCount); + + // TODO: + // In principle all these variables shall be treated as either constant or op output. + // In reality except X, all other inputs to LSTM can be treated as constant. + std::vector Xs(num_directions), Ws(num_directions), Rzrs(num_directions), + Rhs(num_directions), Bs(num_directions), + initialHs(num_directions); + + std::vector Yhs(grus.size()); + + for (std::vector::const_iterator itGRUBlock = grus.cbegin(); itGRUBlock != grus.cend(); itGRUBlock++) + { + // src has to be an GRU node. + const FunctionPtr& gru = *itGRUBlock; + std::vector inputs = gru->Inputs(); + if (inputs.size() != CNTKGRUInputCount) + LogicError("Unkown GRU configuration. The GRU node might be created with self stabilization. Such GRU ops cannot be converted to ONNX."); + + string f_activation, g_activation; + RNNDirection direction; + Variable initStateH; + TraceGRUPathes(gru, f_activation, g_activation, direction, initStateH); + + directionCount[direction]++; + + int directionIndex = grus.size() == 1 ? 0 : (direction ? 1 : 0); + + initialHs[directionIndex] = initStateH; + + + activations[directionIndex * GRUActivationCount + GRUActivationFIndex] = f_activation; + activations[directionIndex * GRUActivationCount + GRUActivationGIndex] = g_activation; + + // input (always the last one), weight, hidden weight, and bias have fixed indices. + // Thus we do not bother obtain them through traversing. + int inputIndex = inputs.size() - 1; + Xs[directionIndex] = inputs[inputIndex]; + + Ws[directionIndex] = inputs[CNTKGRUWeightIndex]; + SanityCheckForConstantOrParameters(Ws); + + Rzrs[directionIndex] = inputs[CNTKGRUHiddenWeightZRIndex]; + SanityCheckForConstantOrParameters(Rzrs); + + Rhs[directionIndex] = inputs[CNTKGRUHiddenWeightHIndex]; + SanityCheckForConstantOrParameters(Rhs); + + Bs[directionIndex] = inputs[CNTKGRUBiasIndex]; + SanityCheckForConstantOrParameters(Bs); + + std::vector outputs = gru->Outputs(); + + Yhs[directionIndex] = outputs[CNTKLSTMOutputYhIndex]; + } + + // ensure that if there is one direction, it is not backward. + // if there two directions, they are forward and backward, and + // that the inputs (Xs) are the same. + if (std::any_of(directionCount.begin(), directionCount.end(), [](std::map::value_type &v) {return v.second > 1; })) + { + LogicError("GRU node is invalid because there should be no more than one path in each direction."); + } + if (grus.size() == 2 && Xs[0] != Xs[1]) + { + LogicError("Bi-directional GRU node is invalid because the two LSTM nodes do not share one same input."); + } + + string direction = DeriveDirectionString(grus, directionCount); + + // an RNN output size is the hidden size + int hidden_size = grus[0]->Outputs()[0].Shape()[0]; + + // inputs + std::vector nodeInputs; + PrepareRNNInput(Xs[0], nodeInputs); + PrepareGRUWeightNode(graph, variableNodes, Ws, nodeInputs); + PrepareGRUZRHWeightNode(graph, variableNodes, Rzrs, Rhs, nodeInputs); + + { + bool hasBias = std::all_of(Bs.begin(), Bs.end(), [](Variable &v) {return v.IsInitialized(); }); + if (hasBias) + { + PrepareGRUBiasNode(graph, variableNodes, Bs, nodeInputs); + } + else + { + AddEmptyInput(nodeInputs); + } + + { + // sequence_lens is not supported + AddEmptyInput(nodeInputs); + } + + bool has_initial_h = std::all_of(initialHs.begin(), initialHs.end(), [](Variable &v) {return v.IsInitialized(); }); + if (has_initial_h) + { + std::string hiddenUid = ToString(Yhs[0].Uid()) + "_initial_h"; + PrepareLSTMInitialStateNode(graph, variableNodes, initialHs, FreeBatchSize, hidden_size, hiddenUid, nodeInputs); + } + else + { + AddEmptyInput(nodeInputs); + } + } + + const int output_sequence = RNNOutputSequence; // GRU in CNTK always output full sequence of output + std::vector nodeOutputs; + { + if (output_sequence == 1) + { + std::string nodeName; + if (grus.size() == 1) + nodeName = ToString(Yhs[0].Uid()); + else + nodeName = ToString(src->Output().Uid()); + + auto outputArgType = ToTypeProto(std::vector({ FreeSequenceLen, (int)Yhs.size(), FreeBatchSize, (int)Yhs[0].Shape()[0] }), false); + UpdateONNXType(Yhs[0].GetDataType(), outputArgType); + ONNXIR::NodeArg outputArg(nodeName, &outputArgType); + nodeOutputs.push_back(outputArg); + } + else + { + ONNXIR::NodeArg outputArg("", nullptr); + nodeOutputs.push_back(outputArg); + } + + { + Variable Yh = Yhs[0]; + std::string nodeName = ToString(Yh.Uid()) + "_h"; + // TODO: batchSize is fixed to one. Needs to find out how to handle bacth axis as a free dimension. + const int batchSize = 1; + const bool doReverseVec = false; + auto outputArgType = ToTypeProto(std::vector({ (int)Yhs.size(), batchSize, (int)Yh.Shape()[0] }), doReverseVec); + UpdateONNXType(Yh.GetDataType(), outputArgType); + ONNXIR::NodeArg outputArg(nodeName, &outputArgType); + nodeOutputs.push_back(outputArg); + } + } + + // TODO: Except X, all other inputs to GRU are treated as constant. + // It is highly unlikely that any other input is an output of an op. + // We will investigate once it is real. + if (Xs[0].Owner().get() != nullptr) + CreateNode(Xs[0].Owner(), graph, functionNodes, variableNodes, compositeOutputsMap); + + auto nodeName = src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name()); + ONNXIR::Node *gruNode = graph->AddNode(nodeName, "GRU", "", nodeInputs, nodeOutputs); + + gruNode->AddAttribute("activations", activations); + gruNode->AddAttribute("direction", direction); + gruNode->AddAttribute("hidden_size", (int64_t)hidden_size); + gruNode->AddAttribute("output_sequence", (int64_t)output_sequence); + + // TODO: make bidirectional GRU work by figuring out output data + // layout transpose in InsertReshapeNodeToCNTKFunction. + if (grus.size() == 2) + NOT_IMPLEMENTED; + + // TODO: uncomment this code once LotusRT output shape matches ONNX + // squeeze direction axis out. This is safe because it is not bi-directional node. + std::vector shape({ FreeSequenceLen, 1, hidden_size }); + ONNXIR::Node *squeezedLSTMNode = InsertReshapeNodeToCNTKFunction(src, gruNode, shape, graph); + functionNodes.emplace(src, squeezedLSTMNode); + return squeezedLSTMNode; +} + ONNXIR::Node *CNTKToONNXHelper::AddReshapeNode(const ONNXIR::NodeArg &nodeArg, const std::vector &newShape, const std::string &outArgName, ONNXIR::Graph* graph) { ONNXIR::NodeArg outputArg(outArgName, nullptr); @@ -1938,7 +2114,7 @@ ONNXIR::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const FunctionPt { FunctionPtr blockRoot = src->BlockRoot(); Variable output; - if (src->OpName() == L"LSTM") + if (Operators::IsRNNOp(ToString(src->OpName()))) output = src->Outputs()[0]; else // a bidirection LSTM case @@ -2000,7 +2176,11 @@ ONNXIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src, // return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); //} //else - if (opName == "LSTM") + if (opName == "GRU") + { + return CreateGRUNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); + } + else if (opName == "LSTM") { return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); } @@ -2985,13 +3165,13 @@ void CNTKToONNXHelper::FillTensorWithScalar(const std::vector &s onnx::TensorProto& dst, const std::vector dstShape) { auto dataType = srcs[0]->GetDataType(); + SetTensorType(dst, dataType); // the first dimension is for srcs count int eachSrcSize = std::accumulate(dstShape.begin() + 1, dstShape.end(), 1, std::multiplies()); switch (dataType) { case DataType::Float: { - dst.set_data_type(onnx::TensorProto_DataType_FLOAT); for (int i = 0; i < srcs.size(); i++) { auto srcTemp = srcs[i]->DeepClone(); @@ -3008,7 +3188,6 @@ void CNTKToONNXHelper::FillTensorWithScalar(const std::vector &s } case DataType::Double: { - dst.set_data_type(onnx::TensorProto_DataType_DOUBLE); for (int i = 0; i < srcs.size(); i++) { auto srcTemp = srcs[i]->DeepClone(); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp index 03cb8d98e..f6eb1af61 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp @@ -612,7 +612,6 @@ std::vector CreateRNNConstant( const Node *parentNode, int index, const std::string &name, onnx::TensorProto &valueProto, const DeviceDescriptor& computeDevice) { std::vector inputs; - string parentONNXOpName = parentNode->OpType(); auto dataType = valueProto.data_type(); switch (dataType) @@ -633,6 +632,7 @@ std::vector CreateRNNConstant( } } + string parentONNXOpName = parentNode->OpType(); // index to LSTM inputs as specified in the ONNX document. // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8 if (parentONNXOpName == "LSTM") @@ -805,6 +805,154 @@ std::vector CreateRNNConstant( CNTK::LogicError("CreateRNNConstant received unepxpeted index: %d", index); } } + else if (parentONNXOpName == "GRU") + { + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6 + switch (index) + { + case GRUInputIndexX: + // X, should not come to here + return inputs; + case GRUInputIndexW: + { + // see ONNX spec for the tensor shape + int num_directions = valueProto.dims(0); + size_t rows = valueProto.dims(1); + size_t cols = valueProto.dims(2); + + // CNTK cpp requires shape: (input_size, 3 * hidden_size) + NDShape weightShape({ rows, cols }); + + int input_size = cols; + int cell_size = rows / 3; + + for (int dir = 0; dir < num_directions; dir++) + { + std::string nodeName = name + "_W_" + (char)dir; + int totalSizePerDirection = rows * cols; + + // TODO: what about double? + float *data = new float[totalSizePerDirection]; + for (size_t count = 0; count < totalSizePerDirection; count++) + { + int row = count / input_size; + int col = count % input_size; + int sourceIndex = dir * totalSizePerDirection + count; + int targetIndex = col * cell_size * GRUWeightDimensionHiddenMultiplier + row; + data[targetIndex] = valueProto.float_data()[sourceIndex]; + } + + Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice); + inputs.push_back(constant); + } + return inputs; + } + case GRUInputIndexR: + { + // split into H and H1 for CNTK GRU implementation + int num_directions = valueProto.dims(0); + size_t rows = valueProto.dims(1); + size_t cols = valueProto.dims(2); + + int input_size = cols; + int cell_size = rows / 3; + + NDShape hShape({ (size_t)cell_size * 2, (size_t)input_size }); + NDShape h1Shape({ (size_t)cell_size, (size_t)input_size }); + + inputs.resize(num_directions * 2); + for (int dir = 0; dir < num_directions; dir++) + { + std::string hNodeName = name + "_H_" + (char)dir; + std::string h1NodeName = name + "_H1_" + (char)dir; + int totalSizePerDirection = rows * cols; + + float *hData = new float[hShape.TotalSize()]; + float *h1Data = new float[h1Shape.TotalSize()]; + for (size_t count = 0; count < totalSizePerDirection; count++) + { + int row = count / input_size; + int col = count % input_size; + int block = row / cell_size; + int sourceIndex = dir * totalSizePerDirection + count; + if (block < CNTKGRUZRWeightMultiplier) + { + int targetIndex = col * cell_size * CNTKGRUZRWeightMultiplier + row; + hData[targetIndex] = valueProto.float_data()[sourceIndex]; + } + else + { + int targetIndex = col * cell_size + row - cell_size * CNTKGRUZRWeightMultiplier; + h1Data[targetIndex] = valueProto.float_data()[sourceIndex]; + } + } + + Constant constantH = CreateConstantWithRawData(&hData[0], hShape, hNodeName, computeDevice); + Constant constantH1 = CreateConstantWithRawData(&h1Data[0], h1Shape, h1NodeName, computeDevice); + inputs[dir] = constantH; + inputs[dir + num_directions] = constantH1; + } + return inputs; + } + case GRUInputIndexB: + // B + { + // see ONNX spec for the tensor shape + int num_directions = valueProto.dims(0); + int cell_size = valueProto.dims(1) / GRUBiasDimensionHiddenMultiplier; + // shape size is devided by 2 so that it only applies to input (CNTK) + // TODO: this incompatibility needs further investigation. + NDShape weightShape({ (size_t)(GRUBiasDimensionHiddenMultiplier / 2 * cell_size) }); + for (int dir = 0; dir < num_directions; dir++) + { + std::string nodeName = name + std::string(1, (char)dir) + LSTMInputBiasNameHint; + int totalSizePerDirection = GRUBiasDimensionHiddenMultiplier / 2 * cell_size; + float *data = new float[totalSizePerDirection]; + for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++) + { + int row = targetIndex; + // soruce is collmn major + int src_index = row; + // "fuse" + data[targetIndex] = + valueProto.float_data()[dir * 2 * totalSizePerDirection + src_index] + + valueProto.float_data()[dir * 2 * totalSizePerDirection + totalSizePerDirection + src_index]; + } + + Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice); + inputs.push_back(constant); + } + return inputs; + } + case GRUInputIndexSequenceLens: + return inputs; + case GRUInitialH: + { + // initial_h + int num_directions = valueProto.dims(0); + int cell_size = valueProto.dims(2); + NDShape weightShape({ (size_t)(cell_size) }); + for (int dir = 0; dir < num_directions; dir++) + { + std::string nodeName = name + std::string(1, (char)dir) + LSTMInputInitialHNameHint; + + float *data = new float[cell_size]; + for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++) + { + data[targetIndex] = valueProto.float_data()[dir * cell_size + targetIndex]; + } + + Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice); + inputs.push_back(constant); + } + return inputs; + } + break; + return inputs; + default: + CNTK::LogicError("CreateRNNConstant for GRU op received unepxpeted index: %d", index); + } + } else { NOT_IMPLEMENTED; @@ -890,6 +1038,36 @@ std::vector ONNXToCNTKHelper::CreateRNNLeafVariableOrConstant(const No LogicError("LSTM node has unexpected input"); } } + else if (parentONNXOpName == "GRU") + { + int inputIndex = CalculateNodeArgInputIndex(nodeArg, parentNode); + switch (inputIndex) + { + case GRUInputIndexX: + // X: `[seq_length, batch_size, input_size]`. + { + Variable inputVariable; + if (constructedNodeArgVariableMap.find(nodeArg->Name()) == constructedNodeArgVariableMap.end()) + { + DataType dataType = FromONNXType(nodeArg->ToProto().type()); + int input_size = shapeProto->dim(2).dim_value(); + NDShape shape({ (size_t)(input_size) }); + inputVariable = InputVariable(shape, dataType, ToWString(nodeArg->Name()), dynamicAxes); + constructedNodeArgVariableMap.insert(ONNXToCNTKVariableMap::value_type(nodeArg->Name(), inputVariable)); + } + return std::vector({ constructedNodeArgVariableMap[nodeArg->Name()] }); + } + // other inputs shall be ONNX constant node and be created as CNTK Constant in CreateRNNConstant + case GRUInputIndexW: + case GRUInputIndexR: + case GRUInputIndexB: + case GRUInputIndexSequenceLens: + case GRUInitialH: + NOT_IMPLEMENTED; + default: + LogicError("LSTM node has unexpected input"); + } + } else { NOT_IMPLEMENTED; @@ -1465,6 +1643,15 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector std::vector({"Sigmoid", "Tanh", "Tanh"})); return CreateLSTM(node, inputs, direction, activations, activation_alpha, activation_beta); } + else if (onnxOpName == "GRU") + { + const string direction = GetNamedAttributeAsString(node, "direction"); + std::vector activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector()); + std::vector activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector()); + const std::vector activations = GetNamedAttributeAsStringVec(node, "activations", + std::vector({ "Sigmoid", "Tanh" })); + return CreateGRU(node, inputs, direction, activations, activation_alpha, activation_beta); + } if (onnxOpName == "FC") { return CreateCNTKFCNode(ToWString(node->Name()), inputs); @@ -2275,7 +2462,7 @@ std::vector ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXTo const Node *parentNode; int childIndex; std::tie(parentNode, childIndex) = FindParentAndChildIndex(node); - if (parentNode != nullptr && parentNode->OpType() == "LSTM") + if (parentNode != nullptr && Operators::IsRNNOp(parentNode->OpType())) { std::vector cntkFunctions = CreateRNNConstantOp(graph, node, parentNode, childIndex, computeDevice); if (!cntkFunctions.empty()) @@ -2613,7 +2800,7 @@ std::vector ONNXToCNTKHelper::CreateCNTKInputsStartingFromIndex(const else { std::string parentONNXOpName = node->OpType(); - if (parentONNXOpName == "LSTM") + if (Operators::IsRNNOp(node->OpType())) { std::vector inputVariables = CreateRNNLeafVariableOrConstant(nodeArg, node, graph, constructedNodeArgVariableMap, computeDevice); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp index 4b2afa499..a45cd5bd2 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp @@ -3,6 +3,57 @@ // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "RNNHelper.h" +#include "Operators.h" + +using namespace CNTK::ONNX; + +std::string MapActivationNameONNXToCNTK(const std::string &onnxOp) +{ + if (onnxOp == "Relu") + return "ReLU"; + else if (onnxOp == "Sigmoid") + return "StableSigmoid"; + else if (onnxOp == "LeakyRelu") + return "LeakyReLU"; + else if (onnxOp == "ThresholdedRelu") + return "ThresholdedReLU"; + else if (onnxOp == "Elu") + return "ELU"; + else + return onnxOp; +} + +std::string MapActivationNameCNTKToONNX(const std::string &cntkOp) +{ + if (cntkOp == "ReLU") + return "Relu"; + else if (cntkOp == "StableSigmoid") + return "Sigmoid"; + else if (cntkOp == "LeakyReLU") + return "LeakyRelu"; + else if (cntkOp == "ThresholdedReLU") + return "ThresholdedRelu"; + else if (cntkOp == "ELU") + return "Elu"; + else + return cntkOp; +} + +bool IsActivationOp(const std::string &activationName) +{ + return + activationName == "Relu" || activationName == "ReLU" || + activationName == "Tanh" || + activationName == "Sigmoid" || activationName == "StableSigmoid" || + activationName == "Affine" || + activationName == "LeakyRelu" || activationName == "LeakyReLU" || + activationName == "ThresholdedRelu" || activationName == "ThresholdedReLU" || + activationName == "ScaledTanh" || + activationName == "HardSigmoid" || + activationName == "Elu" || activationName == "ELU" || + activationName == "Softsign" || + activationName == "Softplus"; +} std::function ActivationMap(const std::string &activationName) { @@ -37,7 +88,7 @@ std::function ActivationMap(const std::string &act } else { - CNTK::LogicError("LSTM does not support activation: %s", activationName.c_str()); + CNTK::LogicError("Recurrent Op does not support activation: %s", activationName.c_str()); } } @@ -82,9 +133,9 @@ GetActivations(const std::vector &activations, const std::vector iofActivationOp, cellActivationOp, hiddenActivationOp; - if (hasBeta) + if (hasAlphaBeta) { iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]); cellActivationOp = ActivationMap(activations[cellActivation], activation_alpha[cellActivation], activation_beta[cellActivation]); @@ -107,6 +158,38 @@ GetActivations(const std::vector &activations, const std::vector, std::function> +GetGRUActivations(const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta, int direction) +{ + if (activations.size() < (direction + 1) * GRUActivationCount) + CNTK::LogicError("LSTM activations shall be %d or %d of strings", GRUActivationCount, GRUActivationCount * 2); + + // + int fActivationIndex = direction * GRUActivationCount + GRUActivationFIndex; + int gActivationIndex = direction * GRUActivationCount + GRUActivationGIndex; + + bool hasAlpha = activation_alpha.size() == (direction + 1) * GRUActivationCount; + bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * GRUActivationCount; + std::function fActivationOp, gActivationOp; + if (hasAlphaBeta) + { + fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex], activation_beta[fActivationIndex]); + gActivationOp = ActivationMap(activations[gActivationIndex], activation_alpha[gActivationIndex], activation_beta[gActivationIndex]); + } + else if (hasAlpha) + { + fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex]); + gActivationOp = ActivationMap(activations[gActivationIndex], activation_alpha[gActivationIndex]); + } + else + { + fActivationOp = ActivationMap(activations[fActivationIndex]); + gActivationOp = ActivationMap(activations[gActivationIndex]); + } + + return std::make_tuple(fActivationOp, gActivationOp); +} + std::pair LSTMPCell(Variable input, const std::function &iofActivationOp, const std::function &cellActivationOp, @@ -155,6 +238,53 @@ std::pair LSTMPCell(Variable input, return{ h, c }; } +FunctionPtr GRUCell(Variable input, + const std::function &fActivationOp, + const std::function &gActivationOp, + Variable prevOutput, + Constant &W, Constant &R, Constant &H1, Constant &B) +{ + size_t outputDim = prevOutput.Shape()[0]; + int stacked_dim = (int)outputDim; + + FunctionPtr projx3; + if (B.IsInitialized()) + projx3 = Plus(B, Times(W, input)); + else + projx3 = Times(W, input); + + FunctionPtr projh2 = Times(R, prevOutput); + + // both CNTK and ONNX weight and bias are in zrh order. + std::vector stack_axis({ Axis(-1) }); + FunctionPtr zt_proj = + Slice(projx3, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim }) + + Slice(projh2, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim }); + + FunctionPtr rt_proj = + Slice(projx3, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim }) + + Slice(projh2, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim }); + + FunctionPtr ct_proj = + Slice(projx3, stack_axis, { 2 * stacked_dim }, { 3 * stacked_dim }); + + FunctionPtr zt = fActivationOp(zt_proj); + + FunctionPtr rt = fActivationOp(rt_proj); + + FunctionPtr rs = ElementTimes(prevOutput, rt); + + FunctionPtr ct = gActivationOp(ct_proj + Times(H1, rs)); + + Constant one = W.GetDataType() == DataType::Float ? Constant::Scalar(1.0f) : Constant::Scalar(1.0); + + FunctionPtr ht = ElementTimes(one - zt, ct) + ElementTimes(zt, prevOutput); + + FunctionPtr h = ht; + + return ht; +} + #include "PrimitiveFunction.h" #include "BlockFunction.h" @@ -185,6 +315,27 @@ std::tuple LSTMPComponent(Variable input, return std::make_tuple(LSTMCell.first, LSTMCell.second); } +FunctionPtr GRUComponent(Variable input, + const NDShape& cellShape, + const std::function &fActivationOp, + const std::function &gActivationOp, + const std::function& recurrenceHookH, + Constant &W, Constant &R, Constant &H1, Constant &B) +{ + auto dh = PlaceholderVariable(cellShape, input.DynamicAxes()); + auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes()); + + auto gruCell = GRUCell( + inputPlaceholder, + fActivationOp, gActivationOp, + dh, W, R, H1, B); + + auto actualDh = recurrenceHookH(gruCell); + + gruCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } }); + return gruCell; +} + const std::vector FindByNameHint(const std::vector &inputs, const std::string &hint) { std::vector variables; @@ -218,7 +369,7 @@ Variable GetInitialStateVariable(const std::vector &inputs, int numDir FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector &inputs, const std::string &direction, const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta) { - int numDirections = direction == LSTMDirectionBidirection ? 2 : 1; + int numDirections = direction == RNNDirectionBidirection ? 2 : 1; std::vector outputHs; for (int dir = 0; dir < numDirections; dir++) { @@ -237,8 +388,8 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector &in else if (numDirections == 2 && biasVariables.size() == 2) B = biasVariables[dir]; - Variable initHVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialCNameHint, X.GetDataType()); - Variable initCVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialHNameHint, X.GetDataType()); + Variable initHVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialHNameHint, X.GetDataType()); + Variable initCVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialCNameHint, X.GetDataType()); std::vector peepholeVariables = FindByNameHint(inputs, LSTMInputPeepholeNameHint); Variable Ci, Cf, Co; @@ -271,17 +422,23 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector &in FunctionPtr outputC; // if it is bidirectional LSTM, the second one will be the backword one. - bool go_backwards = direction == LSTMDirectionReverse || (numDirections == 2 && dir == 1); + bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1); - std::function futureValueRecurrenceHook; + std::function recurrenceHookH, recurrenceHookC; if (go_backwards) - futureValueRecurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); }; + { + recurrenceHookH = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); }; + recurrenceHookC = [initCVariable](const Variable& x) { return FutureValue(x, initCVariable); }; + } else - futureValueRecurrenceHook = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); }; + { + recurrenceHookH = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); }; + recurrenceHookC = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); }; + } std::tie(outputH, outputC) = LSTMPComponent( X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp, - futureValueRecurrenceHook, futureValueRecurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B, + recurrenceHookH, recurrenceHookC, (Constant &)W, (Constant &)R, (Constant &)B, (Constant &)Ci, (Constant &)Cf, (Constant &)Co); outputHs.push_back(outputH); } @@ -293,3 +450,475 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector &in return Splice(operands, Axis(0), ToWString(node->Name())); } } + +FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector &inputs, const std::string &direction, + const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta) +{ + int numDirections = direction == RNNDirectionBidirection ? 2 : 1; + std::vector outputHs; + for (int dir = 0; dir < numDirections; dir++) + { + std::function fActivationOp, gActivationOp; + std::tie, std::function> + (fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir); + + // the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1 + Variable X = inputs[0]; + Variable W = inputs[1 * numDirections + dir - ((numDirections == 2) ? 1 : 0)]; + Variable R = inputs[2 * numDirections + dir - ((numDirections == 2) ? 1 : 0)]; + + // TODO: get H1 + Variable H1 = inputs[3 * numDirections + dir - ((numDirections == 2) ? 1 : 0)]; + + Variable B; + std::vector biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint); + if (numDirections == 1 && biasVariables.size() >= 1) + B = biasVariables[0]; + else if (numDirections == 2 && biasVariables.size() == 2) + B = biasVariables[1]; + + Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType()); + + // ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8 + // tells that weight has shape [num_directions, 4*hidden_size, input_size] + // here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM + // as two separate LSTM. Therefore we can divide the dimension of the first axis + // by 4 to get the hidden size. + int hiddenDim = W.Shape()[0] / GRUWeightDimensionHiddenMultiplier; + + FunctionPtr outputH; + + // if it is bidirectional LSTM, the second one will be the backword one. + bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1); + + std::function recurrenceHook; + if (go_backwards) + recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); }; + else + recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); }; + + outputH = GRUComponent( + X, { (size_t)hiddenDim }, fActivationOp, gActivationOp, + recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)H1, (Constant &)B); + outputHs.push_back(outputH); + } + if (outputHs.size() == 1) + return outputHs[0]; + else + { + std::vector operands({ outputHs[0], outputHs[1] }); + return Splice(operands, Axis(0), ToWString(node->Name())); + } +} + +template +void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set& visitedFunctions, + FunctionType preFunctor, FunctionType postFunctor) +{ + visitedFunctions.insert(cntkFunction); + preFunctor(cntkFunction); + + std::vector functionInputs = cntkFunction->Inputs(); + for (const auto& input : functionInputs) + { + if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end()) + { + const auto& inputFunction = input.Owner(); + TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor); + } + } + + postFunctor(cntkFunction); +} + +bool IsSupportedRNNActivation(const std::wstring &cntkOpName) +{ + static std::vector supportedRNNActivations( + { + L"ReLU", + L"Tanh", + L"StableSigmoid" + }); + return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) != + supportedRNNActivations.cend(); +} + +std::string FindActivation(const std::vector &path, int nth) +{ + int count = 0; + for (std::vector::const_iterator it = path.begin(); it != path.end(); it++) + { + std::wstring opName = (*it)->OpName(); + if (IsSupportedRNNActivation(opName)) + { + if (count == nth) + { + std::unordered_multimap::const_iterator itLookup = Operators::CntkToONNXLookup().find(opName); + if (itLookup == Operators::CntkToONNXLookup().cend()) + CNTK::LogicError("Invalid activation (%s)", ToString(opName).c_str()); + + std::unordered_map::const_iterator itMap = (*itLookup).second.map.find(opName); + if (itMap == (*itLookup).second.map.cend()) + CNTK::LogicError("Invalid activation (%s)", ToString(opName).c_str()); + return itMap->second; + } + count++; + } + } + return ""; +} + +Variable GetPeepholeVariableFromOp(FunctionPtr peepholeOp) +{ + // peephole variable is that child of peepholeOp that is neither stabilizer nor place holder + if (peepholeOp->OpName() != L"ElementTimes") + CNTK::LogicError("Peephole operation must be ElementTimes"); + + Variable peepholeVariable; + FunctionPtr stabilizerOp; + for (int i = 0; i < peepholeOp->Inputs().size(); i++) + { + if (peepholeOp->Inputs()[i].Owner() && peepholeOp->Inputs()[i].Owner()->OpName() == L"Stabilizer") + { + stabilizerOp = peepholeOp->Inputs()[i].Owner(); + } + else if (peepholeOp->Inputs()[i].IsConstant() || peepholeOp->Inputs()[i].IsParameter()) + { + if (!peepholeVariable.IsInitialized()) + peepholeVariable = peepholeOp->Inputs()[i]; + else + CNTK::LogicError("Cannot find peephole variable from peephole op. Multiple qualified variables found."); + } + } + + if (!peepholeVariable.IsInitialized()) + CNTK::LogicError("Cannot find peephole variable from peephole op."); + return peepholeVariable; +} + +// this method helps getting a stabilizer op from its parent/grandparent op. +// the parent op can be Times or ElementTimes. The grandparent op can be a plus op. +FunctionPtr GetStabilizerOp(FunctionPtr parentOp) +{ + FunctionPtr timesOp; + if (parentOp->OpName() == L"Plus") + { + for (int i = 0; i < parentOp->Inputs().size(); i++) + { + if (parentOp->Inputs()[i].Owner() && + (parentOp->Inputs()[i].Owner()->OpName() == L"Times" || + parentOp->Inputs()[i].Owner()->OpName() == L"ElementTimes")) + { + timesOp = parentOp->Inputs()[i].Owner(); + break; + } + } + } + else if (parentOp->OpName() == L"Times" || parentOp->OpName() == L"ElementTimes") + { + timesOp = parentOp; + } + + if (!timesOp) + { + CNTK::LogicError("Cannot find stabilizer op. A stabilizer op must be from Times or ElementTimes ops or skipped from a Plus op."); + } + + for (int j = 0; j < timesOp->Inputs().size(); j++) + { + if (timesOp->Inputs()[j].Owner() && timesOp->Inputs()[j].Owner()->OpName() == L"Stabilizer") + { + return timesOp->Inputs()[j].Owner(); + } + } + + return nullptr; +} + +double GetScaler(Variable variable) +{ + NDArrayViewPtr v = variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value(); + NDArrayViewPtr cpuV = v->DeepClone(); + cpuV->ChangeDevice(DeviceDescriptor::CPUDevice()); + + switch (variable.GetDataType()) + { + case DataType::Float: + return *((float *)cpuV->DataBuffer()); + case DataType::Double: + return *((double *)cpuV->DataBuffer()); + default: + NOT_IMPLEMENTED; + } +} + +double GetStabilizerCoef(const FunctionPtr stabilizerDhOp) +{ + double alpha = GetScaler(stabilizerDhOp->Inputs()[3]); + double steepness = GetScaler(stabilizerDhOp->Inputs()[1]); + return (log(exp(alpha * steepness) + 1.0F) / steepness); +} + +void GetDelayOps(const std::vector &inputVars, + std::vector &pastValueOps, std::vector &futureValueOps) +{ + for (std::vector::const_iterator it = inputVars.cbegin(); it != inputVars.cend(); ++it) + { + if ((*it).Owner() != nullptr && (*it).Owner()->OpName() == L"PastValue") + pastValueOps.push_back((*it).Owner()); + else if ((*it).Owner() != nullptr && (*it).Owner()->OpName() == L"FutureValue") + futureValueOps.push_back((*it).Owner()); + } +} + +// A CNTK LSTM op is created with stacked matmul followed by a slice op for 4 gates. +// Slice order tells which graph path is for which gate. This method is +// to traverse the graph to find the 4 paths along the 4 gates. It helps to +// subsequently find needed attributes in order to build an ONNX LSTM op. +void TraceLSTMPathes(const FunctionPtr& src, + string &f_activation, + string &g_activation, + string &h_activation, + RNNDirection &direction, + Variable &initStateH, + Variable &initStateC, + Variable &peepholeCi, + Variable &peepholeCo, + Variable &peepholeCf, + double &stabilizer_dh, + double &stabilizer_dc, + double &stabilizer_c) +{ + // src has to be an LSTM node. + std::vector inputVars = src->Inputs(); + std::vector pastValueOps, futureValueOps; + GetDelayOps(inputVars, pastValueOps, futureValueOps); + + // with CNTK LSTM, the first delay node is for H, the second one is for C + // indices here also coresponding with CNTK python layer code. + if (pastValueOps.size() == 2 && futureValueOps.size() == 0) + { + direction = RNNDirection::Forward; + initStateH = pastValueOps[0]->Inputs()[1]; + initStateC = pastValueOps[1]->Inputs()[1]; + } + else if (pastValueOps.size() == 0 && futureValueOps.size() == 2) + { + direction = RNNDirection::Backward; + initStateH = futureValueOps[0]->Inputs()[1]; + initStateC = futureValueOps[1]->Inputs()[1]; + } + else + { + CNTK::LogicError("Node %s (%s) is not a valid LSTM node", ToString(src->Name()).c_str(), ToString(src->Uid()).c_str()); + } + + // set up traverse boundary + std::unordered_set visitedFunctions; + for (std::vector::const_iterator it = inputVars.begin(); it != inputVars.end(); it++) + { + visitedFunctions.insert(it->Owner()); + } + + // First find the peephole op node. + // see CNTK\bindings\python\cntk\layers\blocks.py node references. + std::vector> pathesBitBftJoint; + { + std::vector currentPeepholePath; + + // make a copy of traverse boundary + std::unordered_set peepHoleVisitedFunctions = visitedFunctions; + + // traverse to find the joint of bit and bft + TraverseGraphWithPrePostActions(src->BlockRoot(), + peepHoleVisitedFunctions, + (std::function)[ + &peepHoleVisitedFunctions, &pathesBitBftJoint, ¤tPeepholePath](const FunctionPtr& function) + { + currentPeepholePath.push_back(function); + if (function->OpName() == L"Plus" && + function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" && + function->Inputs()[1].Owner() && function->Inputs()[1].Owner()->OpName() == L"ElementTimes") + { + pathesBitBftJoint.push_back(currentPeepholePath); + peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(), + [function](FunctionPtr f) {return function == f; })); + } + }, + (std::function)[¤tPeepholePath](const FunctionPtr& function) + { + currentPeepholePath.pop_back(); + }); + } + + FunctionPtr peepholeCoOp; + bool haspeephole = pathesBitBftJoint.size() == 3; + if (haspeephole) + { + // the last ElementTimes op is the peephole op + std::vector &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(), + [](std::vector &p1, std::vector &p2) {return p1.size() < p2.size(); }); + std::vector::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(), + [](FunctionPtr function) {return function->OpName() == L"ElementTimes"; }); + if (itPeepholeOp == peepholePath.rend()) + { + CNTK::LogicError("Cannot find peephole op from a LSTM graph"); + } + + peepholeCoOp = *itPeepholeOp; + peepholeCo = GetPeepholeVariableFromOp(peepholeCoOp); + + FunctionPtr stabilizer_h_op = GetStabilizerOp(peepholeCoOp); + if (stabilizer_h_op) + { + stabilizer_c = GetStabilizerCoef(stabilizer_h_op); + } + } + + std::vector> pathesToPlusSlice; + std::vector currentPath; + + if (haspeephole) + // so that traverse will not be affected by the peephole path + visitedFunctions.insert(peepholeCoOp); + + TraverseGraphWithPrePostActions(src->BlockRoot(), + visitedFunctions, + (std::function)[&pathesToPlusSlice, ¤tPath](const FunctionPtr& function) + { + currentPath.push_back(function); + if (function->OpName() == L"Slice") + { + FunctionPtr functionSource = function->Inputs()[0].Owner(); + if (functionSource->OpName() == L"Plus") + { + pathesToPlusSlice.push_back(currentPath); + } + } + }, + (std::function)[¤tPath](const FunctionPtr& function) + { + currentPath.pop_back(); + }); + + // 4 gates of LSTM shall be traced. + if (pathesToPlusSlice.size() != 4) + { + CNTK::LogicError("pathesToPlusSlice.size() != 4"); + } + + std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(), + [](const std::vector& path1, const std::vector& path2) + { + FunctionPtr slice1 = *path1.rbegin(); + FunctionPtr slice2 = *path2.rbegin(); + int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value(); + int beginIndex2 = slice2->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value(); + return beginIndex1 < beginIndex2; + }); + + // This code is heavily coupled with CNTK python layer code: + // https://github.com/Microsoft/CNTK/blob/44c626a483edeaff97b4f7a46847b055a1d483aa/bindings/python/cntk/layers/blocks.py#L261 + // pathesToPlusSlice is ordered by slice index so we are able to recover corresponding path here. + std::vector &ht_it_path = pathesToPlusSlice[0]; + std::vector &ht_bit_path = pathesToPlusSlice[1]; + std::vector &ht_ft_path = pathesToPlusSlice[2]; + std::vector &ht_ot_path = pathesToPlusSlice[3]; + + f_activation = MapActivationNameCNTKToONNX(FindActivation(ht_ot_path, 0)); + g_activation = MapActivationNameCNTKToONNX(FindActivation(ht_bit_path, 1)); + h_activation = MapActivationNameCNTKToONNX(FindActivation(ht_bit_path, 0)); + + // stabilizer_dh + FunctionPtr stackedProjPlusOp = ht_it_path[ht_it_path.size() - 1]->Inputs()[0].Owner(); + FunctionPtr stabilizerDhOp = GetStabilizerOp(stackedProjPlusOp); + if (stabilizerDhOp) + { + stabilizer_dh = GetStabilizerCoef(stabilizerDhOp); + } + + if (haspeephole) + { + { + // Ci merges to ht_it_path via element-wise time + FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2]; + FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? + plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner(); + peepholeCi = GetPeepholeVariableFromOp(peepholeOp); + + } + { + // Cf merges to ht_ft_path via element-wise time + FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2]; + FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? + plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner(); + peepholeCf = GetPeepholeVariableFromOp(peepholeOp); + + FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp); + if (stabilizerDcOp) + stabilizer_dc = GetStabilizerCoef(stabilizerDcOp); + } + } +} + +FunctionPtr TraverseGraphFindFirstRNNOp(FunctionPtr src) +{ + std::vector front = src->Inputs(), back; + + while (!front.empty()) + { + for (auto f : front) + { + if (f.IsOutput() && f.Owner()) + if (IsActivationOp(ToString(f.Owner()->OpName()))) + return f.Owner(); + else + { + for (auto i : f.Owner()->Inputs()) + back.push_back(i); + } + } + front = back; + back.clear(); + } + return nullptr; +} + +void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation, + RNNDirection &direction, Variable &initStateH) +{ + std::vector inputVars = src->Inputs(); + std::vector pastValueOps, futureValueOps; + GetDelayOps(inputVars, pastValueOps, futureValueOps); + + // indices here coresponding with CNTK python layer code. + if (pastValueOps.size() == 1 && futureValueOps.size() == 0) + { + direction = RNNDirection::Forward; + initStateH = pastValueOps[0]->Inputs()[1]; + } + else if (pastValueOps.size() == 0 && futureValueOps.size() == 1) + { + direction = RNNDirection::Backward; + initStateH = futureValueOps[0]->Inputs()[1]; + } + else + { + CNTK::LogicError("Node %s (%s) is not a valid GRU node", ToString(src->Name()).c_str(), ToString(src->Uid()).c_str()); + } + + // set up traverse boundary + std::unordered_set visitedFunctions; + for (std::vector::const_iterator it = inputVars.begin(); it != inputVars.end(); it++) + { + visitedFunctions.insert(it->Owner()); + } + + std::vector> pathesToPlusSlice; + std::vector currentPath; + + FunctionPtr gActivation = TraverseGraphFindFirstRNNOp(src->BlockRoot()); + + f_activation = "Sigmoid"; + g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName())); +} diff --git a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h index 77270a8e9..9113e7eaa 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h @@ -25,6 +25,18 @@ const std::string LSTMInputInitialHNameHint = "_initial_h_"; const std::string LSTMInputInitialCNameHint = "_initial_c_"; const std::string LSTMInputPeepholeNameHint = "_peephole_"; +const std::string GRUInputInitialHNameHint = "_initial_h_"; + +// https://github.com/onnx/onnx/blob/master/docs/Operators.md#attributes-18 +// https://github.com/onnx/onnx/blob/master/docs/Operators.md#attributes-27 +// https://github.com/onnx/onnx/blob/master/docs/Operators.md#attributes-39 +// CNTK RNN ops always output sequence. +// ONNX requires to set the output_sequence attribute to 1 to output sequence. +enum +{ + RNNOutputSequence = 1 +}; + enum { LSTMInputIndexX = 0, @@ -52,10 +64,18 @@ enum { LSTMPeepholeCount = 3 }; +// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8 +// size of weight/bias matrix is a multiple of hidden size +enum +{ + LSTMWeightDimensionHiddenMultiplier = 4, + LSTMBiasDimensionHiddenMultiplier = 8 +}; + typedef enum { Forward, Backward, -} LSTMDirection; +} RNNDirection; enum { @@ -69,9 +89,64 @@ enum CNTKLSTMOutputYhIndex = 0, CNTKLSTMOutputChIndex = 1 }; -const string LSTMDirectionBidirection = "bidirectional"; -const string LSTMDirectionReverse = "reverse"; -const string LSTMDirectionForward = "forward"; + +enum +{ + GRUActivationFIndex = 0, + GRUActivationGIndex = 1, + GRUActivationCount = 2 +}; + +enum +{ + GRUInputIndexX = 0, + GRUInputIndexW = 1, + GRUInputIndexR = 2, + GRUInputIndexB = 3, + GRUInputIndexSequenceLens = 4, + GRUInitialH = 5, +}; + +// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6 +// size of weight/bias matrix is a multiple of hidden size +enum +{ + GRUWeightDimensionHiddenMultiplier = 3, + GRUBiasDimensionHiddenMultiplier = 6 +}; + +enum +{ + CNTKGRUZRWeightMultiplier = 2 +}; +enum +{ + CNTKGRUBiasIndex = 1, + CNTKGRUWeightIndex = 2, + CNTKGRUHiddenWeightZRIndex = 3, + CNTKGRUHiddenWeightHIndex = 4, + CNTKGRUPastOrFutureIndex = 5, + CNTKGRUInputIndex = 6, + CNTKGRUInputCount = 7 +}; + + +const string RNNDirectionBidirection = "bidirectional"; +const string RNNDirectionReverse = "reverse"; +const string RNNDirectionForward = "forward"; FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector &inputs, const std::string &direction, const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta); + +FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector &inputs, const std::string &direction, + const std::vector &activations, const std::vector &activation_alpha, const std::vector &activation_beta); + +void TraceLSTMPathes(const FunctionPtr& src, string &f_activation, string &g_activation, string &h_activation, + RNNDirection &direction, Variable &initStateH, Variable &initStateC, Variable &peepholeCi, Variable &peepholeCo, Variable &peepholeCf, + double &stabilizer_dh, double &stabilizer_dc, double &stabilizer_c); + +void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation, + RNNDirection &direction, Variable &initStateH); + +std::string MapActivationNameONNXToCNTK(const std::string &onnxOp); +std::string MapActivationNameCNTKToONNX(const std::string &cntkOp); \ No newline at end of file diff --git a/bindings/python/cntk/tests/onnx_op_test.py b/bindings/python/cntk/tests/onnx_op_test.py index 0e89ce51b..08a47eeaf 100644 --- a/bindings/python/cntk/tests/onnx_op_test.py +++ b/bindings/python/cntk/tests/onnx_op_test.py @@ -8,6 +8,7 @@ import numpy as np import cntk as C import pytest from cntk.ops.tests.ops_test_utils import cntk_device +from itertools import product ############# #helpers @@ -431,6 +432,43 @@ def test_Greater(tmpdir): model = C.greater([41., 42., 43.], [42., 42., 42.]) verify_no_input(model, tmpdir, 'Greater_0') +#GRU +def MakeGRUNameFromConfig(backward, initial_state, activtion): + model_name = 'GRU.' + activtion.__name__ + if (initial_state != 0): + model_name += '.initial' + if (backward): + model_name += '.backward' + else: + model_name += '.forward' + return model_name + +direction_options = [False, True] +activation_options = [C.tanh] +initial_state_options = [0] + +input_dim = 2 +cell_dim = 3 +batch_size = 1 +sequence_len = 5 + +def test_GRU(tmpdir): + for config in list(product(direction_options, initial_state_options, activation_options)): + model_filename = MakeGRUNameFromConfig(*config) + print(model_filename) + backward, initial_state, activation = config + + x = C.input_variable(input_dim, dynamic_axes=[C.Axis.default_batch_axis(), C.Axis('sequenceAxis')]) + GRUModel = C.layers.Recurrence(C.layers.GRU(cell_dim, + activation = activation), + initial_state = initial_state, + go_backwards=backward)(x) + #CLG.plot(GRUModel, filename=cntk_pdf_filename) + #plot_block_internals(GRUModel, 'GRU', model_filename) + data = np.random.uniform(low=0.0, high=1.0, size=(batch_size, sequence_len, input_dim)).astype('f') + verify_one_input(GRUModel, data, tmpdir, model_filename) + + #Hardmax def test_Hardmax(tmpdir): data = np.asarray([1., 1., 2., 3.], dtype=np.float32) @@ -522,21 +560,18 @@ def test_LRN(tmpdir): verify_one_input(model, img, tmpdir, 'LRN_1') #LSTM -from cntk.layers import * -from itertools import product - def CreateLSTMModel(activation, peepholes, self_stabilization, cell_dim, initial_state): - return Sequential([ - Recurrence(LSTM(cell_dim, - use_peepholes = peepholes, - activation = activation, - enable_self_stabilization = self_stabilization), - initial_state = initial_state) - ]) + return C.layers.Sequential([ + C.layers.Recurrence(C.layers.LSTM(cell_dim, + use_peepholes = peepholes, + activation = activation, + enable_self_stabilization = self_stabilization), + initial_state = initial_state) + ]) # lstm attributes use_peepholes_options = [False] @@ -567,7 +602,7 @@ def test_LSTM(tmpdir): model_filename = MakeLSTMNameFromConfig(*config) use_peepholes, enable_self_stabilization, initial_state, activation = config - x = C.input_variable(input_dim, dynamic_axes=[Axis.default_batch_axis(), C.Axis('sequenceAxis')]) + x = C.input_variable(input_dim, dynamic_axes=[C.Axis.default_batch_axis(), C.Axis('sequenceAxis')]) LSTMmodel = CreateLSTMModel(peepholes = use_peepholes, activation = activation, initial_state = initial_state,