Integrate sptiwari/add_rnn_to_ornn2 into master
This commit is contained in:
Коммит
5b04f46aa4
|
@ -3429,6 +3429,9 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
|
|||
auto bidirectional = (bool)(src->Attributes()[L"bidirectional"].Value<bool>());
|
||||
auto recurrentOp = (wstring)src->Attributes()[L"recurrentOp"].Value<wstring>();
|
||||
|
||||
if (!Operators::IsOptimizedRnnStackOp(recurrentOp))
|
||||
InvalidArgument("Recurrent op used for OptimizedRNNStack is not supported for ONNX export.");
|
||||
|
||||
size_t numDirections = bidirectional ? 2 : 1;
|
||||
size_t inputSize = src->Inputs()[0].Shape()[0];
|
||||
auto Wcombined = src->Inputs()[1];
|
||||
|
@ -3550,10 +3553,17 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
|
|||
outputs.push_back(outputArg_Yc);
|
||||
|
||||
// ==== Step 6. Add ONNX LSTM node ====
|
||||
auto rnnOpNameLookup = Operators::OptimizedRnnToOnnxOpLookup();
|
||||
auto rnnNodeName = (src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name())) + std::to_string(i);
|
||||
functionNode = graph->AddNode(rnnNodeName, "LSTM", "", inputs, outputs);
|
||||
functionNode = graph->AddNode(rnnNodeName, rnnOpNameLookup[recurrentOp], "", inputs, outputs);
|
||||
|
||||
std::vector<std::string> singleDirectionActivation({ "Sigmoid", "Tanh", "Tanh" }); // REVIEW: Check this is the order.
|
||||
std::vector<std::string> singleDirectionActivation;
|
||||
if (recurrentOp == L"lstm")
|
||||
singleDirectionActivation = { "Sigmoid", "Tanh", "Tanh" };
|
||||
else if (recurrentOp == L"rnnReLU")
|
||||
singleDirectionActivation = { "Relu" };
|
||||
else if (recurrentOp == L"rnnTanh")
|
||||
singleDirectionActivation = { "Tanh" };
|
||||
std::vector<std::string> activations;
|
||||
activations.insert(activations.end(), singleDirectionActivation.begin(), singleDirectionActivation.end());
|
||||
if (bidirectional)
|
||||
|
@ -3575,7 +3585,6 @@ std::tuple<std::vector<NDArrayViewPtr>, std::vector<NDArrayViewPtr>, std::vector
|
|||
CNTKToONNXHelper::SplitOptimzedRnnWtoIndivMats(Matrix<float>& WbigIn,
|
||||
size_t numLayers, size_t inputSize, size_t hiddenSize, bool bidirectional, wstring recurrentOp)
|
||||
{
|
||||
std::vector<NDArrayViewPtr> onnxInputTensor(3);
|
||||
size_t numDirections = bidirectional ? 2 : 1;
|
||||
size_t numGates;
|
||||
if (recurrentOp == L"lstm")
|
||||
|
@ -3716,7 +3725,6 @@ std::vector<NDArrayViewPtr> CNTKToONNXHelper::ToRnnWeightPerLayerOnnxFormat(std:
|
|||
size_t offset = 0;
|
||||
for (size_t j = 0; j < numDirections; ++j)
|
||||
{
|
||||
// Matrix<float> temp = W[i*numDirections + j].Transpose(); // Have to do this because Matrix::InplaceTranspose is not implemented.
|
||||
Matrix<float> temp(W[i*numDirections + j].GetNumCols(), W[i*numDirections + j].GetNumRows(), W[i*numDirections + j].Data(), CPUDEVICE);
|
||||
currLayerWeightMatrix.SetColumnSlice(temp, offset, layerInputSize);
|
||||
offset += layerInputSize;
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -490,5 +490,17 @@ namespace ONNX
|
|||
{ L"Dropout" },
|
||||
};
|
||||
|
||||
std::set<std::wstring> Operators::_optimizedRnnStackOpNames = {
|
||||
{ L"lstm" },
|
||||
{ L"rnnReLU" },
|
||||
{ L"rnnTanh" },
|
||||
};
|
||||
|
||||
std::unordered_map<std::wstring, std::string> Operators::_optimizedRnnOpNameToOnnxOpName = {
|
||||
{ L"lstm", "LSTM" },
|
||||
{ L"rnnReLU", "RNN" },
|
||||
{ L"rnnTanh","RNN" },
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,6 +53,23 @@ namespace CNTK
|
|||
return _cntkToONNXOpName;
|
||||
}
|
||||
|
||||
//
|
||||
// Method to check if a name is a valid optimizedRnnStack op name.
|
||||
//
|
||||
static inline bool IsOptimizedRnnStackOp(const std::wstring& opName)
|
||||
{
|
||||
return _optimizedRnnStackOpNames.find(opName) != _optimizedRnnStackOpNames.end();
|
||||
}
|
||||
|
||||
//
|
||||
// Return a lookup table that maps CNTK's optimizedRNNStack op to one of the
|
||||
// ONNX RNN ops.
|
||||
//
|
||||
static inline const std::unordered_map<std::wstring, std::string>& OptimizedRnnToOnnxOpLookup()
|
||||
{
|
||||
return _optimizedRnnOpNameToOnnxOpName;
|
||||
}
|
||||
|
||||
static std::tuple<int, int> GetElementWiseInputIndices(const std::wstring& opName);
|
||||
|
||||
//
|
||||
|
@ -115,6 +132,8 @@ namespace CNTK
|
|||
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
|
||||
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
|
||||
static std::unordered_map<std::wstring, std::vector<int>> _cntkToONNXInputIndices;
|
||||
static std::set<std::wstring>_optimizedRnnStackOpNames;
|
||||
static std::unordered_map<std::wstring, std::string> _optimizedRnnOpNameToOnnxOpName;
|
||||
static std::set<std::wstring> _cntkLayerOPName;
|
||||
};
|
||||
|
||||
|
|
|
@ -41,8 +41,7 @@ std::string MapActivationNameCNTKToONNX(const std::string &cntkOp)
|
|||
|
||||
bool IsActivationOp(const std::string &activationName)
|
||||
{
|
||||
return
|
||||
activationName == "Relu" || activationName == "ReLU" ||
|
||||
return activationName == "Relu" || activationName == "ReLU" ||
|
||||
activationName == "Tanh" ||
|
||||
activationName == "Sigmoid" || activationName == "StableSigmoid" ||
|
||||
activationName == "Affine" ||
|
||||
|
@ -55,19 +54,19 @@ bool IsActivationOp(const std::string &activationName)
|
|||
activationName == "Softplus";
|
||||
}
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName)
|
||||
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName)
|
||||
{
|
||||
if (activationName == "Relu")
|
||||
{
|
||||
return [](const Variable& x) { return ReLU(x); };
|
||||
return [](const Variable &x) { return ReLU(x); };
|
||||
}
|
||||
else if (activationName == "Tanh")
|
||||
{
|
||||
return [](const Variable& x) { return Tanh(x); };
|
||||
return [](const Variable &x) { return Tanh(x); };
|
||||
}
|
||||
else if (activationName == "Sigmoid")
|
||||
{
|
||||
return [](const Variable& x) { return Sigmoid(x); };
|
||||
return [](const Variable &x) { return Sigmoid(x); };
|
||||
}
|
||||
// else if (activationName == "Affine")
|
||||
// else if (activationName == "LeakyRelu")
|
||||
|
@ -76,15 +75,15 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
|||
// else if (activationName == "HardSigmoid")
|
||||
else if (activationName == "Elu")
|
||||
{
|
||||
return [](const Variable& x) { return ELU(x); };
|
||||
return [](const Variable &x) { return ELU(x); };
|
||||
}
|
||||
else if (activationName == "Softsign")
|
||||
{
|
||||
return [](const Variable& x) { return Softsign(x); };
|
||||
return [](const Variable &x) { return Softsign(x); };
|
||||
}
|
||||
else if (activationName == "Softplus")
|
||||
{
|
||||
return [](const Variable& x) { return Softplus(x); };
|
||||
return [](const Variable &x) { return Softplus(x); };
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -92,12 +91,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
|||
}
|
||||
}
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
|
||||
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
|
||||
float activation_alpha)
|
||||
{
|
||||
if (activationName == "LeakyRelu")
|
||||
{
|
||||
return [activation_alpha](const Variable& x) { return LeakyReLU(x, activation_alpha); };
|
||||
return [activation_alpha](const Variable &x) { return LeakyReLU(x, activation_alpha); };
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -105,12 +104,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
|||
}
|
||||
}
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
|
||||
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
|
||||
float activation_alpha, float activation_beta)
|
||||
{
|
||||
if (activationName == "HardSigmoid")
|
||||
{
|
||||
return [activation_alpha, activation_beta](const Variable& x) { return HardSigmoid(x, activation_alpha, activation_beta); };
|
||||
return [activation_alpha, activation_beta](const Variable &x) { return HardSigmoid(x, activation_alpha, activation_beta); };
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -118,7 +117,7 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
|||
}
|
||||
}
|
||||
|
||||
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
|
||||
GetActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||
{
|
||||
if (activations.size() < (direction + 1) * LSTMActivationCount)
|
||||
|
@ -134,7 +133,7 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
|
|||
// an activation needs those values or not.
|
||||
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
|
||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
|
||||
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||
if (hasAlphaBeta)
|
||||
{
|
||||
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
|
||||
|
@ -155,10 +154,9 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
|
|||
}
|
||||
|
||||
return std::make_tuple(iofActivationOp, cellActivationOp, hiddenActivationOp);
|
||||
|
||||
}
|
||||
|
||||
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
|
||||
GetGRUActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||
{
|
||||
if (activations.size() < (direction + 1) * GRUActivationCount)
|
||||
|
@ -170,7 +168,7 @@ GetGRUActivations(const std::vector<std::string> &activations, const std::vector
|
|||
|
||||
bool hasAlpha = activation_alpha.size() == (direction + 1) * GRUActivationCount;
|
||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * GRUActivationCount;
|
||||
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
|
||||
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
|
||||
if (hasAlphaBeta)
|
||||
{
|
||||
fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex], activation_beta[fActivationIndex]);
|
||||
|
@ -190,7 +188,7 @@ GetGRUActivations(const std::vector<std::string> &activations, const std::vector
|
|||
return std::make_tuple(fActivationOp, gActivationOp);
|
||||
}
|
||||
|
||||
std::function<FunctionPtr(const Variable&)>
|
||||
std::function<FunctionPtr(const Variable &)>
|
||||
GetRNNActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||
{
|
||||
if (activations.size() < (direction + 1))
|
||||
|
@ -201,7 +199,7 @@ GetRNNActivations(const std::vector<std::string> &activations, const std::vector
|
|||
|
||||
bool hasAlpha = activation_alpha.size() == (direction + 1);
|
||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1);
|
||||
std::function<FunctionPtr(const Variable&)> activationOp;
|
||||
std::function<FunctionPtr(const Variable &)> activationOp;
|
||||
if (hasAlphaBeta)
|
||||
{
|
||||
activationOp = ActivationMap(activations[activationIndex], activation_alpha[activationIndex], activation_beta[activationIndex]);
|
||||
|
@ -219,14 +217,14 @@ GetRNNActivations(const std::vector<std::string> &activations, const std::vector
|
|||
}
|
||||
|
||||
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
||||
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
|
||||
Variable prevOutput, Variable prevCellState,
|
||||
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
|
||||
{
|
||||
size_t outputDim = prevOutput.Shape()[0];
|
||||
int stacked_dim = (int)outputDim;
|
||||
int stacked_dim = (int) outputDim;
|
||||
|
||||
FunctionPtr proj4;
|
||||
if (B.IsInitialized())
|
||||
|
@ -239,12 +237,12 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
|||
}
|
||||
|
||||
// CNTK weight and bias are in icfo order.
|
||||
std::vector<Axis> stack_axis({ Axis(-1) });
|
||||
std::vector<Axis> stack_axis({Axis(-1)});
|
||||
const int IGateIndex = 0, CGateIndex = 1, FGateIndex = 2, OGateIndex = 3;
|
||||
FunctionPtr it_proj = Slice(proj4, stack_axis, { IGateIndex * stacked_dim }, { (IGateIndex + 1) * stacked_dim });
|
||||
FunctionPtr bit_proj = Slice(proj4, stack_axis, { CGateIndex * stacked_dim }, { (CGateIndex + 1) * stacked_dim });
|
||||
FunctionPtr ft_proj = Slice(proj4, stack_axis, { FGateIndex * stacked_dim }, { (FGateIndex + 1) * stacked_dim });
|
||||
FunctionPtr ot_proj = Slice(proj4, stack_axis, { OGateIndex * stacked_dim }, { (OGateIndex + 1) * stacked_dim });
|
||||
FunctionPtr it_proj = Slice(proj4, stack_axis, {IGateIndex * stacked_dim}, {(IGateIndex + 1) * stacked_dim});
|
||||
FunctionPtr bit_proj = Slice(proj4, stack_axis, {CGateIndex * stacked_dim}, {(CGateIndex + 1) * stacked_dim});
|
||||
FunctionPtr ft_proj = Slice(proj4, stack_axis, {FGateIndex * stacked_dim}, {(FGateIndex + 1) * stacked_dim});
|
||||
FunctionPtr ot_proj = Slice(proj4, stack_axis, {OGateIndex * stacked_dim}, {(OGateIndex + 1) * stacked_dim});
|
||||
|
||||
bool hasPeephole = Ci.IsInitialized();
|
||||
|
||||
|
@ -263,17 +261,17 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
|||
auto c = ct;
|
||||
auto h = ht;
|
||||
|
||||
return{ h, c };
|
||||
return {h, c};
|
||||
}
|
||||
|
||||
FunctionPtr GRUCell(Variable input,
|
||||
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
|
||||
Variable prevOutput,
|
||||
Constant &W, Constant &R, Constant &H1, Constant &B)
|
||||
{
|
||||
size_t outputDim = prevOutput.Shape()[0];
|
||||
int stacked_dim = (int)outputDim;
|
||||
int stacked_dim = (int) outputDim;
|
||||
|
||||
FunctionPtr projx3;
|
||||
if (B.IsInitialized())
|
||||
|
@ -284,17 +282,17 @@ FunctionPtr GRUCell(Variable input,
|
|||
FunctionPtr projh2 = Times(R, prevOutput);
|
||||
|
||||
// both CNTK and ONNX weight and bias are in zrh order.
|
||||
std::vector<Axis> stack_axis({ Axis(-1) });
|
||||
std::vector<Axis> stack_axis({Axis(-1)});
|
||||
FunctionPtr zt_proj =
|
||||
Slice(projx3, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim }) +
|
||||
Slice(projh2, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim });
|
||||
Slice(projx3, stack_axis, {0 * stacked_dim}, {1 * stacked_dim}) +
|
||||
Slice(projh2, stack_axis, {0 * stacked_dim}, {1 * stacked_dim});
|
||||
|
||||
FunctionPtr rt_proj =
|
||||
Slice(projx3, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim }) +
|
||||
Slice(projh2, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim });
|
||||
Slice(projx3, stack_axis, {1 * stacked_dim}, {2 * stacked_dim}) +
|
||||
Slice(projh2, stack_axis, {1 * stacked_dim}, {2 * stacked_dim});
|
||||
|
||||
FunctionPtr ct_proj =
|
||||
Slice(projx3, stack_axis, { 2 * stacked_dim }, { 3 * stacked_dim });
|
||||
Slice(projx3, stack_axis, {2 * stacked_dim}, {3 * stacked_dim});
|
||||
|
||||
FunctionPtr zt = fActivationOp(zt_proj);
|
||||
|
||||
|
@ -314,11 +312,12 @@ FunctionPtr GRUCell(Variable input,
|
|||
}
|
||||
|
||||
FunctionPtr RNNCell(Variable input,
|
||||
const std::function<FunctionPtr(const Variable&)> &activationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &activationOp,
|
||||
Variable prevOutput,
|
||||
Constant &W, Constant &R, Constant &B)
|
||||
{
|
||||
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);;
|
||||
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);
|
||||
;
|
||||
if (B.IsInitialized())
|
||||
proj = B + proj;
|
||||
|
||||
|
@ -326,17 +325,16 @@ FunctionPtr RNNCell(Variable input,
|
|||
return h;
|
||||
}
|
||||
|
||||
|
||||
#include "PrimitiveFunction.h"
|
||||
#include "BlockFunction.h"
|
||||
|
||||
std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
|
||||
const NDShape& cellShape,
|
||||
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookC,
|
||||
const NDShape &cellShape,
|
||||
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
|
||||
const std::function<FunctionPtr(const Variable &)> &recurrenceHookC,
|
||||
Constant &W, Constant &R, Constant &B,
|
||||
Constant &Ci, Constant &Cf, Constant &Co)
|
||||
{
|
||||
|
@ -353,15 +351,15 @@ std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
|
|||
auto actualDc = recurrenceHookC(LSTMCell.second);
|
||||
|
||||
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
|
||||
LSTMCell.first->ReplacePlaceholders({ { inputPlaceholder , input}, { dh, actualDh },{ dc, actualDc } });
|
||||
LSTMCell.first->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}, {dc, actualDc}});
|
||||
return std::make_tuple(LSTMCell.first, LSTMCell.second);
|
||||
}
|
||||
|
||||
FunctionPtr GRUComponent(Variable input,
|
||||
const NDShape& cellShape,
|
||||
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
|
||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
||||
const NDShape &cellShape,
|
||||
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
|
||||
Constant &W, Constant &R, Constant &H1, Constant &B)
|
||||
{
|
||||
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||
|
@ -374,16 +372,16 @@ FunctionPtr GRUComponent(Variable input,
|
|||
|
||||
auto actualDh = recurrenceHookH(gruCell);
|
||||
|
||||
gruCell->ReplacePlaceholders({ { dh, actualDh } });
|
||||
gruCell->ReplacePlaceholders({{dh, actualDh}});
|
||||
|
||||
auto gruBlock = AsBlock(std::move(gruCell), { { inputPlaceholder , input } }, L"GRU", L"");
|
||||
auto gruBlock = AsBlock(std::move(gruCell), {{inputPlaceholder, input}}, L"GRU", L"");
|
||||
return gruBlock;
|
||||
}
|
||||
|
||||
FunctionPtr RNNComponent(Variable input,
|
||||
const NDShape& cellShape,
|
||||
const std::function<FunctionPtr(const Variable&)> &activationOp,
|
||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
||||
const NDShape &cellShape,
|
||||
const std::function<FunctionPtr(const Variable &)> &activationOp,
|
||||
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
|
||||
Constant &W, Constant &R, Constant &B)
|
||||
{
|
||||
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||
|
@ -396,7 +394,7 @@ FunctionPtr RNNComponent(Variable input,
|
|||
|
||||
auto actualDh = recurrenceHookH(rnnCell);
|
||||
|
||||
rnnCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } });
|
||||
rnnCell->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}});
|
||||
return rnnCell;
|
||||
}
|
||||
|
||||
|
@ -437,9 +435,8 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
|||
std::vector<FunctionPtr> outputHs;
|
||||
for (int dir = 0; dir < numDirections; dir++)
|
||||
{
|
||||
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
|
||||
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
|
||||
|
||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R
|
||||
Variable X = inputs[0];
|
||||
|
@ -460,7 +457,7 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
|||
if (peepholeVariables.size() != 0 && peepholeVariables.size() != LSTMPeepholeCount && peepholeVariables.size() != 2 * LSTMPeepholeCount)
|
||||
{
|
||||
CNTK::LogicError("Peephole Variable count (%d) should be 0, 1 or 2 times the number of peephole factors (%d).",
|
||||
(int)(peepholeVariables.size()), (int)LSTMPeepholeCount);
|
||||
(int) (peepholeVariables.size()), (int) LSTMPeepholeCount);
|
||||
}
|
||||
else if (numDirections == 1 && peepholeVariables.size() >= LSTMPeepholeCount)
|
||||
{
|
||||
|
@ -488,29 +485,29 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
|||
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC;
|
||||
std::function<FunctionPtr(const Variable &)> recurrenceHookH, recurrenceHookC;
|
||||
if (go_backwards)
|
||||
{
|
||||
recurrenceHookH = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
||||
recurrenceHookC = [initCVariable](const Variable& x) { return FutureValue(x, initCVariable); };
|
||||
recurrenceHookH = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
|
||||
recurrenceHookC = [initCVariable](const Variable &x) { return FutureValue(x, initCVariable); };
|
||||
}
|
||||
else
|
||||
{
|
||||
recurrenceHookH = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
||||
recurrenceHookC = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
|
||||
recurrenceHookH = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
|
||||
recurrenceHookC = [initCVariable](const Variable &x) { return PastValue(x, initCVariable); };
|
||||
}
|
||||
|
||||
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
|
||||
X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp,
|
||||
recurrenceHookH, recurrenceHookC, (Constant &)W, (Constant &)R, (Constant &)B,
|
||||
(Constant &)Ci, (Constant &)Cf, (Constant &)Co);
|
||||
X, {(size_t) hiddenDim}, iofActivationOp, cellActivationOp, hiddenActivationOp,
|
||||
recurrenceHookH, recurrenceHookC, (Constant &) W, (Constant &) R, (Constant &) B,
|
||||
(Constant &) Ci, (Constant &) Cf, (Constant &) Co);
|
||||
outputHs.push_back(outputH);
|
||||
}
|
||||
if (outputHs.size() == 1)
|
||||
return outputHs[0];
|
||||
else
|
||||
{
|
||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
||||
std::vector<Variable> operands({outputHs[0], outputHs[1]});
|
||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||
}
|
||||
}
|
||||
|
@ -522,9 +519,8 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
|||
std::vector<FunctionPtr> outputHs;
|
||||
for (int dir = 0; dir < numDirections; dir++)
|
||||
{
|
||||
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
|
||||
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
||||
(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
|
||||
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
|
||||
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
|
||||
|
||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
|
||||
Variable X = inputs[0];
|
||||
|
@ -537,9 +533,9 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
|||
Variable B;
|
||||
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
||||
if (numDirections == 1 && biasVariables.size() >= 1)
|
||||
B = biasVariables[0];
|
||||
B = biasVariables[dir];
|
||||
else if (numDirections == 2 && biasVariables.size() == 2)
|
||||
B = biasVariables[1];
|
||||
B = biasVariables[dir];
|
||||
|
||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
|
||||
|
||||
|
@ -555,22 +551,22 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
|||
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> recurrenceHook;
|
||||
std::function<FunctionPtr(const Variable &)> recurrenceHook;
|
||||
if (go_backwards)
|
||||
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
||||
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
|
||||
else
|
||||
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
||||
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
|
||||
|
||||
outputH = GRUComponent(
|
||||
X, { (size_t)hiddenDim }, fActivationOp, gActivationOp,
|
||||
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)H1, (Constant &)B);
|
||||
X, {(size_t) hiddenDim}, fActivationOp, gActivationOp,
|
||||
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) H1, (Constant &) B);
|
||||
outputHs.push_back(outputH);
|
||||
}
|
||||
if (outputHs.size() == 1)
|
||||
return outputHs[0];
|
||||
else
|
||||
{
|
||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
||||
std::vector<Variable> operands({outputHs[0], outputHs[1]});
|
||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||
}
|
||||
}
|
||||
|
@ -582,7 +578,7 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
|||
std::vector<FunctionPtr> outputHs;
|
||||
for (int dir = 0; dir < numDirections; dir++)
|
||||
{
|
||||
std::function<FunctionPtr(const Variable&)> activationOp =
|
||||
std::function<FunctionPtr(const Variable &)> activationOp =
|
||||
GetRNNActivations(activations, activation_alpha, activation_beta, dir);
|
||||
|
||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
|
||||
|
@ -592,9 +588,9 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
|||
Variable B;
|
||||
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
||||
if (numDirections == 1 && biasVariables.size() >= 1)
|
||||
B = biasVariables[0];
|
||||
B = biasVariables[dir];
|
||||
else if (numDirections == 2 && biasVariables.size() == 2)
|
||||
B = biasVariables[1];
|
||||
B = biasVariables[dir];
|
||||
|
||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
|
||||
|
||||
|
@ -605,39 +601,39 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
|||
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||
|
||||
std::function<FunctionPtr(const Variable&)> recurrenceHook;
|
||||
std::function<FunctionPtr(const Variable &)> recurrenceHook;
|
||||
if (go_backwards)
|
||||
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
||||
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
|
||||
else
|
||||
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
||||
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
|
||||
|
||||
outputH = RNNComponent(
|
||||
X, { (size_t)hiddenDim }, activationOp,
|
||||
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B);
|
||||
X, {(size_t) hiddenDim}, activationOp,
|
||||
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) B);
|
||||
outputHs.push_back(outputH);
|
||||
}
|
||||
if (outputHs.size() == 1)
|
||||
return outputHs[0];
|
||||
else
|
||||
{
|
||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
||||
std::vector<Variable> operands({outputHs[0], outputHs[1]});
|
||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctionType>
|
||||
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr>& visitedFunctions,
|
||||
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr> &visitedFunctions,
|
||||
FunctionType preFunctor, FunctionType postFunctor)
|
||||
{
|
||||
visitedFunctions.insert(cntkFunction);
|
||||
preFunctor(cntkFunction);
|
||||
|
||||
std::vector<Variable> functionInputs = cntkFunction->Inputs();
|
||||
for (const auto& input : functionInputs)
|
||||
for (const auto &input : functionInputs)
|
||||
{
|
||||
if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end())
|
||||
{
|
||||
const auto& inputFunction = input.Owner();
|
||||
const auto &inputFunction = input.Owner();
|
||||
TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor);
|
||||
}
|
||||
}
|
||||
|
@ -648,11 +644,9 @@ void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_se
|
|||
bool IsSupportedRNNActivation(const std::wstring &cntkOpName)
|
||||
{
|
||||
static std::vector<std::wstring> supportedRNNActivations(
|
||||
{
|
||||
L"ReLU",
|
||||
{L"ReLU",
|
||||
L"Tanh",
|
||||
L"StableSigmoid"
|
||||
});
|
||||
L"StableSigmoid"});
|
||||
return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) !=
|
||||
supportedRNNActivations.cend();
|
||||
}
|
||||
|
@ -758,9 +752,9 @@ double GetScaler(Variable variable)
|
|||
switch (variable.GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
return *((float *)cpuV->DataBuffer<float>());
|
||||
return *((float *) cpuV->DataBuffer<float>());
|
||||
case DataType::Double:
|
||||
return *((double *)cpuV->DataBuffer<double>());
|
||||
return *((double *) cpuV->DataBuffer<double>());
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
@ -789,7 +783,7 @@ void GetDelayOps(const std::vector<Variable> &inputVars,
|
|||
// Slice order tells which graph path is for which gate. This method is
|
||||
// to traverse the graph to find the 4 paths along the 4 gates. It helps to
|
||||
// subsequently find needed attributes in order to build an ONNX LSTM op.
|
||||
void TraceLSTMPathes(const FunctionPtr& src,
|
||||
void TraceLSTMPathes(const FunctionPtr &src,
|
||||
string &f_activation,
|
||||
string &g_activation,
|
||||
string &h_activation,
|
||||
|
@ -846,9 +840,9 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
|||
// traverse to find the joint of bit and bft
|
||||
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
||||
peepHoleVisitedFunctions,
|
||||
(std::function<void(const FunctionPtr&)>)[
|
||||
&peepHoleVisitedFunctions, &pathesBitBftJoint, ¤tPeepholePath](const FunctionPtr& function)
|
||||
{
|
||||
(std::function<void(const FunctionPtr &)>) [
|
||||
&peepHoleVisitedFunctions, &pathesBitBftJoint, ¤tPeepholePath
|
||||
](const FunctionPtr &function) {
|
||||
currentPeepholePath.push_back(function);
|
||||
if (function->OpName() == L"Plus" &&
|
||||
function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" &&
|
||||
|
@ -856,11 +850,10 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
|||
{
|
||||
pathesBitBftJoint.push_back(currentPeepholePath);
|
||||
peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(),
|
||||
[function](FunctionPtr f) {return function == f; }));
|
||||
[function](FunctionPtr f) { return function == f; }));
|
||||
}
|
||||
},
|
||||
(std::function<void(const FunctionPtr&)>)[¤tPeepholePath](const FunctionPtr& function)
|
||||
{
|
||||
(std::function<void(const FunctionPtr &)>) [¤tPeepholePath](const FunctionPtr &function) {
|
||||
currentPeepholePath.pop_back();
|
||||
});
|
||||
}
|
||||
|
@ -871,9 +864,9 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
|||
{
|
||||
// the last ElementTimes op is the peephole op
|
||||
std::vector<FunctionPtr> &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(),
|
||||
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) {return p1.size() < p2.size(); });
|
||||
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) { return p1.size() < p2.size(); });
|
||||
std::vector<FunctionPtr>::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(),
|
||||
[](FunctionPtr function) {return function->OpName() == L"ElementTimes"; });
|
||||
[](FunctionPtr function) { return function->OpName() == L"ElementTimes"; });
|
||||
if (itPeepholeOp == peepholePath.rend())
|
||||
{
|
||||
CNTK::LogicError("Cannot find peephole op from a LSTM graph");
|
||||
|
@ -898,8 +891,7 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
|||
|
||||
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
||||
visitedFunctions,
|
||||
(std::function<void(const FunctionPtr&)>)[&pathesToPlusSlice, ¤tPath](const FunctionPtr& function)
|
||||
{
|
||||
(std::function<void(const FunctionPtr &)>) [&pathesToPlusSlice, ¤tPath ](const FunctionPtr &function) {
|
||||
currentPath.push_back(function);
|
||||
if (function->OpName() == L"Slice")
|
||||
{
|
||||
|
@ -910,8 +902,7 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
|||
}
|
||||
}
|
||||
},
|
||||
(std::function<void(const FunctionPtr&)>)[¤tPath](const FunctionPtr& function)
|
||||
{
|
||||
(std::function<void(const FunctionPtr &)>) [¤tPath](const FunctionPtr &function) {
|
||||
currentPath.pop_back();
|
||||
});
|
||||
|
||||
|
@ -922,8 +913,7 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
|||
}
|
||||
|
||||
std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(),
|
||||
[](const std::vector<FunctionPtr>& path1, const std::vector<FunctionPtr>& path2)
|
||||
{
|
||||
[](const std::vector<FunctionPtr> &path1, const std::vector<FunctionPtr> &path2) {
|
||||
FunctionPtr slice1 = *path1.rbegin();
|
||||
FunctionPtr slice2 = *path2.rbegin();
|
||||
int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
||||
|
@ -956,16 +946,13 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
|||
{
|
||||
// Ci merges to ht_it_path via element-wise time
|
||||
FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2];
|
||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
|
||||
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||
peepholeCi = GetPeepholeVariableFromOp(peepholeOp);
|
||||
|
||||
}
|
||||
{
|
||||
// Cf merges to ht_ft_path via element-wise time
|
||||
FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2];
|
||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
|
||||
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||
peepholeCf = GetPeepholeVariableFromOp(peepholeOp);
|
||||
|
||||
FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp);
|
||||
|
@ -998,7 +985,7 @@ FunctionPtr TraverseGraphFindFirstRNNOp(FunctionPtr src)
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation,
|
||||
void TraceGRUPathes(const FunctionPtr &src, string &f_activation, string &g_activation,
|
||||
RNNDirection &direction, Variable &initStateH)
|
||||
{
|
||||
std::vector<Variable> inputVars = src->Inputs();
|
||||
|
@ -1037,7 +1024,7 @@ void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_acti
|
|||
g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName()));
|
||||
}
|
||||
|
||||
void TraceRNNPathes(const FunctionPtr& src, string &activation,
|
||||
void TraceRNNPathes(const FunctionPtr &src, string &activation,
|
||||
RNNDirection &direction, Variable &initStateH)
|
||||
{
|
||||
std::vector<Variable> inputVars = src->Inputs();
|
||||
|
@ -1086,7 +1073,7 @@ std::vector<FunctionPtr> GetRNNBlocksFromSingleOrBidirectionalRNN(const Function
|
|||
// For single direction RNN, rnns.size() == 1. For bidirectional RNN, rnns.size() == 2.
|
||||
// It is an error otherwise.
|
||||
if (rnns.size() == 0 || rnns.size() > 2 ||
|
||||
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) {return ToString(f->OpName()) != RNNStepOpName; }))
|
||||
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) { return ToString(f->OpName()) != RNNStepOpName; }))
|
||||
{
|
||||
CNTK::LogicError("Invalid number of RNN ops to construct an ONNX %s node.", RNNStepOpName.c_str());
|
||||
}
|
||||
|
|
|
@ -716,10 +716,12 @@ def test_Neg(tmpdir):
|
|||
verify_no_input(model, tmpdir, 'Neg_0')
|
||||
|
||||
#OptimizedRNNStack
|
||||
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3), (True, 2, 4, 8), (True, 2, 6, 8),
|
||||
(True, 4, 2, 3), (False, 2, 2, 3))
|
||||
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size", OPTIM_RNN_STACK_CONFIGS)
|
||||
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, tmpdir, device_id):
|
||||
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3, 'lstm'), (True, 2, 4, 8, 'lstm'), (True, 2, 6, 8, 'lstm'),
|
||||
(True, 4, 2, 3, 'lstm'), (False, 2, 2, 3, 'lstm'),
|
||||
(True, 1, 2, 3, 'rnnReLU'), (True, 4, 4, 8, 'rnnReLU'), (False, 2, 6, 8, 'rnnReLU'),
|
||||
(True, 4, 2, 3, 'rnnTanh'), (False, 2, 2, 3, 'rnnTanh'), (True, 1, 2, 3, 'rnnTanh'))
|
||||
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size, recurrent_op", OPTIM_RNN_STACK_CONFIGS)
|
||||
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, recurrent_op, tmpdir, device_id):
|
||||
if device_id == -1:
|
||||
pytest.skip('Test only runs on GPU')
|
||||
dev = cntk_device(device_id)
|
||||
|
@ -728,7 +730,7 @@ def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, t
|
|||
W = C.parameter((C.InferredDimension, input_size), constant_initializer(0.1), device=dev)
|
||||
x = C.sequence.input_variable(shape=(input_size,))
|
||||
s = np.asarray(np.random.uniform(-1, 1, (5,input_size)), dtype=np.float32)
|
||||
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, name='MyRnnStack')
|
||||
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, recurrent_op=recurrent_op, name='MyRnnStack')
|
||||
f.parameters[0].value = np.reshape(np.arange(np.prod(f.parameters[0].value.shape), dtype=np.float32), f.parameters[0].value.shape)
|
||||
verify_one_input(f, s, tmpdir, model_filename)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче