Adding ONNX support for rnnReLu and rnnTanh in OptimizedRNNStack.

This commit is contained in:
Spandan Tiwari 2018-03-22 11:24:32 -07:00
Родитель aab2567d17
Коммит 3f7246991e
6 изменённых файлов: 825 добавлений и 798 удалений

Просмотреть файл

@ -3429,6 +3429,9 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
auto bidirectional = (bool)(src->Attributes()[L"bidirectional"].Value<bool>());
auto recurrentOp = (wstring)src->Attributes()[L"recurrentOp"].Value<wstring>();
if (!Operators::IsOptimizedRnnStackOp(recurrentOp))
InvalidArgument("Recurrent op used for OptimizedRNNStack is not supported for ONNX export.");
size_t numDirections = bidirectional ? 2 : 1;
size_t inputSize = src->Inputs()[0].Shape()[0];
auto Wcombined = src->Inputs()[1];
@ -3550,10 +3553,17 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
outputs.push_back(outputArg_Yc);
// ==== Step 6. Add ONNX LSTM node ====
auto rnnOpNameLookup = Operators::OptimizedRnnToOnnxOpLookup();
auto rnnNodeName = (src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name())) + std::to_string(i);
functionNode = graph->AddNode(rnnNodeName, "LSTM", "", inputs, outputs);
functionNode = graph->AddNode(rnnNodeName, rnnOpNameLookup[recurrentOp], "", inputs, outputs);
std::vector<std::string> singleDirectionActivation({ "Sigmoid", "Tanh", "Tanh" }); // REVIEW: Check this is the order.
std::vector<std::string> singleDirectionActivation;
if (recurrentOp == L"lstm")
singleDirectionActivation = { "Sigmoid", "Tanh", "Tanh" };
else if (recurrentOp == L"rnnReLU")
singleDirectionActivation = { "Relu" };
else if (recurrentOp == L"rnnTanh")
singleDirectionActivation = { "Tanh" };
std::vector<std::string> activations;
activations.insert(activations.end(), singleDirectionActivation.begin(), singleDirectionActivation.end());
if (bidirectional)
@ -3575,7 +3585,6 @@ std::tuple<std::vector<NDArrayViewPtr>, std::vector<NDArrayViewPtr>, std::vector
CNTKToONNXHelper::SplitOptimzedRnnWtoIndivMats(Matrix<float>& WbigIn,
size_t numLayers, size_t inputSize, size_t hiddenSize, bool bidirectional, wstring recurrentOp)
{
std::vector<NDArrayViewPtr> onnxInputTensor(3);
size_t numDirections = bidirectional ? 2 : 1;
size_t numGates;
if (recurrentOp == L"lstm")
@ -3716,7 +3725,6 @@ std::vector<NDArrayViewPtr> CNTKToONNXHelper::ToRnnWeightPerLayerOnnxFormat(std:
size_t offset = 0;
for (size_t j = 0; j < numDirections; ++j)
{
// Matrix<float> temp = W[i*numDirections + j].Transpose(); // Have to do this because Matrix::InplaceTranspose is not implemented.
Matrix<float> temp(W[i*numDirections + j].GetNumCols(), W[i*numDirections + j].GetNumRows(), W[i*numDirections + j].Data(), CPUDEVICE);
currLayerWeightMatrix.SetColumnSlice(temp, offset, layerInputSize);
offset += layerInputSize;

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -490,5 +490,17 @@ namespace ONNX
{ L"Dropout" },
};
std::set<std::wstring> Operators::_optimizedRnnStackOpNames = {
{ L"lstm" },
{ L"rnnReLU" },
{ L"rnnTanh" },
};
std::unordered_map<std::wstring, std::string> Operators::_optimizedRnnOpNameToOnnxOpName = {
{ L"lstm", "LSTM" },
{ L"rnnReLU", "RNN" },
{ L"rnnTanh","RNN" },
};
}
}

Просмотреть файл

@ -53,6 +53,23 @@ namespace CNTK
return _cntkToONNXOpName;
}
//
// Method to check if a name is a valid optimizedRnnStack op name.
//
static inline bool IsOptimizedRnnStackOp(const std::wstring& opName)
{
return _optimizedRnnStackOpNames.find(opName) != _optimizedRnnStackOpNames.end();
}
//
// Return a lookup table that maps CNTK's optimizedRNNStack op to one of the
// ONNX RNN ops.
//
static inline const std::unordered_map<std::wstring, std::string>& OptimizedRnnToOnnxOpLookup()
{
return _optimizedRnnOpNameToOnnxOpName;
}
static std::tuple<int, int> GetElementWiseInputIndices(const std::wstring& opName);
//
@ -115,6 +132,8 @@ namespace CNTK
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
static std::unordered_map<std::wstring, std::vector<int>> _cntkToONNXInputIndices;
static std::set<std::wstring>_optimizedRnnStackOpNames;
static std::unordered_map<std::wstring, std::string> _optimizedRnnOpNameToOnnxOpName;
static std::set<std::wstring> _cntkLayerOPName;
};

