Integrate sptiwari/add_rnn_to_ornn2 into master

This commit is contained in:
Project Philly 2018-03-22 23:24:40 +00:00 коммит произвёл CNTK Team
Родитель 65961c9c19 3f7246991e
Коммит 5b04f46aa4
6 изменённых файлов: 825 добавлений и 798 удалений

Просмотреть файл

@ -3429,6 +3429,9 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
auto bidirectional = (bool)(src->Attributes()[L"bidirectional"].Value<bool>());
auto recurrentOp = (wstring)src->Attributes()[L"recurrentOp"].Value<wstring>();
if (!Operators::IsOptimizedRnnStackOp(recurrentOp))
InvalidArgument("Recurrent op used for OptimizedRNNStack is not supported for ONNX export.");
size_t numDirections = bidirectional ? 2 : 1;
size_t inputSize = src->Inputs()[0].Shape()[0];
auto Wcombined = src->Inputs()[1];
@ -3550,10 +3553,17 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
outputs.push_back(outputArg_Yc);
// ==== Step 6. Add ONNX LSTM node ====
auto rnnOpNameLookup = Operators::OptimizedRnnToOnnxOpLookup();
auto rnnNodeName = (src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name())) + std::to_string(i);
functionNode = graph->AddNode(rnnNodeName, "LSTM", "", inputs, outputs);
functionNode = graph->AddNode(rnnNodeName, rnnOpNameLookup[recurrentOp], "", inputs, outputs);
std::vector<std::string> singleDirectionActivation({ "Sigmoid", "Tanh", "Tanh" }); // REVIEW: Check this is the order.
std::vector<std::string> singleDirectionActivation;
if (recurrentOp == L"lstm")
singleDirectionActivation = { "Sigmoid", "Tanh", "Tanh" };
else if (recurrentOp == L"rnnReLU")
singleDirectionActivation = { "Relu" };
else if (recurrentOp == L"rnnTanh")
singleDirectionActivation = { "Tanh" };
std::vector<std::string> activations;
activations.insert(activations.end(), singleDirectionActivation.begin(), singleDirectionActivation.end());
if (bidirectional)
@ -3575,7 +3585,6 @@ std::tuple<std::vector<NDArrayViewPtr>, std::vector<NDArrayViewPtr>, std::vector
CNTKToONNXHelper::SplitOptimzedRnnWtoIndivMats(Matrix<float>& WbigIn,
size_t numLayers, size_t inputSize, size_t hiddenSize, bool bidirectional, wstring recurrentOp)
{
std::vector<NDArrayViewPtr> onnxInputTensor(3);
size_t numDirections = bidirectional ? 2 : 1;
size_t numGates;
if (recurrentOp == L"lstm")
@ -3716,7 +3725,6 @@ std::vector<NDArrayViewPtr> CNTKToONNXHelper::ToRnnWeightPerLayerOnnxFormat(std:
size_t offset = 0;
for (size_t j = 0; j < numDirections; ++j)
{
// Matrix<float> temp = W[i*numDirections + j].Transpose(); // Have to do this because Matrix::InplaceTranspose is not implemented.
Matrix<float> temp(W[i*numDirections + j].GetNumCols(), W[i*numDirections + j].GetNumRows(), W[i*numDirections + j].Data(), CPUDEVICE);
currLayerWeightMatrix.SetColumnSlice(temp, offset, layerInputSize);
offset += layerInputSize;

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -490,5 +490,17 @@ namespace ONNX
{ L"Dropout" },
};
std::set<std::wstring> Operators::_optimizedRnnStackOpNames = {
{ L"lstm" },
{ L"rnnReLU" },
{ L"rnnTanh" },
};
std::unordered_map<std::wstring, std::string> Operators::_optimizedRnnOpNameToOnnxOpName = {
{ L"lstm", "LSTM" },
{ L"rnnReLU", "RNN" },
{ L"rnnTanh","RNN" },
};
}
}

Просмотреть файл

@ -53,6 +53,23 @@ namespace CNTK
return _cntkToONNXOpName;
}
//
// Method to check if a name is a valid optimizedRnnStack op name.
//
static inline bool IsOptimizedRnnStackOp(const std::wstring& opName)
{
return _optimizedRnnStackOpNames.find(opName) != _optimizedRnnStackOpNames.end();
}
//
// Return a lookup table that maps CNTK's optimizedRNNStack op to one of the
// ONNX RNN ops.
//
static inline const std::unordered_map<std::wstring, std::string>& OptimizedRnnToOnnxOpLookup()
{
return _optimizedRnnOpNameToOnnxOpName;
}
static std::tuple<int, int> GetElementWiseInputIndices(const std::wstring& opName);
//
@ -115,6 +132,8 @@ namespace CNTK
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
static std::unordered_map<std::wstring, std::vector<int>> _cntkToONNXInputIndices;
static std::set<std::wstring>_optimizedRnnStackOpNames;
static std::unordered_map<std::wstring, std::string> _optimizedRnnOpNameToOnnxOpName;
static std::set<std::wstring> _cntkLayerOPName;
};

Просмотреть файл

@ -41,8 +41,7 @@ std::string MapActivationNameCNTKToONNX(const std::string &cntkOp)
bool IsActivationOp(const std::string &activationName)
{
return
activationName == "Relu" || activationName == "ReLU" ||
return activationName == "Relu" || activationName == "ReLU" ||
activationName == "Tanh" ||
activationName == "Sigmoid" || activationName == "StableSigmoid" ||
activationName == "Affine" ||
@ -55,19 +54,19 @@ bool IsActivationOp(const std::string &activationName)
activationName == "Softplus";
}
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName)
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName)
{
if (activationName == "Relu")
{
return [](const Variable& x) { return ReLU(x); };
return [](const Variable &x) { return ReLU(x); };
}
else if (activationName == "Tanh")
{
return [](const Variable& x) { return Tanh(x); };
return [](const Variable &x) { return Tanh(x); };
}
else if (activationName == "Sigmoid")
{
return [](const Variable& x) { return Sigmoid(x); };
return [](const Variable &x) { return Sigmoid(x); };
}
// else if (activationName == "Affine")
// else if (activationName == "LeakyRelu")
@ -76,15 +75,15 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
// else if (activationName == "HardSigmoid")
else if (activationName == "Elu")
{
return [](const Variable& x) { return ELU(x); };
return [](const Variable &x) { return ELU(x); };
}
else if (activationName == "Softsign")
{
return [](const Variable& x) { return Softsign(x); };
return [](const Variable &x) { return Softsign(x); };
}
else if (activationName == "Softplus")
{
return [](const Variable& x) { return Softplus(x); };
return [](const Variable &x) { return Softplus(x); };
}
else
{
@ -92,12 +91,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
}
}
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
float activation_alpha)
{
if (activationName == "LeakyRelu")
{
return [activation_alpha](const Variable& x) { return LeakyReLU(x, activation_alpha); };
return [activation_alpha](const Variable &x) { return LeakyReLU(x, activation_alpha); };
}
else
{
@ -105,12 +104,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
}
}
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
float activation_alpha, float activation_beta)
{
if (activationName == "HardSigmoid")
{
return [activation_alpha, activation_beta](const Variable& x) { return HardSigmoid(x, activation_alpha, activation_beta); };
return [activation_alpha, activation_beta](const Variable &x) { return HardSigmoid(x, activation_alpha, activation_beta); };
}
else
{
@ -118,7 +117,7 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
}
}
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
GetActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
{
if (activations.size() < (direction + 1) * LSTMActivationCount)
@ -134,7 +133,7 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
// an activation needs those values or not.
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
if (hasAlphaBeta)
{
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
@ -155,10 +154,9 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
}
return std::make_tuple(iofActivationOp, cellActivationOp, hiddenActivationOp);
}
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
GetGRUActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
{
if (activations.size() < (direction + 1) * GRUActivationCount)
@ -170,7 +168,7 @@ GetGRUActivations(const std::vector<std::string> &activations, const std::vector
bool hasAlpha = activation_alpha.size() == (direction + 1) * GRUActivationCount;
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * GRUActivationCount;
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
if (hasAlphaBeta)
{
fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex], activation_beta[fActivationIndex]);
@ -190,7 +188,7 @@ GetGRUActivations(const std::vector<std::string> &activations, const std::vector
return std::make_tuple(fActivationOp, gActivationOp);
}
std::function<FunctionPtr(const Variable&)>
std::function<FunctionPtr(const Variable &)>
GetRNNActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
{
if (activations.size() < (direction + 1))
@ -201,7 +199,7 @@ GetRNNActivations(const std::vector<std::string> &activations, const std::vector
bool hasAlpha = activation_alpha.size() == (direction + 1);
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1);
std::function<FunctionPtr(const Variable&)> activationOp;
std::function<FunctionPtr(const Variable &)> activationOp;
if (hasAlphaBeta)
{
activationOp = ActivationMap(activations[activationIndex], activation_alpha[activationIndex], activation_beta[activationIndex]);
@ -219,14 +217,14 @@ GetRNNActivations(const std::vector<std::string> &activations, const std::vector
}
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
Variable prevOutput, Variable prevCellState,
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
{
size_t outputDim = prevOutput.Shape()[0];
int stacked_dim = (int)outputDim;
int stacked_dim = (int) outputDim;
FunctionPtr proj4;
if (B.IsInitialized())
@ -239,12 +237,12 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
}
// CNTK weight and bias are in icfo order.
std::vector<Axis> stack_axis({ Axis(-1) });
std::vector<Axis> stack_axis({Axis(-1)});
const int IGateIndex = 0, CGateIndex = 1, FGateIndex = 2, OGateIndex = 3;
FunctionPtr it_proj = Slice(proj4, stack_axis, { IGateIndex * stacked_dim }, { (IGateIndex + 1) * stacked_dim });
FunctionPtr bit_proj = Slice(proj4, stack_axis, { CGateIndex * stacked_dim }, { (CGateIndex + 1) * stacked_dim });
FunctionPtr ft_proj = Slice(proj4, stack_axis, { FGateIndex * stacked_dim }, { (FGateIndex + 1) * stacked_dim });
FunctionPtr ot_proj = Slice(proj4, stack_axis, { OGateIndex * stacked_dim }, { (OGateIndex + 1) * stacked_dim });
FunctionPtr it_proj = Slice(proj4, stack_axis, {IGateIndex * stacked_dim}, {(IGateIndex + 1) * stacked_dim});
FunctionPtr bit_proj = Slice(proj4, stack_axis, {CGateIndex * stacked_dim}, {(CGateIndex + 1) * stacked_dim});
FunctionPtr ft_proj = Slice(proj4, stack_axis, {FGateIndex * stacked_dim}, {(FGateIndex + 1) * stacked_dim});
FunctionPtr ot_proj = Slice(proj4, stack_axis, {OGateIndex * stacked_dim}, {(OGateIndex + 1) * stacked_dim});
bool hasPeephole = Ci.IsInitialized();
@ -263,17 +261,17 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
auto c = ct;
auto h = ht;
return{ h, c };
return {h, c};
}
FunctionPtr GRUCell(Variable input,
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
Variable prevOutput,
Constant &W, Constant &R, Constant &H1, Constant &B)
{
size_t outputDim = prevOutput.Shape()[0];
int stacked_dim = (int)outputDim;
int stacked_dim = (int) outputDim;
FunctionPtr projx3;
if (B.IsInitialized())
@ -284,17 +282,17 @@ FunctionPtr GRUCell(Variable input,
FunctionPtr projh2 = Times(R, prevOutput);
// both CNTK and ONNX weight and bias are in zrh order.
std::vector<Axis> stack_axis({ Axis(-1) });
std::vector<Axis> stack_axis({Axis(-1)});
FunctionPtr zt_proj =
Slice(projx3, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim }) +
Slice(projh2, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim });
Slice(projx3, stack_axis, {0 * stacked_dim}, {1 * stacked_dim}) +
Slice(projh2, stack_axis, {0 * stacked_dim}, {1 * stacked_dim});
FunctionPtr rt_proj =
Slice(projx3, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim }) +
Slice(projh2, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim });
Slice(projx3, stack_axis, {1 * stacked_dim}, {2 * stacked_dim}) +
Slice(projh2, stack_axis, {1 * stacked_dim}, {2 * stacked_dim});
FunctionPtr ct_proj =
Slice(projx3, stack_axis, { 2 * stacked_dim }, { 3 * stacked_dim });
Slice(projx3, stack_axis, {2 * stacked_dim}, {3 * stacked_dim});
FunctionPtr zt = fActivationOp(zt_proj);
@ -314,11 +312,12 @@ FunctionPtr GRUCell(Variable input,
}
FunctionPtr RNNCell(Variable input,
const std::function<FunctionPtr(const Variable&)> &activationOp,
const std::function<FunctionPtr(const Variable &)> &activationOp,
Variable prevOutput,
Constant &W, Constant &R, Constant &B)
{
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);;
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);
;
if (B.IsInitialized())
proj = B + proj;
@ -326,17 +325,16 @@ FunctionPtr RNNCell(Variable input,
return h;
}
#include "PrimitiveFunction.h"
#include "BlockFunction.h"
std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
const NDShape& cellShape,
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookC,
const NDShape &cellShape,
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookC,
Constant &W, Constant &R, Constant &B,
Constant &Ci, Constant &Cf, Constant &Co)
{
@ -353,15 +351,15 @@ std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
auto actualDc = recurrenceHookC(LSTMCell.second);
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
LSTMCell.first->ReplacePlaceholders({ { inputPlaceholder , input}, { dh, actualDh },{ dc, actualDc } });
LSTMCell.first->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}, {dc, actualDc}});
return std::make_tuple(LSTMCell.first, LSTMCell.second);
}
FunctionPtr GRUComponent(Variable input,
const NDShape& cellShape,
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
const NDShape &cellShape,
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
Constant &W, Constant &R, Constant &H1, Constant &B)
{
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
@ -374,16 +372,16 @@ FunctionPtr GRUComponent(Variable input,
auto actualDh = recurrenceHookH(gruCell);
gruCell->ReplacePlaceholders({ { dh, actualDh } });
gruCell->ReplacePlaceholders({{dh, actualDh}});
auto gruBlock = AsBlock(std::move(gruCell), { { inputPlaceholder , input } }, L"GRU", L"");
auto gruBlock = AsBlock(std::move(gruCell), {{inputPlaceholder, input}}, L"GRU", L"");
return gruBlock;
}
FunctionPtr RNNComponent(Variable input,
const NDShape& cellShape,
const std::function<FunctionPtr(const Variable&)> &activationOp,
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
const NDShape &cellShape,
const std::function<FunctionPtr(const Variable &)> &activationOp,
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
Constant &W, Constant &R, Constant &B)
{
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
@ -396,7 +394,7 @@ FunctionPtr RNNComponent(Variable input,
auto actualDh = recurrenceHookH(rnnCell);
rnnCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } });
rnnCell->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}});
return rnnCell;
}
@ -437,9 +435,8 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
std::vector<FunctionPtr> outputHs;
for (int dir = 0; dir < numDirections; dir++)
{
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
// the first a few inputs are (in order): X, numDirections * W, numDirections * R
Variable X = inputs[0];
@ -460,7 +457,7 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
if (peepholeVariables.size() != 0 && peepholeVariables.size() != LSTMPeepholeCount && peepholeVariables.size() != 2 * LSTMPeepholeCount)
{
CNTK::LogicError("Peephole Variable count (%d) should be 0, 1 or 2 times the number of peephole factors (%d).",
(int)(peepholeVariables.size()), (int)LSTMPeepholeCount);
(int) (peepholeVariables.size()), (int) LSTMPeepholeCount);
}
else if (numDirections == 1 && peepholeVariables.size() >= LSTMPeepholeCount)
{
@ -488,29 +485,29 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
// if it is bidirectional LSTM, the second one will be the backword one.
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC;
std::function<FunctionPtr(const Variable &)> recurrenceHookH, recurrenceHookC;
if (go_backwards)
{
recurrenceHookH = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable& x) { return FutureValue(x, initCVariable); };
recurrenceHookH = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable &x) { return FutureValue(x, initCVariable); };
}
else
{
recurrenceHookH = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
recurrenceHookH = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
recurrenceHookC = [initCVariable](const Variable &x) { return PastValue(x, initCVariable); };
}
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp,
recurrenceHookH, recurrenceHookC, (Constant &)W, (Constant &)R, (Constant &)B,
(Constant &)Ci, (Constant &)Cf, (Constant &)Co);
X, {(size_t) hiddenDim}, iofActivationOp, cellActivationOp, hiddenActivationOp,
recurrenceHookH, recurrenceHookC, (Constant &) W, (Constant &) R, (Constant &) B,
(Constant &) Ci, (Constant &) Cf, (Constant &) Co);
outputHs.push_back(outputH);
}
if (outputHs.size() == 1)
return outputHs[0];
else
{
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
std::vector<Variable> operands({outputHs[0], outputHs[1]});
return Splice(operands, Axis(0), ToWString(node->Name()));
}
}
@ -522,9 +519,8 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
std::vector<FunctionPtr> outputHs;
for (int dir = 0; dir < numDirections; dir++)
{
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
Variable X = inputs[0];
@ -537,9 +533,9 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
Variable B;
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
if (numDirections == 1 && biasVariables.size() >= 1)
B = biasVariables[0];
B = biasVariables[dir];
else if (numDirections == 2 && biasVariables.size() == 2)
B = biasVariables[1];
B = biasVariables[dir];
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
@ -555,22 +551,22 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
// if it is bidirectional LSTM, the second one will be the backword one.
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
std::function<FunctionPtr(const Variable&)> recurrenceHook;
std::function<FunctionPtr(const Variable &)> recurrenceHook;
if (go_backwards)
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
else
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
outputH = GRUComponent(
X, { (size_t)hiddenDim }, fActivationOp, gActivationOp,
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)H1, (Constant &)B);
X, {(size_t) hiddenDim}, fActivationOp, gActivationOp,
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) H1, (Constant &) B);
outputHs.push_back(outputH);
}
if (outputHs.size() == 1)
return outputHs[0];
else
{
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
std::vector<Variable> operands({outputHs[0], outputHs[1]});
return Splice(operands, Axis(0), ToWString(node->Name()));
}
}
@ -582,7 +578,7 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
std::vector<FunctionPtr> outputHs;
for (int dir = 0; dir < numDirections; dir++)
{
std::function<FunctionPtr(const Variable&)> activationOp =
std::function<FunctionPtr(const Variable &)> activationOp =
GetRNNActivations(activations, activation_alpha, activation_beta, dir);
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
@ -592,9 +588,9 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
Variable B;
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
if (numDirections == 1 && biasVariables.size() >= 1)
B = biasVariables[0];
B = biasVariables[dir];
else if (numDirections == 2 && biasVariables.size() == 2)
B = biasVariables[1];
B = biasVariables[dir];
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
@ -605,39 +601,39 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
// if it is bidirectional LSTM, the second one will be the backword one.
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
std::function<FunctionPtr(const Variable&)> recurrenceHook;
std::function<FunctionPtr(const Variable &)> recurrenceHook;
if (go_backwards)
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
else
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
outputH = RNNComponent(
X, { (size_t)hiddenDim }, activationOp,
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B);
X, {(size_t) hiddenDim}, activationOp,
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) B);
outputHs.push_back(outputH);
}
if (outputHs.size() == 1)
return outputHs[0];
else
{
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
std::vector<Variable> operands({outputHs[0], outputHs[1]});
return Splice(operands, Axis(0), ToWString(node->Name()));
}
}
template <typename FunctionType>
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr>& visitedFunctions,
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr> &visitedFunctions,
FunctionType preFunctor, FunctionType postFunctor)
{
visitedFunctions.insert(cntkFunction);
preFunctor(cntkFunction);
std::vector<Variable> functionInputs = cntkFunction->Inputs();
for (const auto& input : functionInputs)
for (const auto &input : functionInputs)
{
if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end())
{
const auto& inputFunction = input.Owner();
const auto &inputFunction = input.Owner();
TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor);
}
}
@ -648,11 +644,9 @@ void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_se
bool IsSupportedRNNActivation(const std::wstring &cntkOpName)
{
static std::vector<std::wstring> supportedRNNActivations(
{
L"ReLU",
{L"ReLU",
L"Tanh",
L"StableSigmoid"
});
L"StableSigmoid"});
return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) !=
supportedRNNActivations.cend();
}
@ -758,9 +752,9 @@ double GetScaler(Variable variable)
switch (variable.GetDataType())
{
case DataType::Float:
return *((float *)cpuV->DataBuffer<float>());
return *((float *) cpuV->DataBuffer<float>());
case DataType::Double:
return *((double *)cpuV->DataBuffer<double>());
return *((double *) cpuV->DataBuffer<double>());
default:
NOT_IMPLEMENTED;
}
@ -789,7 +783,7 @@ void GetDelayOps(const std::vector<Variable> &inputVars,
// Slice order tells which graph path is for which gate. This method is
// to traverse the graph to find the 4 paths along the 4 gates. It helps to
// subsequently find needed attributes in order to build an ONNX LSTM op.
void TraceLSTMPathes(const FunctionPtr& src,
void TraceLSTMPathes(const FunctionPtr &src,
string &f_activation,
string &g_activation,
string &h_activation,
@ -846,9 +840,9 @@ void TraceLSTMPathes(const FunctionPtr& src,
// traverse to find the joint of bit and bft
TraverseGraphWithPrePostActions(src->BlockRoot(),
peepHoleVisitedFunctions,
(std::function<void(const FunctionPtr&)>)[
&peepHoleVisitedFunctions, &pathesBitBftJoint, &currentPeepholePath](const FunctionPtr& function)
{
(std::function<void(const FunctionPtr &)>) [
&peepHoleVisitedFunctions, &pathesBitBftJoint, &currentPeepholePath
](const FunctionPtr &function) {
currentPeepholePath.push_back(function);
if (function->OpName() == L"Plus" &&
function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" &&
@ -856,11 +850,10 @@ void TraceLSTMPathes(const FunctionPtr& src,
{
pathesBitBftJoint.push_back(currentPeepholePath);
peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(),
[function](FunctionPtr f) {return function == f; }));
[function](FunctionPtr f) { return function == f; }));
}
},
(std::function<void(const FunctionPtr&)>)[&currentPeepholePath](const FunctionPtr& function)
{
(std::function<void(const FunctionPtr &)>) [&currentPeepholePath](const FunctionPtr &function) {
currentPeepholePath.pop_back();
});
}
@ -871,9 +864,9 @@ void TraceLSTMPathes(const FunctionPtr& src,
{
// the last ElementTimes op is the peephole op
std::vector<FunctionPtr> &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(),
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) {return p1.size() < p2.size(); });
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) { return p1.size() < p2.size(); });
std::vector<FunctionPtr>::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(),
[](FunctionPtr function) {return function->OpName() == L"ElementTimes"; });
[](FunctionPtr function) { return function->OpName() == L"ElementTimes"; });
if (itPeepholeOp == peepholePath.rend())
{
CNTK::LogicError("Cannot find peephole op from a LSTM graph");
@ -898,8 +891,7 @@ void TraceLSTMPathes(const FunctionPtr& src,
TraverseGraphWithPrePostActions(src->BlockRoot(),
visitedFunctions,
(std::function<void(const FunctionPtr&)>)[&pathesToPlusSlice, &currentPath](const FunctionPtr& function)
{
(std::function<void(const FunctionPtr &)>) [&pathesToPlusSlice, &currentPath ](const FunctionPtr &function) {
currentPath.push_back(function);
if (function->OpName() == L"Slice")
{
@ -910,8 +902,7 @@ void TraceLSTMPathes(const FunctionPtr& src,
}
}
},
(std::function<void(const FunctionPtr&)>)[&currentPath](const FunctionPtr& function)
{
(std::function<void(const FunctionPtr &)>) [&currentPath](const FunctionPtr &function) {
currentPath.pop_back();
});
@ -922,8 +913,7 @@ void TraceLSTMPathes(const FunctionPtr& src,
}
std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(),
[](const std::vector<FunctionPtr>& path1, const std::vector<FunctionPtr>& path2)
{
[](const std::vector<FunctionPtr> &path1, const std::vector<FunctionPtr> &path2) {
FunctionPtr slice1 = *path1.rbegin();
FunctionPtr slice2 = *path2.rbegin();
int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
@ -956,16 +946,13 @@ void TraceLSTMPathes(const FunctionPtr& src,
{
// Ci merges to ht_it_path via element-wise time
FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2];
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
peepholeCi = GetPeepholeVariableFromOp(peepholeOp);
}
{
// Cf merges to ht_ft_path via element-wise time
FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2];
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
peepholeCf = GetPeepholeVariableFromOp(peepholeOp);
FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp);
@ -998,7 +985,7 @@ FunctionPtr TraverseGraphFindFirstRNNOp(FunctionPtr src)
return nullptr;
}
void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation,
void TraceGRUPathes(const FunctionPtr &src, string &f_activation, string &g_activation,
RNNDirection &direction, Variable &initStateH)
{
std::vector<Variable> inputVars = src->Inputs();
@ -1037,7 +1024,7 @@ void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_acti
g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName()));
}
void TraceRNNPathes(const FunctionPtr& src, string &activation,
void TraceRNNPathes(const FunctionPtr &src, string &activation,
RNNDirection &direction, Variable &initStateH)
{
std::vector<Variable> inputVars = src->Inputs();
@ -1086,7 +1073,7 @@ std::vector<FunctionPtr> GetRNNBlocksFromSingleOrBidirectionalRNN(const Function
// For single direction RNN, rnns.size() == 1. For bidirectional RNN, rnns.size() == 2.
// It is an error otherwise.
if (rnns.size() == 0 || rnns.size() > 2 ||
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) {return ToString(f->OpName()) != RNNStepOpName; }))
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) { return ToString(f->OpName()) != RNNStepOpName; }))
{
CNTK::LogicError("Invalid number of RNN ops to construct an ONNX %s node.", RNNStepOpName.c_str());
}

