Support ONNX GRU
This commit is contained in:
Родитель
44c626a483
Коммит
f1f3bb4e63
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -612,7 +612,6 @@ std::vector<Variable> CreateRNNConstant(
|
|||
const Node *parentNode, int index, const std::string &name, onnx::TensorProto &valueProto, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
std::vector<Variable> inputs;
|
||||
string parentONNXOpName = parentNode->OpType();
|
||||
auto dataType = valueProto.data_type();
|
||||
|
||||
switch (dataType)
|
||||
|
@ -633,6 +632,7 @@ std::vector<Variable> CreateRNNConstant(
|
|||
}
|
||||
}
|
||||
|
||||
string parentONNXOpName = parentNode->OpType();
|
||||
// index to LSTM inputs as specified in the ONNX document.
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||
if (parentONNXOpName == "LSTM")
|
||||
|
@ -805,6 +805,154 @@ std::vector<Variable> CreateRNNConstant(
|
|||
CNTK::LogicError("CreateRNNConstant received unepxpeted index: %d", index);
|
||||
}
|
||||
}
|
||||
else if (parentONNXOpName == "GRU")
|
||||
{
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6
|
||||
switch (index)
|
||||
{
|
||||
case GRUInputIndexX:
|
||||
// X, should not come to here
|
||||
return inputs;
|
||||
case GRUInputIndexW:
|
||||
{
|
||||
// see ONNX spec for the tensor shape
|
||||
int num_directions = valueProto.dims(0);
|
||||
size_t rows = valueProto.dims(1);
|
||||
size_t cols = valueProto.dims(2);
|
||||
|
||||
// CNTK cpp requires shape: (input_size, 3 * hidden_size)
|
||||
NDShape weightShape({ rows, cols });
|
||||
|
||||
int input_size = cols;
|
||||
int cell_size = rows / 3;
|
||||
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
{
|
||||
std::string nodeName = name + "_W_" + (char)dir;
|
||||
int totalSizePerDirection = rows * cols;
|
||||
|
||||
// TODO: what about double?
|
||||
float *data = new float[totalSizePerDirection];
|
||||
for (size_t count = 0; count < totalSizePerDirection; count++)
|
||||
{
|
||||
int row = count / input_size;
|
||||
int col = count % input_size;
|
||||
int sourceIndex = dir * totalSizePerDirection + count;
|
||||
int targetIndex = col * cell_size * GRUWeightDimensionHiddenMultiplier + row;
|
||||
data[targetIndex] = valueProto.float_data()[sourceIndex];
|
||||
}
|
||||
|
||||
Constant constant = CreateConstantWithRawData(&data[0], weightShape, nodeName, computeDevice);
|
||||
inputs.push_back(constant);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
case GRUInputIndexR:
|
||||
{
|
||||
// split into H and H1 for CNTK GRU implementation
|
||||
int num_directions = valueProto.dims(0);
|
||||
size_t rows = valueProto.dims(1);
|
||||
size_t cols = valueProto.dims(2);
|
||||
|
||||
int input_size = cols;
|
||||
int cell_size = rows / 3;
|
||||
|
||||
NDShape hShape({ (size_t)cell_size * 2, (size_t)input_size });
|
||||
NDShape h1Shape({ (size_t)cell_size, (size_t)input_size });
|
||||
|
||||
inputs.resize(num_directions * 2);
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
{
|
||||
std::string hNodeName = name + "_H_" + (char)dir;
|
||||
std::string h1NodeName = name + "_H1_" + (char)dir;
|
||||
int totalSizePerDirection = rows * cols;
|
||||
|
||||
float *hData = new float[hShape.TotalSize()];
|
||||
float *h1Data = new float[h1Shape.TotalSize()];
|
||||
for (size_t count = 0; count < totalSizePerDirection; count++)
|
||||
{
|
||||
int row = count / input_size;
|
||||
int col = count % input_size;
|
||||
int block = row / cell_size;
|
||||
int sourceIndex = dir * totalSizePerDirection + count;
|
||||
if (block < CNTKGRUZRWeightMultiplier)
|
||||
{
|
||||
int targetIndex = col * cell_size * CNTKGRUZRWeightMultiplier + row;
|
||||
hData[targetIndex] = valueProto.float_data()[sourceIndex];
|
||||
}
|
||||
else
|
||||
{
|
||||
int targetIndex = col * cell_size + row - cell_size * CNTKGRUZRWeightMultiplier;
|
||||
h1Data[targetIndex] = valueProto.float_data()[sourceIndex];
|
||||
}
|
||||
}
|
||||
|
||||
Constant constantH = CreateConstantWithRawData(&hData[0], hShape, hNodeName, computeDevice);
|
||||
Constant constantH1 = CreateConstantWithRawData(&h1Data[0], h1Shape, h1NodeName, computeDevice);
|
||||
inputs[dir] = constantH;
|
||||
inputs[dir + num_directions] = constantH1;
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
case GRUInputIndexB:
|
||||
// B
|
||||
{
|
||||
// see ONNX spec for the tensor shape
|
||||
int num_directions = valueProto.dims(0);
|
||||
int cell_size = valueProto.dims(1) / GRUBiasDimensionHiddenMultiplier;
|
||||
// shape size is devided by 2 so that it only applies to input (CNTK)
|
||||
// TODO: this incompatibility needs further investigation.
|
||||
NDShape weightShape({ (size_t)(GRUBiasDimensionHiddenMultiplier / 2 * cell_size) });
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
{
|
||||
std::string nodeName = name + std::string(1, (char)dir) + LSTMInputBiasNameHint;
|
||||
int totalSizePerDirection = GRUBiasDimensionHiddenMultiplier / 2 * cell_size;
|
||||
float *data = new float[totalSizePerDirection];
|
||||
for (size_t targetIndex = 0; targetIndex < totalSizePerDirection; targetIndex++)
|
||||
{
|
||||
int row = targetIndex;
|
||||
// soruce is collmn major
|
||||
int src_index = row;
|
||||
// "fuse"
|
||||
data[targetIndex] =
|
||||
valueProto.float_data()[dir * 2 * totalSizePerDirection + src_index] +
|
||||
valueProto.float_data()[dir * 2 * totalSizePerDirection + totalSizePerDirection + src_index];
|
||||
}
|
||||
|
||||
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
|
||||
inputs.push_back(constant);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
case GRUInputIndexSequenceLens:
|
||||
return inputs;
|
||||
case GRUInitialH:
|
||||
{
|
||||
// initial_h
|
||||
int num_directions = valueProto.dims(0);
|
||||
int cell_size = valueProto.dims(2);
|
||||
NDShape weightShape({ (size_t)(cell_size) });
|
||||
for (int dir = 0; dir < num_directions; dir++)
|
||||
{
|
||||
std::string nodeName = name + std::string(1, (char)dir) + LSTMInputInitialHNameHint;
|
||||
|
||||
float *data = new float[cell_size];
|
||||
for (size_t targetIndex = 0; targetIndex < cell_size; targetIndex++)
|
||||
{
|
||||
data[targetIndex] = valueProto.float_data()[dir * cell_size + targetIndex];
|
||||
}
|
||||
|
||||
Constant constant = CreateConstantWithRawData(data, weightShape, nodeName, computeDevice);
|
||||
inputs.push_back(constant);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
break;
|
||||
return inputs;
|
||||
default:
|
||||
CNTK::LogicError("CreateRNNConstant for GRU op received unepxpeted index: %d", index);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
|
@ -890,6 +1038,36 @@ std::vector<Variable> ONNXToCNTKHelper::CreateRNNLeafVariableOrConstant(const No
|
|||
LogicError("LSTM node has unexpected input");
|
||||
}
|
||||
}
|
||||
else if (parentONNXOpName == "GRU")
|
||||
{
|
||||
int inputIndex = CalculateNodeArgInputIndex(nodeArg, parentNode);
|
||||
switch (inputIndex)
|
||||
{
|
||||
case GRUInputIndexX:
|
||||
// X: `[seq_length, batch_size, input_size]`.
|
||||
{
|
||||
Variable inputVariable;
|
||||
if (constructedNodeArgVariableMap.find(nodeArg->Name()) == constructedNodeArgVariableMap.end())
|
||||
{
|
||||
DataType dataType = FromONNXType(nodeArg->ToProto().type());
|
||||
int input_size = shapeProto->dim(2).dim_value();
|
||||
NDShape shape({ (size_t)(input_size) });
|
||||
inputVariable = InputVariable(shape, dataType, ToWString(nodeArg->Name()), dynamicAxes);
|
||||
constructedNodeArgVariableMap.insert(ONNXToCNTKVariableMap::value_type(nodeArg->Name(), inputVariable));
|
||||
}
|
||||
return std::vector<Variable>({ constructedNodeArgVariableMap[nodeArg->Name()] });
|
||||
}
|
||||
// other inputs shall be ONNX constant node and be created as CNTK Constant in CreateRNNConstant
|
||||
case GRUInputIndexW:
|
||||
case GRUInputIndexR:
|
||||
case GRUInputIndexB:
|
||||
case GRUInputIndexSequenceLens:
|
||||
case GRUInitialH:
|
||||
NOT_IMPLEMENTED;
|
||||
default:
|
||||
LogicError("LSTM node has unexpected input");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
|
@ -1465,6 +1643,15 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
std::vector<string>({"Sigmoid", "Tanh", "Tanh"}));
|
||||
return CreateLSTM(node, inputs, direction, activations, activation_alpha, activation_beta);
|
||||
}
|
||||
else if (onnxOpName == "GRU")
|
||||
{
|
||||
const string direction = GetNamedAttributeAsString(node, "direction");
|
||||
std::vector<float> activation_alpha = GetNamedAttributeAsFloatVec(node, "activation_alpha", std::vector<float>());
|
||||
std::vector<float> activation_beta = GetNamedAttributeAsFloatVec(node, "activation_beta", std::vector<float>());
|
||||
const std::vector<string> activations = GetNamedAttributeAsStringVec(node, "activations",
|
||||
std::vector<string>({ "Sigmoid", "Tanh" }));
|
||||
return CreateGRU(node, inputs, direction, activations, activation_alpha, activation_beta);
|
||||
}
|
||||
if (onnxOpName == "FC")
|
||||
{
|
||||
return CreateCNTKFCNode(ToWString(node->Name()), inputs);
|
||||
|
@ -2275,7 +2462,7 @@ std::vector<FunctionPtr> ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXTo
|
|||
const Node *parentNode;
|
||||
int childIndex;
|
||||
std::tie<const Node *, int>(parentNode, childIndex) = FindParentAndChildIndex(node);
|
||||
if (parentNode != nullptr && parentNode->OpType() == "LSTM")
|
||||
if (parentNode != nullptr && Operators::IsRNNOp(parentNode->OpType()))
|
||||
{
|
||||
std::vector<FunctionPtr> cntkFunctions = CreateRNNConstantOp(graph, node, parentNode, childIndex, computeDevice);
|
||||
if (!cntkFunctions.empty())
|
||||
|
@ -2613,7 +2800,7 @@ std::vector<Variable> ONNXToCNTKHelper::CreateCNTKInputsStartingFromIndex(const
|
|||
else
|
||||
{
|
||||
std::string parentONNXOpName = node->OpType();
|
||||
if (parentONNXOpName == "LSTM")
|
||||
if (Operators::IsRNNOp(node->OpType()))
|
||||
{
|
||||
std::vector<Variable> inputVariables =
|
||||
CreateRNNLeafVariableOrConstant(nodeArg, node, graph, constructedNodeArgVariableMap, computeDevice);
|
||||
|
|
|
@ -3,6 +3,57 @@
|
|||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
#include "RNNHelper.h"
|
||||
#include "Operators.h"
|
||||
|
||||
using namespace CNTK::ONNX;
|
||||
|
||||
std::string MapActivationNameONNXToCNTK(const std::string &onnxOp)
|
||||
{
|
||||
if (onnxOp == "Relu")
|
||||
return "ReLU";
|
||||
else if (onnxOp == "Sigmoid")
|
||||
return "StableSigmoid";
|
||||
else if (onnxOp == "LeakyRelu")
|
||||
return "LeakyReLU";
|
||||
else if (onnxOp == "ThresholdedRelu")
|
||||
return "ThresholdedReLU";
|
||||
else if (onnxOp == "Elu")
|
||||
return "ELU";
|
||||
else
|
||||
return onnxOp;
|
||||
}
|
||||
|
||||
std::string MapActivationNameCNTKToONNX(const std::string &cntkOp)
|
||||
{
|
||||
if (cntkOp == "ReLU")
|
||||
return "Relu";
|
||||
else if (cntkOp == "StableSigmoid")
|
||||
return "Sigmoid";
|
||||
else if (cntkOp == "LeakyReLU")
|
||||
return "LeakyRelu";
|
||||
else if (cntkOp == "ThresholdedReLU")
|
||||
return "ThresholdedRelu";
|
||||
else if (cntkOp == "ELU")
|
||||
return "Elu";
|
||||
else
|
||||
return cntkOp;
|
||||
}
|
||||
|
||||
bool IsActivationOp(const std::string &activationName)
|
||||
{
|
||||
return
|
||||
activationName == "Relu" || activationName == "ReLU" ||
|
||||
activationName == "Tanh" ||
|
||||
activationName == "Sigmoid" || activationName == "StableSigmoid" ||
|
||||
activationName == "Affine" ||
|
||||
activationName == "LeakyRelu" || activationName == "LeakyReLU" ||
|
||||
activationName == "ThresholdedRelu" || activationName == "ThresholdedReLU" ||
|
||||
activationName == "ScaledTanh" ||
|
||||
activationName == "HardSigmoid" ||
|
||||
activationName == "Elu" || activationName == "ELU" ||
|
||||
activationName == "Softsign" ||
|
||||
activationName == "Softplus";
|
||||
}
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName)
|
||||
{
|
||||
|
@ -37,7 +88,7 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
|||
}
|
||||
else
|
||||
{
|
||||
CNTK::LogicError("LSTM does not support activation: %s", activationName.c_str());
|
||||
CNTK::LogicError("Recurrent Op does not support activation: %s", activationName.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,9 +133,9 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
|
|||
// Here we assume that if they are set, they are set for all activations, regardless whether
|
||||
// an activation needs those values or not.
|
||||
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
|
||||
bool hasBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
|
||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
|
||||
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||
if (hasBeta)
|
||||
if (hasAlphaBeta)
|
||||
{
|
||||
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
|
||||
cellActivationOp = ActivationMap(activations[cellActivation], activation_alpha[cellActivation], activation_beta[cellActivation]);
|
||||
|
@ -107,6 +158,38 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
|
|||
|
||||
}
|
||||
|
||||
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
GetGRUActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||
{
|
||||
if (activations.size() < (direction + 1) * GRUActivationCount)
|
||||
CNTK::LogicError("LSTM activations shall be %d or %d of strings", GRUActivationCount, GRUActivationCount * 2);
|
||||
|
||||
//
|
||||
int fActivationIndex = direction * GRUActivationCount + GRUActivationFIndex;
|
||||
int gActivationIndex = direction * GRUActivationCount + GRUActivationGIndex;
|
||||
|
||||
bool hasAlpha = activation_alpha.size() == (direction + 1) * GRUActivationCount;
|
||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * GRUActivationCount;
|
||||
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
|
||||
if (hasAlphaBeta)
|
||||
{
|
||||
fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex], activation_beta[fActivationIndex]);
|
||||
gActivationOp = ActivationMap(activations[gActivationIndex], activation_alpha[gActivationIndex], activation_beta[gActivationIndex]);
|
||||
}
|
||||
else if (hasAlpha)
|
||||
{
|
||||
fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex]);
|
||||
gActivationOp = ActivationMap(activations[gActivationIndex], activation_alpha[gActivationIndex]);
|
||||
}
|
||||
else
|
||||
{
|
||||
fActivationOp = ActivationMap(activations[fActivationIndex]);
|
||||
gActivationOp = ActivationMap(activations[gActivationIndex]);
|
||||
}
|
||||
|
||||
return std::make_tuple(fActivationOp, gActivationOp);
|
||||
}
|
||||
|
||||
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
||||
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
|
||||
|
@ -155,6 +238,53 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
|||
return{ h, c };
|
||||
}
|
||||
|
||||
FunctionPtr GRUCell(Variable input,
|
||||
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
|
||||
Variable prevOutput,
|
||||
Constant &W, Constant &R, Constant &H1, Constant &B)
|
||||
{
|
||||
size_t outputDim = prevOutput.Shape()[0];
|
||||
int stacked_dim = (int)outputDim;
|
||||
|
||||
FunctionPtr projx3;
|
||||
if (B.IsInitialized())
|
||||
projx3 = Plus(B, Times(W, input));
|
||||
else
|
||||
projx3 = Times(W, input);
|
||||
|
||||
FunctionPtr projh2 = Times(R, prevOutput);
|
||||
|
||||
// both CNTK and ONNX weight and bias are in zrh order.
|
||||
std::vector<Axis> stack_axis({ Axis(-1) });
|
||||
FunctionPtr zt_proj =
|
||||
Slice(projx3, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim }) +
|
||||
Slice(projh2, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim });
|
||||
|
||||
FunctionPtr rt_proj =
|
||||
Slice(projx3, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim }) +
|
||||
Slice(projh2, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim });
|
||||
|
||||
FunctionPtr ct_proj =
|
||||
Slice(projx3, stack_axis, { 2 * stacked_dim }, { 3 * stacked_dim });
|
||||
|
||||
FunctionPtr zt = fActivationOp(zt_proj);
|
||||
|
||||
FunctionPtr rt = fActivationOp(rt_proj);
|
||||
|
||||
FunctionPtr rs = ElementTimes(prevOutput, rt);
|
||||
|
||||
FunctionPtr ct = gActivationOp(ct_proj + Times(H1, rs));
|
||||
|
||||
Constant one = W.GetDataType() == DataType::Float ? Constant::Scalar<float>(1.0f) : Constant::Scalar<double>(1.0);
|
||||
|
||||
FunctionPtr ht = ElementTimes(one - zt, ct) + ElementTimes(zt, prevOutput);
|
||||
|
||||
FunctionPtr h = ht;
|
||||
|
||||
return ht;
|
||||
}
|
||||
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "BlockFunction.h"
|
||||
|
||||
|
@ -185,6 +315,27 @@ std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
|
|||
return std::make_tuple(LSTMCell.first, LSTMCell.second);
|
||||
}
|
||||
|
||||
FunctionPtr GRUComponent(Variable input,
|
||||
const NDShape& cellShape,
|
||||
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
||||
Constant &W, Constant &R, Constant &H1, Constant &B)
|
||||
{
|
||||
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
|
||||
|
||||
auto gruCell = GRUCell(
|
||||
inputPlaceholder,
|
||||
fActivationOp, gActivationOp,
|
||||
dh, W, R, H1, B);
|
||||
|
||||
auto actualDh = recurrenceHookH(gruCell);
|
||||
|
||||
gruCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } });
|
||||
return gruCell;
|
||||
}
|
||||
|
||||
const std::vector<Variable> FindByNameHint(const std::vector<Variable> &inputs, const std::string &hint)
|
||||
{
|
||||
std::vector<Variable> variables;
|
||||
|
@ -218,7 +369,7 @@ Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDir
|
|||
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
||||
{
|
||||
int numDirections = direction == LSTMDirectionBidirection ? 2 : 1;
|
||||
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
||||
std::vector<FunctionPtr> outputHs;
|
||||
for (int dir = 0; dir < numDirections; dir++)
|
||||
{
|
||||
|
@ -237,8 +388,8 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
|||
else if (numDirections == 2 && biasVariables.size() == 2)
|
||||
B = biasVariables[dir];
|
||||
|
||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialCNameHint, X.GetDataType());
|
||||
Variable initCVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialHNameHint, X.GetDataType());
|
||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialHNameHint, X.GetDataType());
|
||||
Variable initCVariable = GetInitialStateVariable(inputs, numDirections, LSTMInputInitialCNameHint, X.GetDataType());
|
||||
|
||||
std::vector<Variable> peepholeVariables = FindByNameHint(inputs, LSTMInputPeepholeNameHint);
|
||||
Variable Ci, Cf, Co;
|
||||
|
@ -271,17 +422,23 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
|||
FunctionPtr outputC;
|
||||
|
||||
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||
bool go_backwards = direction == LSTMDirectionReverse || (numDirections == 2 && dir == 1);
|
||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> futureValueRecurrenceHook;
|
||||
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC;
|
||||
if (go_backwards)
|
||||
futureValueRecurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
||||
{
|
||||
recurrenceHookH = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
||||
recurrenceHookC = [initCVariable](const Variable& x) { return FutureValue(x, initCVariable); };
|
||||
}
|
||||
else
|
||||
futureValueRecurrenceHook = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
|
||||
{
|
||||
recurrenceHookH = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
||||
recurrenceHookC = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
|
||||
}
|
||||
|
||||
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
|
||||
X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp,
|
||||
futureValueRecurrenceHook, futureValueRecurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B,
|
||||
recurrenceHookH, recurrenceHookC, (Constant &)W, (Constant &)R, (Constant &)B,
|
||||
(Constant &)Ci, (Constant &)Cf, (Constant &)Co);
|
||||
outputHs.push_back(outputH);
|
||||
}
|
||||
|
@ -293,3 +450,475 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
|||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
||||
{
|
||||
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
||||
std::vector<FunctionPtr> outputHs;
|
||||
for (int dir = 0; dir < numDirections; dir++)
|
||||
{
|
||||
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
|
||||
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
|
||||
|
||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
|
||||
Variable X = inputs[0];
|
||||
Variable W = inputs[1 * numDirections + dir - ((numDirections == 2) ? 1 : 0)];
|
||||
Variable R = inputs[2 * numDirections + dir - ((numDirections == 2) ? 1 : 0)];
|
||||
|
||||
// TODO: get H1
|
||||
Variable H1 = inputs[3 * numDirections + dir - ((numDirections == 2) ? 1 : 0)];
|
||||
|
||||
Variable B;
|
||||
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
||||
if (numDirections == 1 && biasVariables.size() >= 1)
|
||||
B = biasVariables[0];
|
||||
else if (numDirections == 2 && biasVariables.size() == 2)
|
||||
B = biasVariables[1];
|
||||
|
||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
|
||||
|
||||
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
|
||||
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
|
||||
// as two separate LSTM. Therefore we can divide the dimension of the first axis
|
||||
// by 4 to get the hidden size.
|
||||
int hiddenDim = W.Shape()[0] / GRUWeightDimensionHiddenMultiplier;
|
||||
|
||||
FunctionPtr outputH;
|
||||
|
||||
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> recurrenceHook;
|
||||
if (go_backwards)
|
||||
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
||||
else
|
||||
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
||||
|
||||
outputH = GRUComponent(
|
||||
X, { (size_t)hiddenDim }, fActivationOp, gActivationOp,
|
||||
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)H1, (Constant &)B);
|
||||
outputHs.push_back(outputH);
|
||||
}
|
||||
if (outputHs.size() == 1)
|
||||
return outputHs[0];
|
||||
else
|
||||
{
|
||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctionType>
|
||||
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr>& visitedFunctions,
|
||||
FunctionType preFunctor, FunctionType postFunctor)
|
||||
{
|
||||
visitedFunctions.insert(cntkFunction);
|
||||
preFunctor(cntkFunction);
|
||||
|
||||
std::vector<Variable> functionInputs = cntkFunction->Inputs();
|
||||
for (const auto& input : functionInputs)
|
||||
{
|
||||
if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end())
|
||||
{
|
||||
const auto& inputFunction = input.Owner();
|
||||
TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor);
|
||||
}
|
||||
}
|
||||
|
||||
postFunctor(cntkFunction);
|
||||
}
|
||||
|
||||
bool IsSupportedRNNActivation(const std::wstring &cntkOpName)
|
||||
{
|
||||
static std::vector<std::wstring> supportedRNNActivations(
|
||||
{
|
||||
L"ReLU",
|
||||
L"Tanh",
|
||||
L"StableSigmoid"
|
||||
});
|
||||
return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) !=
|
||||
supportedRNNActivations.cend();
|
||||
}
|
||||
|
||||
std::string FindActivation(const std::vector<FunctionPtr> &path, int nth)
|
||||
{
|
||||
int count = 0;
|
||||
for (std::vector<FunctionPtr>::const_iterator it = path.begin(); it != path.end(); it++)
|
||||
{
|
||||
std::wstring opName = (*it)->OpName();
|
||||
if (IsSupportedRNNActivation(opName))
|
||||
{
|
||||
if (count == nth)
|
||||
{
|
||||
std::unordered_multimap<std::wstring, AttributesMapping>::const_iterator itLookup = Operators::CntkToONNXLookup().find(opName);
|
||||
if (itLookup == Operators::CntkToONNXLookup().cend())
|
||||
CNTK::LogicError("Invalid activation (%s)", ToString(opName).c_str());
|
||||
|
||||
std::unordered_map<std::wstring, std::string>::const_iterator itMap = (*itLookup).second.map.find(opName);
|
||||
if (itMap == (*itLookup).second.map.cend())
|
||||
CNTK::LogicError("Invalid activation (%s)", ToString(opName).c_str());
|
||||
return itMap->second;
|
||||
}
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
Variable GetPeepholeVariableFromOp(FunctionPtr peepholeOp)
|
||||
{
|
||||
// peephole variable is that child of peepholeOp that is neither stabilizer nor place holder
|
||||
if (peepholeOp->OpName() != L"ElementTimes")
|
||||
CNTK::LogicError("Peephole operation must be ElementTimes");
|
||||
|
||||
Variable peepholeVariable;
|
||||
FunctionPtr stabilizerOp;
|
||||
for (int i = 0; i < peepholeOp->Inputs().size(); i++)
|
||||
{
|
||||
if (peepholeOp->Inputs()[i].Owner() && peepholeOp->Inputs()[i].Owner()->OpName() == L"Stabilizer")
|
||||
{
|
||||
stabilizerOp = peepholeOp->Inputs()[i].Owner();
|
||||
}
|
||||
else if (peepholeOp->Inputs()[i].IsConstant() || peepholeOp->Inputs()[i].IsParameter())
|
||||
{
|
||||
if (!peepholeVariable.IsInitialized())
|
||||
peepholeVariable = peepholeOp->Inputs()[i];
|
||||
else
|
||||
CNTK::LogicError("Cannot find peephole variable from peephole op. Multiple qualified variables found.");
|
||||
}
|
||||
}
|
||||
|
||||
if (!peepholeVariable.IsInitialized())
|
||||
CNTK::LogicError("Cannot find peephole variable from peephole op.");
|
||||
return peepholeVariable;
|
||||
}
|
||||
|
||||
// this method helps getting a stabilizer op from its parent/grandparent op.
|
||||
// the parent op can be Times or ElementTimes. The grandparent op can be a plus op.
|
||||
FunctionPtr GetStabilizerOp(FunctionPtr parentOp)
|
||||
{
|
||||
FunctionPtr timesOp;
|
||||
if (parentOp->OpName() == L"Plus")
|
||||
{
|
||||
for (int i = 0; i < parentOp->Inputs().size(); i++)
|
||||
{
|
||||
if (parentOp->Inputs()[i].Owner() &&
|
||||
(parentOp->Inputs()[i].Owner()->OpName() == L"Times" ||
|
||||
parentOp->Inputs()[i].Owner()->OpName() == L"ElementTimes"))
|
||||
{
|
||||
timesOp = parentOp->Inputs()[i].Owner();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (parentOp->OpName() == L"Times" || parentOp->OpName() == L"ElementTimes")
|
||||
{
|
||||
timesOp = parentOp;
|
||||
}
|
||||
|
||||
if (!timesOp)
|
||||
{
|
||||
CNTK::LogicError("Cannot find stabilizer op. A stabilizer op must be from Times or ElementTimes ops or skipped from a Plus op.");
|
||||
}
|
||||
|
||||
for (int j = 0; j < timesOp->Inputs().size(); j++)
|
||||
{
|
||||
if (timesOp->Inputs()[j].Owner() && timesOp->Inputs()[j].Owner()->OpName() == L"Stabilizer")
|
||||
{
|
||||
return timesOp->Inputs()[j].Owner();
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
double GetScaler(Variable variable)
|
||||
{
|
||||
NDArrayViewPtr v = variable.IsParameter() ? Parameter(variable).Value() : Constant(variable).Value();
|
||||
NDArrayViewPtr cpuV = v->DeepClone();
|
||||
cpuV->ChangeDevice(DeviceDescriptor::CPUDevice());
|
||||
|
||||
switch (variable.GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
return *((float *)cpuV->DataBuffer<float>());
|
||||
case DataType::Double:
|
||||
return *((double *)cpuV->DataBuffer<double>());
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
|
||||
double GetStabilizerCoef(const FunctionPtr stabilizerDhOp)
|
||||
{
|
||||
double alpha = GetScaler(stabilizerDhOp->Inputs()[3]);
|
||||
double steepness = GetScaler(stabilizerDhOp->Inputs()[1]);
|
||||
return (log(exp(alpha * steepness) + 1.0F) / steepness);
|
||||
}
|
||||
|
||||
void GetDelayOps(const std::vector<Variable> &inputVars,
|
||||
std::vector<FunctionPtr> &pastValueOps, std::vector<FunctionPtr> &futureValueOps)
|
||||
{
|
||||
for (std::vector<Variable>::const_iterator it = inputVars.cbegin(); it != inputVars.cend(); ++it)
|
||||
{
|
||||
if ((*it).Owner() != nullptr && (*it).Owner()->OpName() == L"PastValue")
|
||||
pastValueOps.push_back((*it).Owner());
|
||||
else if ((*it).Owner() != nullptr && (*it).Owner()->OpName() == L"FutureValue")
|
||||
futureValueOps.push_back((*it).Owner());
|
||||
}
|
||||
}
|
||||
|
||||
// A CNTK LSTM op is created with stacked matmul followed by a slice op for 4 gates.
|
||||
// Slice order tells which graph path is for which gate. This method is
|
||||
// to traverse the graph to find the 4 paths along the 4 gates. It helps to
|
||||
// subsequently find needed attributes in order to build an ONNX LSTM op.
|
||||
void TraceLSTMPathes(const FunctionPtr& src,
|
||||
string &f_activation,
|
||||
string &g_activation,
|
||||
string &h_activation,
|
||||
RNNDirection &direction,
|
||||
Variable &initStateH,
|
||||
Variable &initStateC,
|
||||
Variable &peepholeCi,
|
||||
Variable &peepholeCo,
|
||||
Variable &peepholeCf,
|
||||
double &stabilizer_dh,
|
||||
double &stabilizer_dc,
|
||||
double &stabilizer_c)
|
||||
{
|
||||
// src has to be an LSTM node.
|
||||
std::vector<Variable> inputVars = src->Inputs();
|
||||
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
||||
GetDelayOps(inputVars, pastValueOps, futureValueOps);
|
||||
|
||||
// with CNTK LSTM, the first delay node is for H, the second one is for C
|
||||
// indices here also coresponding with CNTK python layer code.
|
||||
if (pastValueOps.size() == 2 && futureValueOps.size() == 0)
|
||||
{
|
||||
direction = RNNDirection::Forward;
|
||||
initStateH = pastValueOps[0]->Inputs()[1];
|
||||
initStateC = pastValueOps[1]->Inputs()[1];
|
||||
}
|
||||
else if (pastValueOps.size() == 0 && futureValueOps.size() == 2)
|
||||
{
|
||||
direction = RNNDirection::Backward;
|
||||
initStateH = futureValueOps[0]->Inputs()[1];
|
||||
initStateC = futureValueOps[1]->Inputs()[1];
|
||||
}
|
||||
else
|
||||
{
|
||||
CNTK::LogicError("Node %s (%s) is not a valid LSTM node", ToString(src->Name()).c_str(), ToString(src->Uid()).c_str());
|
||||
}
|
||||
|
||||
// set up traverse boundary
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
for (std::vector<Variable>::const_iterator it = inputVars.begin(); it != inputVars.end(); it++)
|
||||
{
|
||||
visitedFunctions.insert(it->Owner());
|
||||
}
|
||||
|
||||
// First find the peephole op node.
|
||||
// see CNTK\bindings\python\cntk\layers\blocks.py node references.
|
||||
std::vector<std::vector<FunctionPtr>> pathesBitBftJoint;
|
||||
{
|
||||
std::vector<FunctionPtr> currentPeepholePath;
|
||||
|
||||
// make a copy of traverse boundary
|
||||
std::unordered_set<FunctionPtr> peepHoleVisitedFunctions = visitedFunctions;
|
||||
|
||||
// traverse to find the joint of bit and bft
|
||||
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
||||
peepHoleVisitedFunctions,
|
||||
(std::function<void(const FunctionPtr&)>)[
|
||||
&peepHoleVisitedFunctions, &pathesBitBftJoint, ¤tPeepholePath](const FunctionPtr& function)
|
||||
{
|
||||
currentPeepholePath.push_back(function);
|
||||
if (function->OpName() == L"Plus" &&
|
||||
function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" &&
|
||||
function->Inputs()[1].Owner() && function->Inputs()[1].Owner()->OpName() == L"ElementTimes")
|
||||
{
|
||||
pathesBitBftJoint.push_back(currentPeepholePath);
|
||||
peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(),
|
||||
[function](FunctionPtr f) {return function == f; }));
|
||||
}
|
||||
},
|
||||
(std::function<void(const FunctionPtr&)>)[¤tPeepholePath](const FunctionPtr& function)
|
||||
{
|
||||
currentPeepholePath.pop_back();
|
||||
});
|
||||
}
|
||||
|
||||
FunctionPtr peepholeCoOp;
|
||||
bool haspeephole = pathesBitBftJoint.size() == 3;
|
||||
if (haspeephole)
|
||||
{
|
||||
// the last ElementTimes op is the peephole op
|
||||
std::vector<FunctionPtr> &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(),
|
||||
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) {return p1.size() < p2.size(); });
|
||||
std::vector<FunctionPtr>::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(),
|
||||
[](FunctionPtr function) {return function->OpName() == L"ElementTimes"; });
|
||||
if (itPeepholeOp == peepholePath.rend())
|
||||
{
|
||||
CNTK::LogicError("Cannot find peephole op from a LSTM graph");
|
||||
}
|
||||
|
||||
peepholeCoOp = *itPeepholeOp;
|
||||
peepholeCo = GetPeepholeVariableFromOp(peepholeCoOp);
|
||||
|
||||
FunctionPtr stabilizer_h_op = GetStabilizerOp(peepholeCoOp);
|
||||
if (stabilizer_h_op)
|
||||
{
|
||||
stabilizer_c = GetStabilizerCoef(stabilizer_h_op);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<FunctionPtr>> pathesToPlusSlice;
|
||||
std::vector<FunctionPtr> currentPath;
|
||||
|
||||
if (haspeephole)
|
||||
// so that traverse will not be affected by the peephole path
|
||||
visitedFunctions.insert(peepholeCoOp);
|
||||
|
||||
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
||||
visitedFunctions,
|
||||
(std::function<void(const FunctionPtr&)>)[&pathesToPlusSlice, ¤tPath](const FunctionPtr& function)
|
||||
{
|
||||
currentPath.push_back(function);
|
||||
if (function->OpName() == L"Slice")
|
||||
{
|
||||
FunctionPtr functionSource = function->Inputs()[0].Owner();
|
||||
if (functionSource->OpName() == L"Plus")
|
||||
{
|
||||
pathesToPlusSlice.push_back(currentPath);
|
||||
}
|
||||
}
|
||||
},
|
||||
(std::function<void(const FunctionPtr&)>)[¤tPath](const FunctionPtr& function)
|
||||
{
|
||||
currentPath.pop_back();
|
||||
});
|
||||
|
||||
// 4 gates of LSTM shall be traced.
|
||||
if (pathesToPlusSlice.size() != 4)
|
||||
{
|
||||
CNTK::LogicError("pathesToPlusSlice.size() != 4");
|
||||
}
|
||||
|
||||
std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(),
|
||||
[](const std::vector<FunctionPtr>& path1, const std::vector<FunctionPtr>& path2)
|
||||
{
|
||||
FunctionPtr slice1 = *path1.rbegin();
|
||||
FunctionPtr slice2 = *path2.rbegin();
|
||||
int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
||||
int beginIndex2 = slice2->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
||||
return beginIndex1 < beginIndex2;
|
||||
});
|
||||
|
||||
// This code is heavily coupled with CNTK python layer code:
|
||||
// https://github.com/Microsoft/CNTK/blob/44c626a483edeaff97b4f7a46847b055a1d483aa/bindings/python/cntk/layers/blocks.py#L261
|
||||
// pathesToPlusSlice is ordered by slice index so we are able to recover corresponding path here.
|
||||
std::vector<FunctionPtr> &ht_it_path = pathesToPlusSlice[0];
|
||||
std::vector<FunctionPtr> &ht_bit_path = pathesToPlusSlice[1];
|
||||
std::vector<FunctionPtr> &ht_ft_path = pathesToPlusSlice[2];
|
||||
std::vector<FunctionPtr> &ht_ot_path = pathesToPlusSlice[3];
|
||||
|
||||
f_activation = MapActivationNameCNTKToONNX(FindActivation(ht_ot_path, 0));
|
||||
g_activation = MapActivationNameCNTKToONNX(FindActivation(ht_bit_path, 1));
|
||||
h_activation = MapActivationNameCNTKToONNX(FindActivation(ht_bit_path, 0));
|
||||
|
||||
// stabilizer_dh
|
||||
FunctionPtr stackedProjPlusOp = ht_it_path[ht_it_path.size() - 1]->Inputs()[0].Owner();
|
||||
FunctionPtr stabilizerDhOp = GetStabilizerOp(stackedProjPlusOp);
|
||||
if (stabilizerDhOp)
|
||||
{
|
||||
stabilizer_dh = GetStabilizerCoef(stabilizerDhOp);
|
||||
}
|
||||
|
||||
if (haspeephole)
|
||||
{
|
||||
{
|
||||
// Ci merges to ht_it_path via element-wise time
|
||||
FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2];
|
||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
|
||||
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||
peepholeCi = GetPeepholeVariableFromOp(peepholeOp);
|
||||
|
||||
}
|
||||
{
|
||||
// Cf merges to ht_ft_path via element-wise time
|
||||
FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2];
|
||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
|
||||
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||
peepholeCf = GetPeepholeVariableFromOp(peepholeOp);
|
||||
|
||||
FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp);
|
||||
if (stabilizerDcOp)
|
||||
stabilizer_dc = GetStabilizerCoef(stabilizerDcOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FunctionPtr TraverseGraphFindFirstRNNOp(FunctionPtr src)
|
||||
{
|
||||
std::vector<Variable> front = src->Inputs(), back;
|
||||
|
||||
while (!front.empty())
|
||||
{
|
||||
for (auto f : front)
|
||||
{
|
||||
if (f.IsOutput() && f.Owner())
|
||||
if (IsActivationOp(ToString(f.Owner()->OpName())))
|
||||
return f.Owner();
|
||||
else
|
||||
{
|
||||
for (auto i : f.Owner()->Inputs())
|
||||
back.push_back(i);
|
||||
}
|
||||
}
|
||||
front = back;
|
||||
back.clear();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation,
|
||||
RNNDirection &direction, Variable &initStateH)
|
||||
{
|
||||
std::vector<Variable> inputVars = src->Inputs();
|
||||
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
||||
GetDelayOps(inputVars, pastValueOps, futureValueOps);
|
||||
|
||||
// indices here coresponding with CNTK python layer code.
|
||||
if (pastValueOps.size() == 1 && futureValueOps.size() == 0)
|
||||
{
|
||||
direction = RNNDirection::Forward;
|
||||
initStateH = pastValueOps[0]->Inputs()[1];
|
||||
}
|
||||
else if (pastValueOps.size() == 0 && futureValueOps.size() == 1)
|
||||
{
|
||||
direction = RNNDirection::Backward;
|
||||
initStateH = futureValueOps[0]->Inputs()[1];
|
||||
}
|
||||
else
|
||||
{
|
||||
CNTK::LogicError("Node %s (%s) is not a valid GRU node", ToString(src->Name()).c_str(), ToString(src->Uid()).c_str());
|
||||
}
|
||||
|
||||
// set up traverse boundary
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
for (std::vector<Variable>::const_iterator it = inputVars.begin(); it != inputVars.end(); it++)
|
||||
{
|
||||
visitedFunctions.insert(it->Owner());
|
||||
}
|
||||
|
||||
std::vector<std::vector<FunctionPtr>> pathesToPlusSlice;
|
||||
std::vector<FunctionPtr> currentPath;
|
||||
|
||||
FunctionPtr gActivation = TraverseGraphFindFirstRNNOp(src->BlockRoot());
|
||||
|
||||
f_activation = "Sigmoid";
|
||||
g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName()));
|
||||
}
|
||||
|
|
|
@ -25,6 +25,18 @@ const std::string LSTMInputInitialHNameHint = "_initial_h_";
|
|||
const std::string LSTMInputInitialCNameHint = "_initial_c_";
|
||||
const std::string LSTMInputPeepholeNameHint = "_peephole_";
|
||||
|
||||
const std::string GRUInputInitialHNameHint = "_initial_h_";
|
||||
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#attributes-18
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#attributes-27
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#attributes-39
|
||||
// CNTK RNN ops always output sequence.
|
||||
// ONNX requires to set the output_sequence attribute to 1 to output sequence.
|
||||
enum
|
||||
{
|
||||
RNNOutputSequence = 1
|
||||
};
|
||||
|
||||
enum
|
||||
{
|
||||
LSTMInputIndexX = 0,
|
||||
|
@ -52,10 +64,18 @@ enum {
|
|||
LSTMPeepholeCount = 3
|
||||
};
|
||||
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||
// size of weight/bias matrix is a multiple of hidden size
|
||||
enum
|
||||
{
|
||||
LSTMWeightDimensionHiddenMultiplier = 4,
|
||||
LSTMBiasDimensionHiddenMultiplier = 8
|
||||
};
|
||||
|
||||
typedef enum {
|
||||
Forward,
|
||||
Backward,
|
||||
} LSTMDirection;
|
||||
} RNNDirection;
|
||||
|
||||
enum
|
||||
{
|
||||
|
@ -69,9 +89,64 @@ enum
|
|||
CNTKLSTMOutputYhIndex = 0,
|
||||
CNTKLSTMOutputChIndex = 1
|
||||
};
|
||||
const string LSTMDirectionBidirection = "bidirectional";
|
||||
const string LSTMDirectionReverse = "reverse";
|
||||
const string LSTMDirectionForward = "forward";
|
||||
|
||||
enum
|
||||
{
|
||||
GRUActivationFIndex = 0,
|
||||
GRUActivationGIndex = 1,
|
||||
GRUActivationCount = 2
|
||||
};
|
||||
|
||||
enum
|
||||
{
|
||||
GRUInputIndexX = 0,
|
||||
GRUInputIndexW = 1,
|
||||
GRUInputIndexR = 2,
|
||||
GRUInputIndexB = 3,
|
||||
GRUInputIndexSequenceLens = 4,
|
||||
GRUInitialH = 5,
|
||||
};
|
||||
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---6
|
||||
// size of weight/bias matrix is a multiple of hidden size
|
||||
enum
|
||||
{
|
||||
GRUWeightDimensionHiddenMultiplier = 3,
|
||||
GRUBiasDimensionHiddenMultiplier = 6
|
||||
};
|
||||
|
||||
enum
|
||||
{
|
||||
CNTKGRUZRWeightMultiplier = 2
|
||||
};
|
||||
enum
|
||||
{
|
||||
CNTKGRUBiasIndex = 1,
|
||||
CNTKGRUWeightIndex = 2,
|
||||
CNTKGRUHiddenWeightZRIndex = 3,
|
||||
CNTKGRUHiddenWeightHIndex = 4,
|
||||
CNTKGRUPastOrFutureIndex = 5,
|
||||
CNTKGRUInputIndex = 6,
|
||||
CNTKGRUInputCount = 7
|
||||
};
|
||||
|
||||
|
||||
const string RNNDirectionBidirection = "bidirectional";
|
||||
const string RNNDirectionReverse = "reverse";
|
||||
const string RNNDirectionForward = "forward";
|
||||
|
||||
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||
const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta);
|
||||
|
||||
FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta);
|
||||
|
||||
void TraceLSTMPathes(const FunctionPtr& src, string &f_activation, string &g_activation, string &h_activation,
|
||||
RNNDirection &direction, Variable &initStateH, Variable &initStateC, Variable &peepholeCi, Variable &peepholeCo, Variable &peepholeCf,
|
||||
double &stabilizer_dh, double &stabilizer_dc, double &stabilizer_c);
|
||||
|
||||
void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation,
|
||||
RNNDirection &direction, Variable &initStateH);
|
||||
|
||||
std::string MapActivationNameONNXToCNTK(const std::string &onnxOp);
|
||||
std::string MapActivationNameCNTKToONNX(const std::string &cntkOp);
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
import cntk as C
|
||||
import pytest
|
||||
from cntk.ops.tests.ops_test_utils import cntk_device
|
||||
from itertools import product
|
||||
|
||||
#############
|
||||
#helpers
|
||||
|
@ -431,6 +432,43 @@ def test_Greater(tmpdir):
|
|||
model = C.greater([41., 42., 43.], [42., 42., 42.])
|
||||
verify_no_input(model, tmpdir, 'Greater_0')
|
||||
|
||||
#GRU
|
||||
def MakeGRUNameFromConfig(backward, initial_state, activtion):
|
||||
model_name = 'GRU.' + activtion.__name__
|
||||
if (initial_state != 0):
|
||||
model_name += '.initial'
|
||||
if (backward):
|
||||
model_name += '.backward'
|
||||
else:
|
||||
model_name += '.forward'
|
||||
return model_name
|
||||
|
||||
direction_options = [False, True]
|
||||
activation_options = [C.tanh]
|
||||
initial_state_options = [0]
|
||||
|
||||
input_dim = 2
|
||||
cell_dim = 3
|
||||
batch_size = 1
|
||||
sequence_len = 5
|
||||
|
||||
def test_GRU(tmpdir):
|
||||
for config in list(product(direction_options, initial_state_options, activation_options)):
|
||||
model_filename = MakeGRUNameFromConfig(*config)
|
||||
print(model_filename)
|
||||
backward, initial_state, activation = config
|
||||
|
||||
x = C.input_variable(input_dim, dynamic_axes=[C.Axis.default_batch_axis(), C.Axis('sequenceAxis')])
|
||||
GRUModel = C.layers.Recurrence(C.layers.GRU(cell_dim,
|
||||
activation = activation),
|
||||
initial_state = initial_state,
|
||||
go_backwards=backward)(x)
|
||||
#CLG.plot(GRUModel, filename=cntk_pdf_filename)
|
||||
#plot_block_internals(GRUModel, 'GRU', model_filename)
|
||||
data = np.random.uniform(low=0.0, high=1.0, size=(batch_size, sequence_len, input_dim)).astype('f')
|
||||
verify_one_input(GRUModel, data, tmpdir, model_filename)
|
||||
|
||||
|
||||
#Hardmax
|
||||
def test_Hardmax(tmpdir):
|
||||
data = np.asarray([1., 1., 2., 3.], dtype=np.float32)
|
||||
|
@ -522,21 +560,18 @@ def test_LRN(tmpdir):
|
|||
verify_one_input(model, img, tmpdir, 'LRN_1')
|
||||
|
||||
#LSTM
|
||||
from cntk.layers import *
|
||||
from itertools import product
|
||||
|
||||
def CreateLSTMModel(activation,
|
||||
peepholes,
|
||||
self_stabilization,
|
||||
cell_dim,
|
||||
initial_state):
|
||||
return Sequential([
|
||||
Recurrence(LSTM(cell_dim,
|
||||
use_peepholes = peepholes,
|
||||
activation = activation,
|
||||
enable_self_stabilization = self_stabilization),
|
||||
initial_state = initial_state)
|
||||
])
|
||||
return C.layers.Sequential([
|
||||
C.layers.Recurrence(C.layers.LSTM(cell_dim,
|
||||
use_peepholes = peepholes,
|
||||
activation = activation,
|
||||
enable_self_stabilization = self_stabilization),
|
||||
initial_state = initial_state)
|
||||
])
|
||||
|
||||
# lstm attributes
|
||||
use_peepholes_options = [False]
|
||||
|
@ -567,7 +602,7 @@ def test_LSTM(tmpdir):
|
|||
model_filename = MakeLSTMNameFromConfig(*config)
|
||||
use_peepholes, enable_self_stabilization, initial_state, activation = config
|
||||
|
||||
x = C.input_variable(input_dim, dynamic_axes=[Axis.default_batch_axis(), C.Axis('sequenceAxis')])
|
||||
x = C.input_variable(input_dim, dynamic_axes=[C.Axis.default_batch_axis(), C.Axis('sequenceAxis')])
|
||||
LSTMmodel = CreateLSTMModel(peepholes = use_peepholes,
|
||||
activation = activation,
|
||||
initial_state = initial_state,
|
||||
|
|
Загрузка…
Ссылка в новой задаче