Просмотреть файл

@ -41,33 +41,32 @@ std::string MapActivationNameCNTKToONNX(const std::string &cntkOp)
bool IsActivationOp(const std::string &activationName)
{
return
activationName == "Relu" || activationName == "ReLU" ||
activationName == "Tanh" ||
activationName == "Sigmoid" || activationName == "StableSigmoid" ||
activationName == "Affine" ||
activationName == "LeakyRelu" || activationName == "LeakyReLU" ||
activationName == "ThresholdedRelu" || activationName == "ThresholdedReLU" ||
activationName == "ScaledTanh" ||
activationName == "HardSigmoid" ||
activationName == "Elu" || activationName == "ELU" ||
activationName == "Softsign" ||
activationName == "Softplus";
return activationName == "Relu" || activationName == "ReLU" ||
activationName == "Tanh" ||
activationName == "Sigmoid" || activationName == "StableSigmoid" ||
activationName == "Affine" ||
activationName == "LeakyRelu" || activationName == "LeakyReLU" ||
activationName == "ThresholdedRelu" || activationName == "ThresholdedReLU" ||
activationName == "ScaledTanh" ||
activationName == "HardSigmoid" ||
activationName == "Elu" || activationName == "ELU" ||
activationName == "Softsign" ||
activationName == "Softplus";
}
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName)
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName)
{
if (activationName == "Relu")
{
return [](const Variable& x) { return ReLU(x); };
return [](const Variable &x) { return ReLU(x); };
}
else if (activationName == "Tanh")
{
return [](const Variable& x) { return Tanh(x); };
return [](const Variable &x) { return Tanh(x); };
}
else if (activationName == "Sigmoid")
{
return [](const Variable& x) { return Sigmoid(x); };
return [](const Variable &x) { return Sigmoid(x); };
}
// else if (activationName == "Affine")
// else if (activationName == "LeakyRelu")
@ -76,15 +75,15 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
// else if (activationName == "HardSigmoid")
else if (activationName == "Elu")
{
return [](const Variable& x) { return ELU(x); };
return [](const Variable &x) { return ELU(x); };
}
else if (activationName == "Softsign")
{
return [](const Variable& x) { return Softsign(x); };
return [](const Variable &x) { return Softsign(x); };
}
else if (activationName == "Softplus")
{
return [](const Variable& x) { return Softplus(x); };
return [](const Variable &x) { return Softplus(x); };
}
else
{
@ -92,12 +91,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
}
}
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
float activation_alpha)
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
float activation_alpha)
{
if (activationName == "LeakyRelu")
{
return [activation_alpha](const Variable& x) { return LeakyReLU(x, activation_alpha); };
return [activation_alpha](const Variable &x) { return LeakyReLU(x, activation_alpha); };
}
else
{
@ -105,12 +104,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
}
}
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
float activation_alpha, float activation_beta)
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
float activation_alpha, float activation_beta)
{
if (activationName == "HardSigmoid")
{
return [activation_alpha, activation_beta](const Variable& x) { return HardSigmoid(x, activation_alpha, activation_beta); };
return [activation_alpha, activation_beta](const Variable &x) { return HardSigmoid(x, activation_alpha, activation_beta); };
}
else
{
@ -118,23 +117,23 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
}
}
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
GetActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
{
if (activations.size() < (direction + 1) * LSTMActivationCount)
CNTK::LogicError("LSTM activations shall be a list of strings of size %d or %d ", LSTMActivationCount, LSTMActivationCount * 2);
//
//
int iofActivationIndex = direction * LSTMActivationCount + LSTMActivationFIndex;
int cellActivation = direction * LSTMActivationCount + LSTMActivationGIndex;
int hiddenActivationIndex = direction * LSTMActivationCount + LSTMActivationHIndex;
// ONNX spec is not clear on how activation alpha and beta is set.
// Here we assume that if they are set, they are set for all activations, regardless whether
// ONNX spec is not clear on how activation alpha and beta is set.
// Here we assume that if they are set, they are set for all activations, regardless whether
// an activation needs those values or not.
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
if (hasAlphaBeta)
{
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
@ -155,22 +154,21 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
}
return std::make_tuple(iofActivationOp, cellActivationOp, hiddenActivationOp);
}
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
GetGRUActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
{
if (activations.size() < (direction + 1) * GRUActivationCount)
CNTK::LogicError("GRU activations shall be a list of strings of size %d or %d", GRUActivationCount, GRUActivationCount * 2);
//
//
int fActivationIndex = direction * GRUActivationCount + GRUActivationFIndex;
int gActivationIndex = direction * GRUActivationCount + GRUActivationGIndex;
bool hasAlpha = activation_alpha.size() == (direction + 1) * GRUActivationCount;
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * GRUActivationCount;
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
if (hasAlphaBeta)
{
fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex], activation_beta[fActivationIndex]);
@ -190,18 +188,18 @@ GetGRUActivations(const std::vector<std::string> &activations, const std::vector
return std::make_tuple(fActivationOp, gActivationOp);
}
std::function<FunctionPtr(const Variable&)>
std::function<FunctionPtr(const Variable &)>
GetRNNActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
{
if (activations.size() < (direction + 1))
CNTK::LogicError("RNN activations shall be a list of strings of size 1 or 2");
//
//
int activationIndex = direction;
bool hasAlpha = activation_alpha.size() == (direction + 1);
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1);
std::function<FunctionPtr(const Variable&)> activationOp;
std::function<FunctionPtr(const Variable &)> activationOp;
if (hasAlphaBeta)
{
activationOp = ActivationMap(activations[activationIndex], activation_alpha[activationIndex], activation_beta[activationIndex]);
@ -219,14 +217,14 @@ GetRNNActivations(const std::vector<std::string> &activations, const std::vector
}
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
Variable prevOutput, Variable prevCellState,
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
Variable prevOutput, Variable prevCellState,
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
{
size_t outputDim = prevOutput.Shape()[0];
int stacked_dim = (int)outputDim;
int stacked_dim = (int) outputDim;
FunctionPtr proj4;
if (B.IsInitialized())
@ -238,13 +236,13 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
proj4 = Plus(Times(W, input), Times(R, prevOutput));
}
// CNTK weight and bias are in icfo order.
std::vector<Axis> stack_axis({ Axis(-1) });
// CNTK weight and bias are in icfo order.
std::vector<Axis> stack_axis({Axis(-1)});
const int IGateIndex = 0, CGateIndex = 1, FGateIndex = 2, OGateIndex = 3;
FunctionPtr it_proj = Slice(proj4, stack_axis, { IGateIndex * stacked_dim }, { (IGateIndex + 1) * stacked_dim });
FunctionPtr bit_proj = Slice(proj4, stack_axis, { CGateIndex * stacked_dim }, { (CGateIndex + 1) * stacked_dim });
FunctionPtr ft_proj = Slice(proj4, stack_axis, { FGateIndex * stacked_dim }, { (FGateIndex + 1) * stacked_dim });
FunctionPtr ot_proj = Slice(proj4, stack_axis, { OGateIndex * stacked_dim }, { (OGateIndex + 1) * stacked_dim });
FunctionPtr it_proj = Slice(proj4, stack_axis, {IGateIndex * stacked_dim}, {(IGateIndex + 1) * stacked_dim});
FunctionPtr bit_proj = Slice(proj4, stack_axis, {CGateIndex * stacked_dim}, {(CGateIndex + 1) * stacked_dim});
FunctionPtr ft_proj = Slice(proj4, stack_axis, {FGateIndex * stacked_dim}, {(FGateIndex + 1) * stacked_dim});
FunctionPtr ot_proj = Slice(proj4, stack_axis, {OGateIndex * stacked_dim}, {(OGateIndex + 1) * stacked_dim});
bool hasPeephole = Ci.IsInitialized();
@ -263,38 +261,38 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
auto c = ct;
auto h = ht;
return{ h, c };
return {h, c};
}
FunctionPtr GRUCell(Variable input,
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
Variable prevOutput,
Constant &W, Constant &R, Constant &H1, Constant &B)
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
Variable prevOutput,
Constant &W, Constant &R, Constant &H1, Constant &B)
{
size_t outputDim = prevOutput.Shape()[0];
int stacked_dim = (int)outputDim;
int stacked_dim = (int) outputDim;
FunctionPtr projx3;
if (B.IsInitialized())
projx3 = Plus(B, Times(W, input));
else
else
projx3 = Times(W, input);
FunctionPtr projh2 = Times(R, prevOutput);
// both CNTK and ONNX weight and bias are in zrh order.
std::vector<Axis> stack_axis({ Axis(-1) });
FunctionPtr zt_proj =
Slice(projx3, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim }) +
Slice(projh2, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim });
// both CNTK and ONNX weight and bias are in zrh order.
std::vector<Axis> stack_axis({Axis(-1)});
FunctionPtr zt_proj =
Slice(projx3, stack_axis, {0 * stacked_dim}, {1 * stacked_dim}) +
Slice(projh2, stack_axis, {0 * stacked_dim}, {1 * stacked_dim});
FunctionPtr rt_proj =
Slice(projx3, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim }) +
Slice(projh2, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim });
Slice(projx3, stack_axis, {1 * stacked_dim}, {2 * stacked_dim}) +
Slice(projh2, stack_axis, {1 * stacked_dim}, {2 * stacked_dim});
FunctionPtr ct_proj =
Slice(projx3, stack_axis, { 2 * stacked_dim }, { 3 * stacked_dim });
Slice(projx3, stack_axis, {2 * stacked_dim}, {3 * stacked_dim});
FunctionPtr zt = fActivationOp(zt_proj);
@ -307,18 +305,19 @@ FunctionPtr GRUCell(Variable input,
Constant one = W.GetDataType() == DataType::Float ? Constant::Scalar<float>(1.0f) : Constant::Scalar<double>(1.0);
FunctionPtr ht = ElementTimes(one - zt, ct) + ElementTimes(zt, prevOutput);
FunctionPtr h = ht;
return ht;
}
FunctionPtr RNNCell(Variable input,
const std::function<FunctionPtr(const Variable&)> &activationOp,
Variable prevOutput,
Constant &W, Constant &R, Constant &B)
const std::function<FunctionPtr(const Variable &)> &activationOp,
Variable prevOutput,
Constant &W, Constant &R, Constant &B)
{
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);;
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);
;
if (B.IsInitialized())
proj = B + proj;
@ -326,19 +325,18 @@ FunctionPtr RNNCell(Variable input,
return h;
}
#include "PrimitiveFunction.h"
#include "BlockFunction.h"
std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
const NDShape& cellShape,
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookC,
Constant &W, Constant &R, Constant &B,
Constant &Ci, Constant &Cf, Constant &Co)
const NDShape &cellShape,
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookC,
Constant &W, Constant &R, Constant &B,
Constant &Ci, Constant &Cf, Constant &Co)
{
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
auto dc = PlaceholderVariable(cellShape, input.DynamicAxes());
@ -353,16 +351,16 @@ std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
auto actualDc = recurrenceHookC(LSTMCell.second);
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
LSTMCell.first->ReplacePlaceholders({ { inputPlaceholder , input}, { dh, actualDh },{ dc, actualDc } });
LSTMCell.first->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}, {dc, actualDc}});
return std::make_tuple(LSTMCell.first, LSTMCell.second);
}
FunctionPtr GRUComponent(Variable input,
const NDShape& cellShape,
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
Constant &W, Constant &R, Constant &H1, Constant &B)
const NDShape &cellShape,
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
Constant &W, Constant &R, Constant &H1, Constant &B)
{
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
@ -374,17 +372,17 @@ FunctionPtr GRUComponent(Variable input,
auto actualDh = recurrenceHookH(gruCell);
gruCell->ReplacePlaceholders({ { dh, actualDh } });
gruCell->ReplacePlaceholders({{dh, actualDh}});
auto gruBlock = AsBlock(std::move(gruCell), { { inputPlaceholder , input } }, L"GRU", L"");
auto gruBlock = AsBlock(std::move(gruCell), {{inputPlaceholder, input}}, L"GRU", L"");
return gruBlock;
}
FunctionPtr RNNComponent(Variable input,
const NDShape& cellShape,
const std::function<FunctionPtr(const Variable&)> &activationOp,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
Constant &W, Constant &R, Constant &B)
const NDShape &cellShape,
const std::function<FunctionPtr(const Variable &)> &activationOp,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
Constant &W, Constant &R, Constant &B)
{
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
@ -396,7 +394,7 @@ FunctionPtr RNNComponent(Variable input,
auto actualDh = recurrenceHookH(rnnCell);
rnnCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } });
rnnCell->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}});
return rnnCell;
}
@ -414,7 +412,7 @@ const std::vector<Variable> FindByNameHint(const std::vector<Variable> &inputs,
}
Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDirections,
const std::string &nameHint, DataType datatype)
const std::string &nameHint, DataType datatype)
{
Variable initialVariable = datatype == DataType::Double ? Constant::Scalar(0.0) : Constant::Scalar(0.0f);
const std::vector<Variable> initialVariables = FindByNameHint(inputs, nameHint);
@ -431,15 +429,14 @@ Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDir
}
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
{
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
std::vector<FunctionPtr> outputHs;
for (int dir = 0; dir < numDirections; dir++)
{
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
// the first a few inputs are (in order): X, numDirections * W, numDirections * R
Variable X = inputs[0];
@ -460,7 +457,7 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
if (peepholeVariables.size() != 0 && peepholeVariables.size() != LSTMPeepholeCount && peepholeVariables.size() != 2 * LSTMPeepholeCount)
{
CNTK::LogicError("Peephole Variable count (%d) should be 0, 1 or 2 times the number of peephole factors (%d).",
(int)(peepholeVariables.size()), (int)LSTMPeepholeCount);
(int) (peepholeVariables.size()), (int) LSTMPeepholeCount);
}
else if (numDirections == 1 && peepholeVariables.size() >= LSTMPeepholeCount)
{
@ -477,8 +474,8 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
// as two separate LSTM. Therefore we can divide the dimension of the first axis
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
// as two separate LSTM. Therefore we can divide the dimension of the first axis
// by 4 to get the hidden size.
int hiddenDim = W.Shape()[0] / 4;
@ -488,43 +485,42 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
// if it is bidirectional LSTM, the second one will be the backword one.
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC;
std::function<FunctionPtr(const Variable &)> recurrenceHookH, recurrenceHookC;
if (go_backwards)
{
recurrenceHookH = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable& x) { return FutureValue(x, initCVariable); };
recurrenceHookH = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable &x) { return FutureValue(x, initCVariable); };
}
else
{
recurrenceHookH = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
recurrenceHookH = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable &x) { return PastValue(x, initCVariable); };
}
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp,
recurrenceHookH, recurrenceHookC, (Constant &)W, (Constant &)R, (Constant &)B,
(Constant &)Ci, (Constant &)Cf, (Constant &)Co);
X, {(size_t) hiddenDim}, iofActivationOp, cellActivationOp, hiddenActivationOp,
recurrenceHookH, recurrenceHookC, (Constant &) W, (Constant &) R, (Constant &) B,
(Constant &) Ci, (Constant &) Cf, (Constant &) Co);
outputHs.push_back(outputH);
}
if (outputHs.size() == 1)
return outputHs[0];
else
{
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
std::vector<Variable> operands({outputHs[0], outputHs[1]});
return Splice(operands, Axis(0), ToWString(node->Name()));
}
}
FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
{
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
std::vector<FunctionPtr> outputHs;
for (int dir = 0; dir < numDirections; dir++)
{
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
Variable X = inputs[0];
@ -537,16 +533,16 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
Variable B;
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
if (numDirections == 1 && biasVariables.size() >= 1)
B = biasVariables[0];
B = biasVariables[dir];
else if (numDirections == 2 && biasVariables.size() == 2)
B = biasVariables[1];
B = biasVariables[dir];
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
// as two separate LSTM. Therefore we can divide the dimension of the first axis
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
// as two separate LSTM. Therefore we can divide the dimension of the first axis
// by 4 to get the hidden size.
int hiddenDim = W.Shape()[0] / GRUWeightDimensionHiddenMultiplier;
@ -555,34 +551,34 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
// if it is bidirectional LSTM, the second one will be the backword one.
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
std::function<FunctionPtr(const Variable&)> recurrenceHook;
std::function<FunctionPtr(const Variable &)> recurrenceHook;
if (go_backwards)
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
else
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
outputH = GRUComponent(
X, { (size_t)hiddenDim }, fActivationOp, gActivationOp,
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)H1, (Constant &)B);
X, {(size_t) hiddenDim}, fActivationOp, gActivationOp,
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) H1, (Constant &) B);
outputHs.push_back(outputH);
}
if (outputHs.size() == 1)
return outputHs[0];
else
{
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
std::vector<Variable> operands({outputHs[0], outputHs[1]});
return Splice(operands, Axis(0), ToWString(node->Name()));
}
}
FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
{
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
std::vector<FunctionPtr> outputHs;
for (int dir = 0; dir < numDirections; dir++)
{
std::function<FunctionPtr(const Variable&)> activationOp =
std::function<FunctionPtr(const Variable &)> activationOp =
GetRNNActivations(activations, activation_alpha, activation_beta, dir);
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
@ -592,9 +588,9 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
Variable B;
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
if (numDirections == 1 && biasVariables.size() >= 1)
B = biasVariables[0];
B = biasVariables[dir];
else if (numDirections == 2 && biasVariables.size() == 2)
B = biasVariables[1];
B = biasVariables[dir];
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
@ -605,39 +601,39 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
// if it is bidirectional LSTM, the second one will be the backword one.
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
std::function<FunctionPtr(const Variable&)> recurrenceHook;
std::function<FunctionPtr(const Variable &)> recurrenceHook;
if (go_backwards)
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
else
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
outputH = RNNComponent(
X, { (size_t)hiddenDim }, activationOp,
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B);
X, {(size_t) hiddenDim}, activationOp,
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) B);
outputHs.push_back(outputH);
}
if (outputHs.size() == 1)
return outputHs[0];
else
{
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
std::vector<Variable> operands({outputHs[0], outputHs[1]});
return Splice(operands, Axis(0), ToWString(node->Name()));
}
}
template <typename FunctionType>
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr>& visitedFunctions,
FunctionType preFunctor, FunctionType postFunctor)
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr> &visitedFunctions,
FunctionType preFunctor, FunctionType postFunctor)
{
visitedFunctions.insert(cntkFunction);
preFunctor(cntkFunction);
std::vector<Variable> functionInputs = cntkFunction->Inputs();
for (const auto& input : functionInputs)
for (const auto &input : functionInputs)
{
if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end())
{
const auto& inputFunction = input.Owner();
const auto &inputFunction = input.Owner();
TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor);
}
}
@ -648,13 +644,11 @@ void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_se
bool IsSupportedRNNActivation(const std::wstring &cntkOpName)
{
static std::vector<std::wstring> supportedRNNActivations(
{
L"ReLU",
L"Tanh",
L"StableSigmoid"
});
{L"ReLU",
L"Tanh",
L"StableSigmoid"});
return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) !=
supportedRNNActivations.cend();
supportedRNNActivations.cend();
}
std::string FindActivation(const std::vector<FunctionPtr> &path, int nth)
@ -684,7 +678,7 @@ std::string FindActivation(const std::vector<FunctionPtr> &path, int nth)
Variable GetPeepholeVariableFromOp(FunctionPtr peepholeOp)
{
// peephole variable is that child of peepholeOp that is neither stabilizer nor place holder
// peephole variable is that child of peepholeOp that is neither stabilizer nor place holder
if (peepholeOp->OpName() != L"ElementTimes")
CNTK::LogicError("Peephole operation must be ElementTimes");
@ -721,7 +715,7 @@ FunctionPtr GetStabilizerOp(FunctionPtr parentOp)
{
if (parentOp->Inputs()[i].Owner() &&
(parentOp->Inputs()[i].Owner()->OpName() == L"Times" ||
parentOp->Inputs()[i].Owner()->OpName() == L"ElementTimes"))
parentOp->Inputs()[i].Owner()->OpName() == L"ElementTimes"))
{
timesOp = parentOp->Inputs()[i].Owner();
break;
@ -758,9 +752,9 @@ double GetScaler(Variable variable)
switch (variable.GetDataType())
{
case DataType::Float:
return *((float *)cpuV->DataBuffer<float>());
return *((float *) cpuV->DataBuffer<float>());
case DataType::Double:
return *((double *)cpuV->DataBuffer<double>());
return *((double *) cpuV->DataBuffer<double>());
default:
NOT_IMPLEMENTED;
}
@ -773,8 +767,8 @@ double GetStabilizerCoef(const FunctionPtr stabilizerDhOp)
return (log(exp(alpha * steepness) + 1.0F) / steepness);
}
void GetDelayOps(const std::vector<Variable> &inputVars,
std::vector<FunctionPtr> &pastValueOps, std::vector<FunctionPtr> &futureValueOps)
void GetDelayOps(const std::vector<Variable> &inputVars,
std::vector<FunctionPtr> &pastValueOps, std::vector<FunctionPtr> &futureValueOps)
{
for (std::vector<Variable>::const_iterator it = inputVars.cbegin(); it != inputVars.cend(); ++it)
{
@ -785,25 +779,25 @@ void GetDelayOps(const std::vector<Variable> &inputVars,
}
}
// A CNTK LSTM op is created with stacked matmul followed by a slice op for 4 gates.
// Slice order tells which graph path is for which gate. This method is
// to traverse the graph to find the 4 paths along the 4 gates. It helps to
// A CNTK LSTM op is created with stacked matmul followed by a slice op for 4 gates.
// Slice order tells which graph path is for which gate. This method is
// to traverse the graph to find the 4 paths along the 4 gates. It helps to
// subsequently find needed attributes in order to build an ONNX LSTM op.
void TraceLSTMPathes(const FunctionPtr& src,
string &f_activation,
string &g_activation,
string &h_activation,
RNNDirection &direction,
Variable &initStateH,
Variable &initStateC,
Variable &peepholeCi,
Variable &peepholeCo,
Variable &peepholeCf,
double &stabilizer_dh,
double &stabilizer_dc,
double &stabilizer_c)
void TraceLSTMPathes(const FunctionPtr &src,
string &f_activation,
string &g_activation,
string &h_activation,
RNNDirection &direction,
Variable &initStateH,
Variable &initStateC,
Variable &peepholeCi,
Variable &peepholeCo,
Variable &peepholeCf,
double &stabilizer_dh,
double &stabilizer_dc,
double &stabilizer_c)
{
// src has to be an LSTM node.
// src has to be an LSTM node.
std::vector<Variable> inputVars = src->Inputs();
std::vector<FunctionPtr> pastValueOps, futureValueOps;
GetDelayOps(inputVars, pastValueOps, futureValueOps);
@ -834,35 +828,34 @@ void TraceLSTMPathes(const FunctionPtr& src,
visitedFunctions.insert(it->Owner());
}
// First find the peephole op node.
// First find the peephole op node.
// see CNTK\bindings\python\cntk\layers\blocks.py node references.
std::vector<std::vector<FunctionPtr>> pathesBitBftJoint;
{
std::vector<FunctionPtr> currentPeepholePath;
// make a copy of traverse boundary
// make a copy of traverse boundary
std::unordered_set<FunctionPtr> peepHoleVisitedFunctions = visitedFunctions;
// traverse to find the joint of bit and bft
TraverseGraphWithPrePostActions(src->BlockRoot(),
peepHoleVisitedFunctions,
(std::function<void(const FunctionPtr&)>)[
&peepHoleVisitedFunctions, &pathesBitBftJoint, &currentPeepholePath](const FunctionPtr& function)
{
currentPeepholePath.push_back(function);
if (function->OpName() == L"Plus" &&
function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" &&
function->Inputs()[1].Owner() && function->Inputs()[1].Owner()->OpName() == L"ElementTimes")
{
pathesBitBftJoint.push_back(currentPeepholePath);
peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(),
[function](FunctionPtr f) {return function == f; }));
}
},
(std::function<void(const FunctionPtr&)>)[&currentPeepholePath](const FunctionPtr& function)
{
currentPeepholePath.pop_back();
});
peepHoleVisitedFunctions,
(std::function<void(const FunctionPtr &)>) [
&peepHoleVisitedFunctions, &pathesBitBftJoint, &currentPeepholePath
](const FunctionPtr &function) {
currentPeepholePath.push_back(function);
if (function->OpName() == L"Plus" &&
function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" &&
function->Inputs()[1].Owner() && function->Inputs()[1].Owner()->OpName() == L"ElementTimes")
{
pathesBitBftJoint.push_back(currentPeepholePath);
peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(),
[function](FunctionPtr f) { return function == f; }));
}
},
(std::function<void(const FunctionPtr &)>) [&currentPeepholePath](const FunctionPtr &function) {
currentPeepholePath.pop_back();
});
}
FunctionPtr peepholeCoOp;
@ -871,9 +864,9 @@ void TraceLSTMPathes(const FunctionPtr& src,
{
// the last ElementTimes op is the peephole op
std::vector<FunctionPtr> &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(),
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) {return p1.size() < p2.size(); });
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) { return p1.size() < p2.size(); });
std::vector<FunctionPtr>::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(),
[](FunctionPtr function) {return function->OpName() == L"ElementTimes"; });
[](FunctionPtr function) { return function->OpName() == L"ElementTimes"; });
if (itPeepholeOp == peepholePath.rend())
{
CNTK::LogicError("Cannot find peephole op from a LSTM graph");
@ -897,39 +890,36 @@ void TraceLSTMPathes(const FunctionPtr& src,
visitedFunctions.insert(peepholeCoOp);
TraverseGraphWithPrePostActions(src->BlockRoot(),
visitedFunctions,
(std::function<void(const FunctionPtr&)>)[&pathesToPlusSlice, &currentPath](const FunctionPtr& function)
{
currentPath.push_back(function);
if (function->OpName() == L"Slice")
{
FunctionPtr functionSource = function->Inputs()[0].Owner();
if (functionSource->OpName() == L"Plus")
{
pathesToPlusSlice.push_back(currentPath);
}
}
},
(std::function<void(const FunctionPtr&)>)[&currentPath](const FunctionPtr& function)
{
currentPath.pop_back();
});
visitedFunctions,
(std::function<void(const FunctionPtr &)>) [&pathesToPlusSlice, &currentPath ](const FunctionPtr &function) {
currentPath.push_back(function);
if (function->OpName() == L"Slice")
{
FunctionPtr functionSource = function->Inputs()[0].Owner();
if (functionSource->OpName() == L"Plus")
{
pathesToPlusSlice.push_back(currentPath);
}
}
},
(std::function<void(const FunctionPtr &)>) [&currentPath](const FunctionPtr &function) {
currentPath.pop_back();
});
// 4 gates of LSTM shall be traced.
// 4 gates of LSTM shall be traced.
if (pathesToPlusSlice.size() != 4)
{
CNTK::LogicError("pathesToPlusSlice.size() != 4");
}
std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(),
[](const std::vector<FunctionPtr>& path1, const std::vector<FunctionPtr>& path2)
{
FunctionPtr slice1 = *path1.rbegin();
FunctionPtr slice2 = *path2.rbegin();
int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
int beginIndex2 = slice2->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
return beginIndex1 < beginIndex2;
});
[](const std::vector<FunctionPtr> &path1, const std::vector<FunctionPtr> &path2) {
FunctionPtr slice1 = *path1.rbegin();
FunctionPtr slice2 = *path2.rbegin();
int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
int beginIndex2 = slice2->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
return beginIndex1 < beginIndex2;
});
// This code is heavily coupled with CNTK python layer code:
// https://github.com/Microsoft/CNTK/blob/44c626a483edeaff97b4f7a46847b055a1d483aa/bindings/python/cntk/layers/blocks.py#L261
@ -956,16 +946,13 @@ void TraceLSTMPathes(const FunctionPtr& src,
{
// Ci merges to ht_it_path via element-wise time
FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2];
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
peepholeCi = GetPeepholeVariableFromOp(peepholeOp);
}
{
// Cf merges to ht_ft_path via element-wise time
FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2];
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
peepholeCf = GetPeepholeVariableFromOp(peepholeOp);
FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp);
@ -998,8 +985,8 @@ FunctionPtr TraverseGraphFindFirstRNNOp(FunctionPtr src)
return nullptr;
}
void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation,
RNNDirection &direction, Variable &initStateH)
void TraceGRUPathes(const FunctionPtr &src, string &f_activation, string &g_activation,
RNNDirection &direction, Variable &initStateH)
{
std::vector<Variable> inputVars = src->Inputs();
std::vector<FunctionPtr> pastValueOps, futureValueOps;
@ -1037,8 +1024,8 @@ void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_acti
g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName()));
}
void TraceRNNPathes(const FunctionPtr& src, string &activation,
RNNDirection &direction, Variable &initStateH)
void TraceRNNPathes(const FunctionPtr &src, string &activation,
RNNDirection &direction, Variable &initStateH)
{
std::vector<Variable> inputVars = src->Inputs();
std::vector<FunctionPtr> pastValueOps, futureValueOps;
@ -1083,10 +1070,10 @@ std::vector<FunctionPtr> GetRNNBlocksFromSingleOrBidirectionalRNN(const Function
CNTK::LogicError("An %s op should start with an GRU op (single direction) or a Splice op (bidirectional).", RNNStepOpName.c_str());
}
// For single direction RNN, rnns.size() == 1. For bidirectional RNN, rnns.size() == 2.
// For single direction RNN, rnns.size() == 1. For bidirectional RNN, rnns.size() == 2.
// It is an error otherwise.
if (rnns.size() == 0 || rnns.size() > 2 ||
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) {return ToString(f->OpName()) != RNNStepOpName; }))
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) { return ToString(f->OpName()) != RNNStepOpName; }))
{
CNTK::LogicError("Invalid number of RNN ops to construct an ONNX %s node.", RNNStepOpName.c_str());
}