Просмотреть файл

@ -716,10 +716,12 @@ def test_Neg(tmpdir):
verify_no_input(model, tmpdir, 'Neg_0')
#OptimizedRNNStack
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3), (True, 2, 4, 8), (True, 2, 6, 8),
(True, 4, 2, 3), (False, 2, 2, 3))
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size", OPTIM_RNN_STACK_CONFIGS)
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, tmpdir, device_id):
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3, 'lstm'), (True, 2, 4, 8, 'lstm'), (True, 2, 6, 8, 'lstm'),
(True, 4, 2, 3, 'lstm'), (False, 2, 2, 3, 'lstm'),
(True, 1, 2, 3, 'rnnReLU'), (True, 4, 4, 8, 'rnnReLU'), (False, 2, 6, 8, 'rnnReLU'),
(True, 4, 2, 3, 'rnnTanh'), (False, 2, 2, 3, 'rnnTanh'), (True, 1, 2, 3, 'rnnTanh'))
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size, recurrent_op", OPTIM_RNN_STACK_CONFIGS)
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, recurrent_op, tmpdir, device_id):
if device_id == -1:
pytest.skip('Test only runs on GPU')
dev = cntk_device(device_id)
@ -728,7 +730,7 @@ def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, t
W = C.parameter((C.InferredDimension, input_size), constant_initializer(0.1), device=dev)
x = C.sequence.input_variable(shape=(input_size,))
s = np.asarray(np.random.uniform(-1, 1, (5,input_size)), dtype=np.float32)
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, name='MyRnnStack')
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, recurrent_op=recurrent_op, name='MyRnnStack')
f.parameters[0].value = np.reshape(np.arange(np.prod(f.parameters[0].value.shape), dtype=np.float32), f.parameters[0].value.shape)
verify_one_input(f, s, tmpdir, model_filename)