Просмотреть файл

@ -716,10 +716,12 @@ def test_Neg(tmpdir):
verify_no_input(model, tmpdir, 'Neg_0')
#OptimizedRNNStack
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3), (True, 2, 4, 8), (True, 2, 6, 8),
(True, 4, 2, 3), (False, 2, 2, 3))
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size", OPTIM_RNN_STACK_CONFIGS)
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, tmpdir, device_id):
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3, 'lstm'), (True, 2, 4, 8, 'lstm'), (True, 2, 6, 8, 'lstm'),
(True, 4, 2, 3, 'lstm'), (False, 2, 2, 3, 'lstm'),
(True, 1, 2, 3, 'rnnReLU'), (True, 4, 4, 8, 'rnnReLU'), (False, 2, 6, 8, 'rnnReLU'),
(True, 4, 2, 3, 'rnnTanh'), (False, 2, 2, 3, 'rnnTanh'), (True, 1, 2, 3, 'rnnTanh'))
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size, recurrent_op", OPTIM_RNN_STACK_CONFIGS)
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, recurrent_op, tmpdir, device_id):
if device_id == -1:
pytest.skip('Test only runs on GPU')
dev = cntk_device(device_id)
@ -728,7 +730,7 @@ def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, t
W = C.parameter((C.InferredDimension, input_size), constant_initializer(0.1), device=dev)
x = C.sequence.input_variable(shape=(input_size,))
s = np.asarray(np.random.uniform(-1, 1, (5,input_size)), dtype=np.float32)
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, name='MyRnnStack')
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, recurrent_op=recurrent_op, name='MyRnnStack')
f.parameters[0].value = np.reshape(np.arange(np.prod(f.parameters[0].value.shape), dtype=np.float32), f.parameters[0].value.shape)
verify_one_input(f, s, tmpdir, model_filename)