Integrate sptiwari/add_rnn_to_ornn2 into master
This commit is contained in:
Коммит
5b04f46aa4
|
@ -3429,6 +3429,9 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
|
||||||
auto bidirectional = (bool)(src->Attributes()[L"bidirectional"].Value<bool>());
|
auto bidirectional = (bool)(src->Attributes()[L"bidirectional"].Value<bool>());
|
||||||
auto recurrentOp = (wstring)src->Attributes()[L"recurrentOp"].Value<wstring>();
|
auto recurrentOp = (wstring)src->Attributes()[L"recurrentOp"].Value<wstring>();
|
||||||
|
|
||||||
|
if (!Operators::IsOptimizedRnnStackOp(recurrentOp))
|
||||||
|
InvalidArgument("Recurrent op used for OptimizedRNNStack is not supported for ONNX export.");
|
||||||
|
|
||||||
size_t numDirections = bidirectional ? 2 : 1;
|
size_t numDirections = bidirectional ? 2 : 1;
|
||||||
size_t inputSize = src->Inputs()[0].Shape()[0];
|
size_t inputSize = src->Inputs()[0].Shape()[0];
|
||||||
auto Wcombined = src->Inputs()[1];
|
auto Wcombined = src->Inputs()[1];
|
||||||
|
@ -3550,10 +3553,17 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Functi
|
||||||
outputs.push_back(outputArg_Yc);
|
outputs.push_back(outputArg_Yc);
|
||||||
|
|
||||||
// ==== Step 6. Add ONNX LSTM node ====
|
// ==== Step 6. Add ONNX LSTM node ====
|
||||||
|
auto rnnOpNameLookup = Operators::OptimizedRnnToOnnxOpLookup();
|
||||||
auto rnnNodeName = (src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name())) + std::to_string(i);
|
auto rnnNodeName = (src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name())) + std::to_string(i);
|
||||||
functionNode = graph->AddNode(rnnNodeName, "LSTM", "", inputs, outputs);
|
functionNode = graph->AddNode(rnnNodeName, rnnOpNameLookup[recurrentOp], "", inputs, outputs);
|
||||||
|
|
||||||
std::vector<std::string> singleDirectionActivation({ "Sigmoid", "Tanh", "Tanh" }); // REVIEW: Check this is the order.
|
std::vector<std::string> singleDirectionActivation;
|
||||||
|
if (recurrentOp == L"lstm")
|
||||||
|
singleDirectionActivation = { "Sigmoid", "Tanh", "Tanh" };
|
||||||
|
else if (recurrentOp == L"rnnReLU")
|
||||||
|
singleDirectionActivation = { "Relu" };
|
||||||
|
else if (recurrentOp == L"rnnTanh")
|
||||||
|
singleDirectionActivation = { "Tanh" };
|
||||||
std::vector<std::string> activations;
|
std::vector<std::string> activations;
|
||||||
activations.insert(activations.end(), singleDirectionActivation.begin(), singleDirectionActivation.end());
|
activations.insert(activations.end(), singleDirectionActivation.begin(), singleDirectionActivation.end());
|
||||||
if (bidirectional)
|
if (bidirectional)
|
||||||
|
@ -3575,7 +3585,6 @@ std::tuple<std::vector<NDArrayViewPtr>, std::vector<NDArrayViewPtr>, std::vector
|
||||||
CNTKToONNXHelper::SplitOptimzedRnnWtoIndivMats(Matrix<float>& WbigIn,
|
CNTKToONNXHelper::SplitOptimzedRnnWtoIndivMats(Matrix<float>& WbigIn,
|
||||||
size_t numLayers, size_t inputSize, size_t hiddenSize, bool bidirectional, wstring recurrentOp)
|
size_t numLayers, size_t inputSize, size_t hiddenSize, bool bidirectional, wstring recurrentOp)
|
||||||
{
|
{
|
||||||
std::vector<NDArrayViewPtr> onnxInputTensor(3);
|
|
||||||
size_t numDirections = bidirectional ? 2 : 1;
|
size_t numDirections = bidirectional ? 2 : 1;
|
||||||
size_t numGates;
|
size_t numGates;
|
||||||
if (recurrentOp == L"lstm")
|
if (recurrentOp == L"lstm")
|
||||||
|
@ -3716,7 +3725,6 @@ std::vector<NDArrayViewPtr> CNTKToONNXHelper::ToRnnWeightPerLayerOnnxFormat(std:
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (size_t j = 0; j < numDirections; ++j)
|
for (size_t j = 0; j < numDirections; ++j)
|
||||||
{
|
{
|
||||||
// Matrix<float> temp = W[i*numDirections + j].Transpose(); // Have to do this because Matrix::InplaceTranspose is not implemented.
|
|
||||||
Matrix<float> temp(W[i*numDirections + j].GetNumCols(), W[i*numDirections + j].GetNumRows(), W[i*numDirections + j].Data(), CPUDEVICE);
|
Matrix<float> temp(W[i*numDirections + j].GetNumCols(), W[i*numDirections + j].GetNumRows(), W[i*numDirections + j].Data(), CPUDEVICE);
|
||||||
currLayerWeightMatrix.SetColumnSlice(temp, offset, layerInputSize);
|
currLayerWeightMatrix.SetColumnSlice(temp, offset, layerInputSize);
|
||||||
offset += layerInputSize;
|
offset += layerInputSize;
|
||||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -490,5 +490,17 @@ namespace ONNX
|
||||||
{ L"Dropout" },
|
{ L"Dropout" },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::set<std::wstring> Operators::_optimizedRnnStackOpNames = {
|
||||||
|
{ L"lstm" },
|
||||||
|
{ L"rnnReLU" },
|
||||||
|
{ L"rnnTanh" },
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unordered_map<std::wstring, std::string> Operators::_optimizedRnnOpNameToOnnxOpName = {
|
||||||
|
{ L"lstm", "LSTM" },
|
||||||
|
{ L"rnnReLU", "RNN" },
|
||||||
|
{ L"rnnTanh","RNN" },
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,23 @@ namespace CNTK
|
||||||
return _cntkToONNXOpName;
|
return _cntkToONNXOpName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Method to check if a name is a valid optimizedRnnStack op name.
|
||||||
|
//
|
||||||
|
static inline bool IsOptimizedRnnStackOp(const std::wstring& opName)
|
||||||
|
{
|
||||||
|
return _optimizedRnnStackOpNames.find(opName) != _optimizedRnnStackOpNames.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Return a lookup table that maps CNTK's optimizedRNNStack op to one of the
|
||||||
|
// ONNX RNN ops.
|
||||||
|
//
|
||||||
|
static inline const std::unordered_map<std::wstring, std::string>& OptimizedRnnToOnnxOpLookup()
|
||||||
|
{
|
||||||
|
return _optimizedRnnOpNameToOnnxOpName;
|
||||||
|
}
|
||||||
|
|
||||||
static std::tuple<int, int> GetElementWiseInputIndices(const std::wstring& opName);
|
static std::tuple<int, int> GetElementWiseInputIndices(const std::wstring& opName);
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -115,6 +132,8 @@ namespace CNTK
|
||||||
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
|
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
|
||||||
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
|
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
|
||||||
static std::unordered_map<std::wstring, std::vector<int>> _cntkToONNXInputIndices;
|
static std::unordered_map<std::wstring, std::vector<int>> _cntkToONNXInputIndices;
|
||||||
|
static std::set<std::wstring>_optimizedRnnStackOpNames;
|
||||||
|
static std::unordered_map<std::wstring, std::string> _optimizedRnnOpNameToOnnxOpName;
|
||||||
static std::set<std::wstring> _cntkLayerOPName;
|
static std::set<std::wstring> _cntkLayerOPName;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -41,33 +41,32 @@ std::string MapActivationNameCNTKToONNX(const std::string &cntkOp)
|
||||||
|
|
||||||
bool IsActivationOp(const std::string &activationName)
|
bool IsActivationOp(const std::string &activationName)
|
||||||
{
|
{
|
||||||
return
|
return activationName == "Relu" || activationName == "ReLU" ||
|
||||||
activationName == "Relu" || activationName == "ReLU" ||
|
activationName == "Tanh" ||
|
||||||
activationName == "Tanh" ||
|
activationName == "Sigmoid" || activationName == "StableSigmoid" ||
|
||||||
activationName == "Sigmoid" || activationName == "StableSigmoid" ||
|
activationName == "Affine" ||
|
||||||
activationName == "Affine" ||
|
activationName == "LeakyRelu" || activationName == "LeakyReLU" ||
|
||||||
activationName == "LeakyRelu" || activationName == "LeakyReLU" ||
|
activationName == "ThresholdedRelu" || activationName == "ThresholdedReLU" ||
|
||||||
activationName == "ThresholdedRelu" || activationName == "ThresholdedReLU" ||
|
activationName == "ScaledTanh" ||
|
||||||
activationName == "ScaledTanh" ||
|
activationName == "HardSigmoid" ||
|
||||||
activationName == "HardSigmoid" ||
|
activationName == "Elu" || activationName == "ELU" ||
|
||||||
activationName == "Elu" || activationName == "ELU" ||
|
activationName == "Softsign" ||
|
||||||
activationName == "Softsign" ||
|
activationName == "Softplus";
|
||||||
activationName == "Softplus";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName)
|
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName)
|
||||||
{
|
{
|
||||||
if (activationName == "Relu")
|
if (activationName == "Relu")
|
||||||
{
|
{
|
||||||
return [](const Variable& x) { return ReLU(x); };
|
return [](const Variable &x) { return ReLU(x); };
|
||||||
}
|
}
|
||||||
else if (activationName == "Tanh")
|
else if (activationName == "Tanh")
|
||||||
{
|
{
|
||||||
return [](const Variable& x) { return Tanh(x); };
|
return [](const Variable &x) { return Tanh(x); };
|
||||||
}
|
}
|
||||||
else if (activationName == "Sigmoid")
|
else if (activationName == "Sigmoid")
|
||||||
{
|
{
|
||||||
return [](const Variable& x) { return Sigmoid(x); };
|
return [](const Variable &x) { return Sigmoid(x); };
|
||||||
}
|
}
|
||||||
// else if (activationName == "Affine")
|
// else if (activationName == "Affine")
|
||||||
// else if (activationName == "LeakyRelu")
|
// else if (activationName == "LeakyRelu")
|
||||||
|
@ -76,15 +75,15 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
||||||
// else if (activationName == "HardSigmoid")
|
// else if (activationName == "HardSigmoid")
|
||||||
else if (activationName == "Elu")
|
else if (activationName == "Elu")
|
||||||
{
|
{
|
||||||
return [](const Variable& x) { return ELU(x); };
|
return [](const Variable &x) { return ELU(x); };
|
||||||
}
|
}
|
||||||
else if (activationName == "Softsign")
|
else if (activationName == "Softsign")
|
||||||
{
|
{
|
||||||
return [](const Variable& x) { return Softsign(x); };
|
return [](const Variable &x) { return Softsign(x); };
|
||||||
}
|
}
|
||||||
else if (activationName == "Softplus")
|
else if (activationName == "Softplus")
|
||||||
{
|
{
|
||||||
return [](const Variable& x) { return Softplus(x); };
|
return [](const Variable &x) { return Softplus(x); };
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -92,12 +91,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
|
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
|
||||||
float activation_alpha)
|
float activation_alpha)
|
||||||
{
|
{
|
||||||
if (activationName == "LeakyRelu")
|
if (activationName == "LeakyRelu")
|
||||||
{
|
{
|
||||||
return [activation_alpha](const Variable& x) { return LeakyReLU(x, activation_alpha); };
|
return [activation_alpha](const Variable &x) { return LeakyReLU(x, activation_alpha); };
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -105,12 +104,12 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &activationName,
|
std::function<FunctionPtr(const Variable &)> ActivationMap(const std::string &activationName,
|
||||||
float activation_alpha, float activation_beta)
|
float activation_alpha, float activation_beta)
|
||||||
{
|
{
|
||||||
if (activationName == "HardSigmoid")
|
if (activationName == "HardSigmoid")
|
||||||
{
|
{
|
||||||
return [activation_alpha, activation_beta](const Variable& x) { return HardSigmoid(x, activation_alpha, activation_beta); };
|
return [activation_alpha, activation_beta](const Variable &x) { return HardSigmoid(x, activation_alpha, activation_beta); };
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -118,23 +117,23 @@ std::function<FunctionPtr(const Variable&)> ActivationMap(const std::string &act
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
|
||||||
GetActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
GetActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||||
{
|
{
|
||||||
if (activations.size() < (direction + 1) * LSTMActivationCount)
|
if (activations.size() < (direction + 1) * LSTMActivationCount)
|
||||||
CNTK::LogicError("LSTM activations shall be a list of strings of size %d or %d ", LSTMActivationCount, LSTMActivationCount * 2);
|
CNTK::LogicError("LSTM activations shall be a list of strings of size %d or %d ", LSTMActivationCount, LSTMActivationCount * 2);
|
||||||
|
|
||||||
//
|
//
|
||||||
int iofActivationIndex = direction * LSTMActivationCount + LSTMActivationFIndex;
|
int iofActivationIndex = direction * LSTMActivationCount + LSTMActivationFIndex;
|
||||||
int cellActivation = direction * LSTMActivationCount + LSTMActivationGIndex;
|
int cellActivation = direction * LSTMActivationCount + LSTMActivationGIndex;
|
||||||
int hiddenActivationIndex = direction * LSTMActivationCount + LSTMActivationHIndex;
|
int hiddenActivationIndex = direction * LSTMActivationCount + LSTMActivationHIndex;
|
||||||
|
|
||||||
// ONNX spec is not clear on how activation alpha and beta is set.
|
// ONNX spec is not clear on how activation alpha and beta is set.
|
||||||
// Here we assume that if they are set, they are set for all activations, regardless whether
|
// Here we assume that if they are set, they are set for all activations, regardless whether
|
||||||
// an activation needs those values or not.
|
// an activation needs those values or not.
|
||||||
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
|
bool hasAlpha = activation_alpha.size() == (direction + 1) * LSTMActivationCount;
|
||||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
|
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * LSTMActivationCount;
|
||||||
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||||
if (hasAlphaBeta)
|
if (hasAlphaBeta)
|
||||||
{
|
{
|
||||||
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
|
iofActivationOp = ActivationMap(activations[iofActivationIndex], activation_alpha[iofActivationIndex], activation_beta[iofActivationIndex]);
|
||||||
|
@ -155,22 +154,21 @@ GetActivations(const std::vector<std::string> &activations, const std::vector<fl
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(iofActivationOp, cellActivationOp, hiddenActivationOp);
|
return std::make_tuple(iofActivationOp, cellActivationOp, hiddenActivationOp);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
std::tuple<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>
|
||||||
GetGRUActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
GetGRUActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||||
{
|
{
|
||||||
if (activations.size() < (direction + 1) * GRUActivationCount)
|
if (activations.size() < (direction + 1) * GRUActivationCount)
|
||||||
CNTK::LogicError("GRU activations shall be a list of strings of size %d or %d", GRUActivationCount, GRUActivationCount * 2);
|
CNTK::LogicError("GRU activations shall be a list of strings of size %d or %d", GRUActivationCount, GRUActivationCount * 2);
|
||||||
|
|
||||||
//
|
//
|
||||||
int fActivationIndex = direction * GRUActivationCount + GRUActivationFIndex;
|
int fActivationIndex = direction * GRUActivationCount + GRUActivationFIndex;
|
||||||
int gActivationIndex = direction * GRUActivationCount + GRUActivationGIndex;
|
int gActivationIndex = direction * GRUActivationCount + GRUActivationGIndex;
|
||||||
|
|
||||||
bool hasAlpha = activation_alpha.size() == (direction + 1) * GRUActivationCount;
|
bool hasAlpha = activation_alpha.size() == (direction + 1) * GRUActivationCount;
|
||||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * GRUActivationCount;
|
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1) * GRUActivationCount;
|
||||||
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
|
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
|
||||||
if (hasAlphaBeta)
|
if (hasAlphaBeta)
|
||||||
{
|
{
|
||||||
fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex], activation_beta[fActivationIndex]);
|
fActivationOp = ActivationMap(activations[fActivationIndex], activation_alpha[fActivationIndex], activation_beta[fActivationIndex]);
|
||||||
|
@ -190,18 +188,18 @@ GetGRUActivations(const std::vector<std::string> &activations, const std::vector
|
||||||
return std::make_tuple(fActivationOp, gActivationOp);
|
return std::make_tuple(fActivationOp, gActivationOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<FunctionPtr(const Variable&)>
|
std::function<FunctionPtr(const Variable &)>
|
||||||
GetRNNActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
GetRNNActivations(const std::vector<std::string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta, int direction)
|
||||||
{
|
{
|
||||||
if (activations.size() < (direction + 1))
|
if (activations.size() < (direction + 1))
|
||||||
CNTK::LogicError("RNN activations shall be a list of strings of size 1 or 2");
|
CNTK::LogicError("RNN activations shall be a list of strings of size 1 or 2");
|
||||||
|
|
||||||
//
|
//
|
||||||
int activationIndex = direction;
|
int activationIndex = direction;
|
||||||
|
|
||||||
bool hasAlpha = activation_alpha.size() == (direction + 1);
|
bool hasAlpha = activation_alpha.size() == (direction + 1);
|
||||||
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1);
|
bool hasAlphaBeta = hasAlpha && activation_beta.size() == (direction + 1);
|
||||||
std::function<FunctionPtr(const Variable&)> activationOp;
|
std::function<FunctionPtr(const Variable &)> activationOp;
|
||||||
if (hasAlphaBeta)
|
if (hasAlphaBeta)
|
||||||
{
|
{
|
||||||
activationOp = ActivationMap(activations[activationIndex], activation_alpha[activationIndex], activation_beta[activationIndex]);
|
activationOp = ActivationMap(activations[activationIndex], activation_alpha[activationIndex], activation_beta[activationIndex]);
|
||||||
|
@ -219,14 +217,14 @@ GetRNNActivations(const std::vector<std::string> &activations, const std::vector
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
||||||
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
|
||||||
Variable prevOutput, Variable prevCellState,
|
Variable prevOutput, Variable prevCellState,
|
||||||
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
|
Constant &W, Constant &R, Constant &B, Constant &Ci, Constant &Cf, Constant &Co)
|
||||||
{
|
{
|
||||||
size_t outputDim = prevOutput.Shape()[0];
|
size_t outputDim = prevOutput.Shape()[0];
|
||||||
int stacked_dim = (int)outputDim;
|
int stacked_dim = (int) outputDim;
|
||||||
|
|
||||||
FunctionPtr proj4;
|
FunctionPtr proj4;
|
||||||
if (B.IsInitialized())
|
if (B.IsInitialized())
|
||||||
|
@ -238,13 +236,13 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
||||||
proj4 = Plus(Times(W, input), Times(R, prevOutput));
|
proj4 = Plus(Times(W, input), Times(R, prevOutput));
|
||||||
}
|
}
|
||||||
|
|
||||||
// CNTK weight and bias are in icfo order.
|
// CNTK weight and bias are in icfo order.
|
||||||
std::vector<Axis> stack_axis({ Axis(-1) });
|
std::vector<Axis> stack_axis({Axis(-1)});
|
||||||
const int IGateIndex = 0, CGateIndex = 1, FGateIndex = 2, OGateIndex = 3;
|
const int IGateIndex = 0, CGateIndex = 1, FGateIndex = 2, OGateIndex = 3;
|
||||||
FunctionPtr it_proj = Slice(proj4, stack_axis, { IGateIndex * stacked_dim }, { (IGateIndex + 1) * stacked_dim });
|
FunctionPtr it_proj = Slice(proj4, stack_axis, {IGateIndex * stacked_dim}, {(IGateIndex + 1) * stacked_dim});
|
||||||
FunctionPtr bit_proj = Slice(proj4, stack_axis, { CGateIndex * stacked_dim }, { (CGateIndex + 1) * stacked_dim });
|
FunctionPtr bit_proj = Slice(proj4, stack_axis, {CGateIndex * stacked_dim}, {(CGateIndex + 1) * stacked_dim});
|
||||||
FunctionPtr ft_proj = Slice(proj4, stack_axis, { FGateIndex * stacked_dim }, { (FGateIndex + 1) * stacked_dim });
|
FunctionPtr ft_proj = Slice(proj4, stack_axis, {FGateIndex * stacked_dim}, {(FGateIndex + 1) * stacked_dim});
|
||||||
FunctionPtr ot_proj = Slice(proj4, stack_axis, { OGateIndex * stacked_dim }, { (OGateIndex + 1) * stacked_dim });
|
FunctionPtr ot_proj = Slice(proj4, stack_axis, {OGateIndex * stacked_dim}, {(OGateIndex + 1) * stacked_dim});
|
||||||
|
|
||||||
bool hasPeephole = Ci.IsInitialized();
|
bool hasPeephole = Ci.IsInitialized();
|
||||||
|
|
||||||
|
@ -263,38 +261,38 @@ std::pair<FunctionPtr, FunctionPtr> LSTMPCell(Variable input,
|
||||||
auto c = ct;
|
auto c = ct;
|
||||||
auto h = ht;
|
auto h = ht;
|
||||||
|
|
||||||
return{ h, c };
|
return {h, c};
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr GRUCell(Variable input,
|
FunctionPtr GRUCell(Variable input,
|
||||||
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
|
||||||
Variable prevOutput,
|
Variable prevOutput,
|
||||||
Constant &W, Constant &R, Constant &H1, Constant &B)
|
Constant &W, Constant &R, Constant &H1, Constant &B)
|
||||||
{
|
{
|
||||||
size_t outputDim = prevOutput.Shape()[0];
|
size_t outputDim = prevOutput.Shape()[0];
|
||||||
int stacked_dim = (int)outputDim;
|
int stacked_dim = (int) outputDim;
|
||||||
|
|
||||||
FunctionPtr projx3;
|
FunctionPtr projx3;
|
||||||
if (B.IsInitialized())
|
if (B.IsInitialized())
|
||||||
projx3 = Plus(B, Times(W, input));
|
projx3 = Plus(B, Times(W, input));
|
||||||
else
|
else
|
||||||
projx3 = Times(W, input);
|
projx3 = Times(W, input);
|
||||||
|
|
||||||
FunctionPtr projh2 = Times(R, prevOutput);
|
FunctionPtr projh2 = Times(R, prevOutput);
|
||||||
|
|
||||||
// both CNTK and ONNX weight and bias are in zrh order.
|
// both CNTK and ONNX weight and bias are in zrh order.
|
||||||
std::vector<Axis> stack_axis({ Axis(-1) });
|
std::vector<Axis> stack_axis({Axis(-1)});
|
||||||
FunctionPtr zt_proj =
|
FunctionPtr zt_proj =
|
||||||
Slice(projx3, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim }) +
|
Slice(projx3, stack_axis, {0 * stacked_dim}, {1 * stacked_dim}) +
|
||||||
Slice(projh2, stack_axis, { 0 * stacked_dim }, { 1 * stacked_dim });
|
Slice(projh2, stack_axis, {0 * stacked_dim}, {1 * stacked_dim});
|
||||||
|
|
||||||
FunctionPtr rt_proj =
|
FunctionPtr rt_proj =
|
||||||
Slice(projx3, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim }) +
|
Slice(projx3, stack_axis, {1 * stacked_dim}, {2 * stacked_dim}) +
|
||||||
Slice(projh2, stack_axis, { 1 * stacked_dim }, { 2 * stacked_dim });
|
Slice(projh2, stack_axis, {1 * stacked_dim}, {2 * stacked_dim});
|
||||||
|
|
||||||
FunctionPtr ct_proj =
|
FunctionPtr ct_proj =
|
||||||
Slice(projx3, stack_axis, { 2 * stacked_dim }, { 3 * stacked_dim });
|
Slice(projx3, stack_axis, {2 * stacked_dim}, {3 * stacked_dim});
|
||||||
|
|
||||||
FunctionPtr zt = fActivationOp(zt_proj);
|
FunctionPtr zt = fActivationOp(zt_proj);
|
||||||
|
|
||||||
|
@ -307,18 +305,19 @@ FunctionPtr GRUCell(Variable input,
|
||||||
Constant one = W.GetDataType() == DataType::Float ? Constant::Scalar<float>(1.0f) : Constant::Scalar<double>(1.0);
|
Constant one = W.GetDataType() == DataType::Float ? Constant::Scalar<float>(1.0f) : Constant::Scalar<double>(1.0);
|
||||||
|
|
||||||
FunctionPtr ht = ElementTimes(one - zt, ct) + ElementTimes(zt, prevOutput);
|
FunctionPtr ht = ElementTimes(one - zt, ct) + ElementTimes(zt, prevOutput);
|
||||||
|
|
||||||
FunctionPtr h = ht;
|
FunctionPtr h = ht;
|
||||||
|
|
||||||
return ht;
|
return ht;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr RNNCell(Variable input,
|
FunctionPtr RNNCell(Variable input,
|
||||||
const std::function<FunctionPtr(const Variable&)> &activationOp,
|
const std::function<FunctionPtr(const Variable &)> &activationOp,
|
||||||
Variable prevOutput,
|
Variable prevOutput,
|
||||||
Constant &W, Constant &R, Constant &B)
|
Constant &W, Constant &R, Constant &B)
|
||||||
{
|
{
|
||||||
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);;
|
FunctionPtr proj = Times(W, input) + Times(R, prevOutput);
|
||||||
|
;
|
||||||
if (B.IsInitialized())
|
if (B.IsInitialized())
|
||||||
proj = B + proj;
|
proj = B + proj;
|
||||||
|
|
||||||
|
@ -326,19 +325,18 @@ FunctionPtr RNNCell(Variable input,
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#include "PrimitiveFunction.h"
|
#include "PrimitiveFunction.h"
|
||||||
#include "BlockFunction.h"
|
#include "BlockFunction.h"
|
||||||
|
|
||||||
std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
|
std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
|
||||||
const NDShape& cellShape,
|
const NDShape &cellShape,
|
||||||
const std::function<FunctionPtr(const Variable&)> &iofActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &iofActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)> &cellActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &cellActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)> &hiddenActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &hiddenActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
|
||||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookC,
|
const std::function<FunctionPtr(const Variable &)> &recurrenceHookC,
|
||||||
Constant &W, Constant &R, Constant &B,
|
Constant &W, Constant &R, Constant &B,
|
||||||
Constant &Ci, Constant &Cf, Constant &Co)
|
Constant &Ci, Constant &Cf, Constant &Co)
|
||||||
{
|
{
|
||||||
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||||
auto dc = PlaceholderVariable(cellShape, input.DynamicAxes());
|
auto dc = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||||
|
@ -353,16 +351,16 @@ std::tuple<FunctionPtr, FunctionPtr> LSTMPComponent(Variable input,
|
||||||
auto actualDc = recurrenceHookC(LSTMCell.second);
|
auto actualDc = recurrenceHookC(LSTMCell.second);
|
||||||
|
|
||||||
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
|
// Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
|
||||||
LSTMCell.first->ReplacePlaceholders({ { inputPlaceholder , input}, { dh, actualDh },{ dc, actualDc } });
|
LSTMCell.first->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}, {dc, actualDc}});
|
||||||
return std::make_tuple(LSTMCell.first, LSTMCell.second);
|
return std::make_tuple(LSTMCell.first, LSTMCell.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr GRUComponent(Variable input,
|
FunctionPtr GRUComponent(Variable input,
|
||||||
const NDShape& cellShape,
|
const NDShape &cellShape,
|
||||||
const std::function<FunctionPtr(const Variable&)> &fActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &fActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)> &gActivationOp,
|
const std::function<FunctionPtr(const Variable &)> &gActivationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
|
||||||
Constant &W, Constant &R, Constant &H1, Constant &B)
|
Constant &W, Constant &R, Constant &H1, Constant &B)
|
||||||
{
|
{
|
||||||
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||||
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
|
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
|
||||||
|
@ -374,17 +372,17 @@ FunctionPtr GRUComponent(Variable input,
|
||||||
|
|
||||||
auto actualDh = recurrenceHookH(gruCell);
|
auto actualDh = recurrenceHookH(gruCell);
|
||||||
|
|
||||||
gruCell->ReplacePlaceholders({ { dh, actualDh } });
|
gruCell->ReplacePlaceholders({{dh, actualDh}});
|
||||||
|
|
||||||
auto gruBlock = AsBlock(std::move(gruCell), { { inputPlaceholder , input } }, L"GRU", L"");
|
auto gruBlock = AsBlock(std::move(gruCell), {{inputPlaceholder, input}}, L"GRU", L"");
|
||||||
return gruBlock;
|
return gruBlock;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr RNNComponent(Variable input,
|
FunctionPtr RNNComponent(Variable input,
|
||||||
const NDShape& cellShape,
|
const NDShape &cellShape,
|
||||||
const std::function<FunctionPtr(const Variable&)> &activationOp,
|
const std::function<FunctionPtr(const Variable &)> &activationOp,
|
||||||
const std::function<FunctionPtr(const Variable&)>& recurrenceHookH,
|
const std::function<FunctionPtr(const Variable &)> &recurrenceHookH,
|
||||||
Constant &W, Constant &R, Constant &B)
|
Constant &W, Constant &R, Constant &B)
|
||||||
{
|
{
|
||||||
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
|
||||||
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
|
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());
|
||||||
|
@ -396,7 +394,7 @@ FunctionPtr RNNComponent(Variable input,
|
||||||
|
|
||||||
auto actualDh = recurrenceHookH(rnnCell);
|
auto actualDh = recurrenceHookH(rnnCell);
|
||||||
|
|
||||||
rnnCell->ReplacePlaceholders({ { inputPlaceholder , input },{ dh, actualDh } });
|
rnnCell->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}});
|
||||||
return rnnCell;
|
return rnnCell;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -414,7 +412,7 @@ const std::vector<Variable> FindByNameHint(const std::vector<Variable> &inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDirections,
|
Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDirections,
|
||||||
const std::string &nameHint, DataType datatype)
|
const std::string &nameHint, DataType datatype)
|
||||||
{
|
{
|
||||||
Variable initialVariable = datatype == DataType::Double ? Constant::Scalar(0.0) : Constant::Scalar(0.0f);
|
Variable initialVariable = datatype == DataType::Double ? Constant::Scalar(0.0) : Constant::Scalar(0.0f);
|
||||||
const std::vector<Variable> initialVariables = FindByNameHint(inputs, nameHint);
|
const std::vector<Variable> initialVariables = FindByNameHint(inputs, nameHint);
|
||||||
|
@ -431,15 +429,14 @@ Variable GetInitialStateVariable(const std::vector<Variable> &inputs, int numDir
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||||
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
||||||
{
|
{
|
||||||
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
||||||
std::vector<FunctionPtr> outputHs;
|
std::vector<FunctionPtr> outputHs;
|
||||||
for (int dir = 0; dir < numDirections; dir++)
|
for (int dir = 0; dir < numDirections; dir++)
|
||||||
{
|
{
|
||||||
std::function<FunctionPtr(const Variable&)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
std::function<FunctionPtr(const Variable &)> iofActivationOp, cellActivationOp, hiddenActivationOp;
|
||||||
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
|
||||||
(iofActivationOp, cellActivationOp, hiddenActivationOp) = GetActivations(activations, activation_alpha, activation_beta, dir);
|
|
||||||
|
|
||||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R
|
// the first a few inputs are (in order): X, numDirections * W, numDirections * R
|
||||||
Variable X = inputs[0];
|
Variable X = inputs[0];
|
||||||
|
@ -460,7 +457,7 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
||||||
if (peepholeVariables.size() != 0 && peepholeVariables.size() != LSTMPeepholeCount && peepholeVariables.size() != 2 * LSTMPeepholeCount)
|
if (peepholeVariables.size() != 0 && peepholeVariables.size() != LSTMPeepholeCount && peepholeVariables.size() != 2 * LSTMPeepholeCount)
|
||||||
{
|
{
|
||||||
CNTK::LogicError("Peephole Variable count (%d) should be 0, 1 or 2 times the number of peephole factors (%d).",
|
CNTK::LogicError("Peephole Variable count (%d) should be 0, 1 or 2 times the number of peephole factors (%d).",
|
||||||
(int)(peepholeVariables.size()), (int)LSTMPeepholeCount);
|
(int) (peepholeVariables.size()), (int) LSTMPeepholeCount);
|
||||||
}
|
}
|
||||||
else if (numDirections == 1 && peepholeVariables.size() >= LSTMPeepholeCount)
|
else if (numDirections == 1 && peepholeVariables.size() >= LSTMPeepholeCount)
|
||||||
{
|
{
|
||||||
|
@ -477,8 +474,8 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
||||||
|
|
||||||
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||||
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
|
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
|
||||||
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
|
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
|
||||||
// as two separate LSTM. Therefore we can divide the dimension of the first axis
|
// as two separate LSTM. Therefore we can divide the dimension of the first axis
|
||||||
// by 4 to get the hidden size.
|
// by 4 to get the hidden size.
|
||||||
int hiddenDim = W.Shape()[0] / 4;
|
int hiddenDim = W.Shape()[0] / 4;
|
||||||
|
|
||||||
|
@ -488,43 +485,42 @@ FunctionPtr CreateLSTM(const ONNXIR::Node *node, const std::vector<Variable> &in
|
||||||
// if it is bidirectional LSTM, the second one will be the backword one.
|
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||||
|
|
||||||
std::function<FunctionPtr(const Variable&)> recurrenceHookH, recurrenceHookC;
|
std::function<FunctionPtr(const Variable &)> recurrenceHookH, recurrenceHookC;
|
||||||
if (go_backwards)
|
if (go_backwards)
|
||||||
{
|
{
|
||||||
recurrenceHookH = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
recurrenceHookH = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
|
||||||
recurrenceHookC = [initCVariable](const Variable& x) { return FutureValue(x, initCVariable); };
|
recurrenceHookC = [initCVariable](const Variable &x) { return FutureValue(x, initCVariable); };
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
recurrenceHookH = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
recurrenceHookH = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
|
||||||
recurrenceHookC = [initCVariable](const Variable& x) { return PastValue(x, initCVariable); };
|
recurrenceHookC = [initCVariable](const Variable &x) { return PastValue(x, initCVariable); };
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
|
std::tie<FunctionPtr, FunctionPtr>(outputH, outputC) = LSTMPComponent(
|
||||||
X, { (size_t)hiddenDim }, iofActivationOp, cellActivationOp, hiddenActivationOp,
|
X, {(size_t) hiddenDim}, iofActivationOp, cellActivationOp, hiddenActivationOp,
|
||||||
recurrenceHookH, recurrenceHookC, (Constant &)W, (Constant &)R, (Constant &)B,
|
recurrenceHookH, recurrenceHookC, (Constant &) W, (Constant &) R, (Constant &) B,
|
||||||
(Constant &)Ci, (Constant &)Cf, (Constant &)Co);
|
(Constant &) Ci, (Constant &) Cf, (Constant &) Co);
|
||||||
outputHs.push_back(outputH);
|
outputHs.push_back(outputH);
|
||||||
}
|
}
|
||||||
if (outputHs.size() == 1)
|
if (outputHs.size() == 1)
|
||||||
return outputHs[0];
|
return outputHs[0];
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
std::vector<Variable> operands({outputHs[0], outputHs[1]});
|
||||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||||
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
||||||
{
|
{
|
||||||
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
||||||
std::vector<FunctionPtr> outputHs;
|
std::vector<FunctionPtr> outputHs;
|
||||||
for (int dir = 0; dir < numDirections; dir++)
|
for (int dir = 0; dir < numDirections; dir++)
|
||||||
{
|
{
|
||||||
std::function<FunctionPtr(const Variable&)> fActivationOp, gActivationOp;
|
std::function<FunctionPtr(const Variable &)> fActivationOp, gActivationOp;
|
||||||
std::tie<std::function<FunctionPtr(const Variable&)>, std::function<FunctionPtr(const Variable&)>>
|
std::tie<std::function<FunctionPtr(const Variable &)>, std::function<FunctionPtr(const Variable &)>>(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
|
||||||
(fActivationOp, gActivationOp) = GetGRUActivations(activations, activation_alpha, activation_beta, dir);
|
|
||||||
|
|
||||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
|
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
|
||||||
Variable X = inputs[0];
|
Variable X = inputs[0];
|
||||||
|
@ -537,16 +533,16 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
||||||
Variable B;
|
Variable B;
|
||||||
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
||||||
if (numDirections == 1 && biasVariables.size() >= 1)
|
if (numDirections == 1 && biasVariables.size() >= 1)
|
||||||
B = biasVariables[0];
|
B = biasVariables[dir];
|
||||||
else if (numDirections == 2 && biasVariables.size() == 2)
|
else if (numDirections == 2 && biasVariables.size() == 2)
|
||||||
B = biasVariables[1];
|
B = biasVariables[dir];
|
||||||
|
|
||||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
|
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
|
||||||
|
|
||||||
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
// ONNX spec https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-3---8
|
||||||
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
|
// tells that weight has shape [num_directions, 4*hidden_size, input_size]
|
||||||
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
|
// here in CNTK, there is no direction axis because CNTK treats bidirectional LSTM
|
||||||
// as two separate LSTM. Therefore we can divide the dimension of the first axis
|
// as two separate LSTM. Therefore we can divide the dimension of the first axis
|
||||||
// by 4 to get the hidden size.
|
// by 4 to get the hidden size.
|
||||||
int hiddenDim = W.Shape()[0] / GRUWeightDimensionHiddenMultiplier;
|
int hiddenDim = W.Shape()[0] / GRUWeightDimensionHiddenMultiplier;
|
||||||
|
|
||||||
|
@ -555,34 +551,34 @@ FunctionPtr CreateGRU(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
||||||
// if it is bidirectional LSTM, the second one will be the backword one.
|
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||||
|
|
||||||
std::function<FunctionPtr(const Variable&)> recurrenceHook;
|
std::function<FunctionPtr(const Variable &)> recurrenceHook;
|
||||||
if (go_backwards)
|
if (go_backwards)
|
||||||
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
|
||||||
else
|
else
|
||||||
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
|
||||||
|
|
||||||
outputH = GRUComponent(
|
outputH = GRUComponent(
|
||||||
X, { (size_t)hiddenDim }, fActivationOp, gActivationOp,
|
X, {(size_t) hiddenDim}, fActivationOp, gActivationOp,
|
||||||
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)H1, (Constant &)B);
|
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) H1, (Constant &) B);
|
||||||
outputHs.push_back(outputH);
|
outputHs.push_back(outputH);
|
||||||
}
|
}
|
||||||
if (outputHs.size() == 1)
|
if (outputHs.size() == 1)
|
||||||
return outputHs[0];
|
return outputHs[0];
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
std::vector<Variable> operands({outputHs[0], outputHs[1]});
|
||||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
|
||||||
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
const std::vector<string> &activations, const std::vector<float> &activation_alpha, const std::vector<float> &activation_beta)
|
||||||
{
|
{
|
||||||
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
int numDirections = direction == RNNDirectionBidirection ? 2 : 1;
|
||||||
std::vector<FunctionPtr> outputHs;
|
std::vector<FunctionPtr> outputHs;
|
||||||
for (int dir = 0; dir < numDirections; dir++)
|
for (int dir = 0; dir < numDirections; dir++)
|
||||||
{
|
{
|
||||||
std::function<FunctionPtr(const Variable&)> activationOp =
|
std::function<FunctionPtr(const Variable &)> activationOp =
|
||||||
GetRNNActivations(activations, activation_alpha, activation_beta, dir);
|
GetRNNActivations(activations, activation_alpha, activation_beta, dir);
|
||||||
|
|
||||||
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
|
// the first a few inputs are (in order): X, numDirections * W, numDirections * R, numDirections * H1
|
||||||
|
@ -592,9 +588,9 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
||||||
Variable B;
|
Variable B;
|
||||||
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
std::vector<Variable> biasVariables = FindByNameHint(inputs, LSTMInputBiasNameHint);
|
||||||
if (numDirections == 1 && biasVariables.size() >= 1)
|
if (numDirections == 1 && biasVariables.size() >= 1)
|
||||||
B = biasVariables[0];
|
B = biasVariables[dir];
|
||||||
else if (numDirections == 2 && biasVariables.size() == 2)
|
else if (numDirections == 2 && biasVariables.size() == 2)
|
||||||
B = biasVariables[1];
|
B = biasVariables[dir];
|
||||||
|
|
||||||
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
|
Variable initHVariable = GetInitialStateVariable(inputs, numDirections, GRUInputInitialHNameHint, X.GetDataType());
|
||||||
|
|
||||||
|
@ -605,39 +601,39 @@ FunctionPtr CreateRNN(const ONNXIR::Node *node, const std::vector<Variable> &inp
|
||||||
// if it is bidirectional LSTM, the second one will be the backword one.
|
// if it is bidirectional LSTM, the second one will be the backword one.
|
||||||
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
bool go_backwards = direction == RNNDirectionReverse || (numDirections == 2 && dir == 1);
|
||||||
|
|
||||||
std::function<FunctionPtr(const Variable&)> recurrenceHook;
|
std::function<FunctionPtr(const Variable &)> recurrenceHook;
|
||||||
if (go_backwards)
|
if (go_backwards)
|
||||||
recurrenceHook = [initHVariable](const Variable& x) { return FutureValue(x, initHVariable); };
|
recurrenceHook = [initHVariable](const Variable &x) { return FutureValue(x, initHVariable); };
|
||||||
else
|
else
|
||||||
recurrenceHook = [initHVariable](const Variable& x) { return PastValue(x, initHVariable); };
|
recurrenceHook = [initHVariable](const Variable &x) { return PastValue(x, initHVariable); };
|
||||||
|
|
||||||
outputH = RNNComponent(
|
outputH = RNNComponent(
|
||||||
X, { (size_t)hiddenDim }, activationOp,
|
X, {(size_t) hiddenDim}, activationOp,
|
||||||
recurrenceHook, (Constant &)W, (Constant &)R, (Constant &)B);
|
recurrenceHook, (Constant &) W, (Constant &) R, (Constant &) B);
|
||||||
outputHs.push_back(outputH);
|
outputHs.push_back(outputH);
|
||||||
}
|
}
|
||||||
if (outputHs.size() == 1)
|
if (outputHs.size() == 1)
|
||||||
return outputHs[0];
|
return outputHs[0];
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
std::vector<Variable> operands({ outputHs[0], outputHs[1] });
|
std::vector<Variable> operands({outputHs[0], outputHs[1]});
|
||||||
return Splice(operands, Axis(0), ToWString(node->Name()));
|
return Splice(operands, Axis(0), ToWString(node->Name()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename FunctionType>
|
template <typename FunctionType>
|
||||||
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr>& visitedFunctions,
|
void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_set<FunctionPtr> &visitedFunctions,
|
||||||
FunctionType preFunctor, FunctionType postFunctor)
|
FunctionType preFunctor, FunctionType postFunctor)
|
||||||
{
|
{
|
||||||
visitedFunctions.insert(cntkFunction);
|
visitedFunctions.insert(cntkFunction);
|
||||||
preFunctor(cntkFunction);
|
preFunctor(cntkFunction);
|
||||||
|
|
||||||
std::vector<Variable> functionInputs = cntkFunction->Inputs();
|
std::vector<Variable> functionInputs = cntkFunction->Inputs();
|
||||||
for (const auto& input : functionInputs)
|
for (const auto &input : functionInputs)
|
||||||
{
|
{
|
||||||
if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end())
|
if (input.IsOutput() && visitedFunctions.find(input.Owner()) == visitedFunctions.end())
|
||||||
{
|
{
|
||||||
const auto& inputFunction = input.Owner();
|
const auto &inputFunction = input.Owner();
|
||||||
TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor);
|
TraverseGraphWithPrePostActions(inputFunction, visitedFunctions, preFunctor, postFunctor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -648,13 +644,11 @@ void TraverseGraphWithPrePostActions(FunctionPtr cntkFunction, std::unordered_se
|
||||||
bool IsSupportedRNNActivation(const std::wstring &cntkOpName)
|
bool IsSupportedRNNActivation(const std::wstring &cntkOpName)
|
||||||
{
|
{
|
||||||
static std::vector<std::wstring> supportedRNNActivations(
|
static std::vector<std::wstring> supportedRNNActivations(
|
||||||
{
|
{L"ReLU",
|
||||||
L"ReLU",
|
L"Tanh",
|
||||||
L"Tanh",
|
L"StableSigmoid"});
|
||||||
L"StableSigmoid"
|
|
||||||
});
|
|
||||||
return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) !=
|
return std::find(supportedRNNActivations.cbegin(), supportedRNNActivations.cend(), cntkOpName) !=
|
||||||
supportedRNNActivations.cend();
|
supportedRNNActivations.cend();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string FindActivation(const std::vector<FunctionPtr> &path, int nth)
|
std::string FindActivation(const std::vector<FunctionPtr> &path, int nth)
|
||||||
|
@ -684,7 +678,7 @@ std::string FindActivation(const std::vector<FunctionPtr> &path, int nth)
|
||||||
|
|
||||||
Variable GetPeepholeVariableFromOp(FunctionPtr peepholeOp)
|
Variable GetPeepholeVariableFromOp(FunctionPtr peepholeOp)
|
||||||
{
|
{
|
||||||
// peephole variable is that child of peepholeOp that is neither stabilizer nor place holder
|
// peephole variable is that child of peepholeOp that is neither stabilizer nor place holder
|
||||||
if (peepholeOp->OpName() != L"ElementTimes")
|
if (peepholeOp->OpName() != L"ElementTimes")
|
||||||
CNTK::LogicError("Peephole operation must be ElementTimes");
|
CNTK::LogicError("Peephole operation must be ElementTimes");
|
||||||
|
|
||||||
|
@ -721,7 +715,7 @@ FunctionPtr GetStabilizerOp(FunctionPtr parentOp)
|
||||||
{
|
{
|
||||||
if (parentOp->Inputs()[i].Owner() &&
|
if (parentOp->Inputs()[i].Owner() &&
|
||||||
(parentOp->Inputs()[i].Owner()->OpName() == L"Times" ||
|
(parentOp->Inputs()[i].Owner()->OpName() == L"Times" ||
|
||||||
parentOp->Inputs()[i].Owner()->OpName() == L"ElementTimes"))
|
parentOp->Inputs()[i].Owner()->OpName() == L"ElementTimes"))
|
||||||
{
|
{
|
||||||
timesOp = parentOp->Inputs()[i].Owner();
|
timesOp = parentOp->Inputs()[i].Owner();
|
||||||
break;
|
break;
|
||||||
|
@ -758,9 +752,9 @@ double GetScaler(Variable variable)
|
||||||
switch (variable.GetDataType())
|
switch (variable.GetDataType())
|
||||||
{
|
{
|
||||||
case DataType::Float:
|
case DataType::Float:
|
||||||
return *((float *)cpuV->DataBuffer<float>());
|
return *((float *) cpuV->DataBuffer<float>());
|
||||||
case DataType::Double:
|
case DataType::Double:
|
||||||
return *((double *)cpuV->DataBuffer<double>());
|
return *((double *) cpuV->DataBuffer<double>());
|
||||||
default:
|
default:
|
||||||
NOT_IMPLEMENTED;
|
NOT_IMPLEMENTED;
|
||||||
}
|
}
|
||||||
|
@ -773,8 +767,8 @@ double GetStabilizerCoef(const FunctionPtr stabilizerDhOp)
|
||||||
return (log(exp(alpha * steepness) + 1.0F) / steepness);
|
return (log(exp(alpha * steepness) + 1.0F) / steepness);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetDelayOps(const std::vector<Variable> &inputVars,
|
void GetDelayOps(const std::vector<Variable> &inputVars,
|
||||||
std::vector<FunctionPtr> &pastValueOps, std::vector<FunctionPtr> &futureValueOps)
|
std::vector<FunctionPtr> &pastValueOps, std::vector<FunctionPtr> &futureValueOps)
|
||||||
{
|
{
|
||||||
for (std::vector<Variable>::const_iterator it = inputVars.cbegin(); it != inputVars.cend(); ++it)
|
for (std::vector<Variable>::const_iterator it = inputVars.cbegin(); it != inputVars.cend(); ++it)
|
||||||
{
|
{
|
||||||
|
@ -785,25 +779,25 @@ void GetDelayOps(const std::vector<Variable> &inputVars,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// A CNTK LSTM op is created with stacked matmul followed by a slice op for 4 gates.
|
// A CNTK LSTM op is created with stacked matmul followed by a slice op for 4 gates.
|
||||||
// Slice order tells which graph path is for which gate. This method is
|
// Slice order tells which graph path is for which gate. This method is
|
||||||
// to traverse the graph to find the 4 paths along the 4 gates. It helps to
|
// to traverse the graph to find the 4 paths along the 4 gates. It helps to
|
||||||
// subsequently find needed attributes in order to build an ONNX LSTM op.
|
// subsequently find needed attributes in order to build an ONNX LSTM op.
|
||||||
void TraceLSTMPathes(const FunctionPtr& src,
|
void TraceLSTMPathes(const FunctionPtr &src,
|
||||||
string &f_activation,
|
string &f_activation,
|
||||||
string &g_activation,
|
string &g_activation,
|
||||||
string &h_activation,
|
string &h_activation,
|
||||||
RNNDirection &direction,
|
RNNDirection &direction,
|
||||||
Variable &initStateH,
|
Variable &initStateH,
|
||||||
Variable &initStateC,
|
Variable &initStateC,
|
||||||
Variable &peepholeCi,
|
Variable &peepholeCi,
|
||||||
Variable &peepholeCo,
|
Variable &peepholeCo,
|
||||||
Variable &peepholeCf,
|
Variable &peepholeCf,
|
||||||
double &stabilizer_dh,
|
double &stabilizer_dh,
|
||||||
double &stabilizer_dc,
|
double &stabilizer_dc,
|
||||||
double &stabilizer_c)
|
double &stabilizer_c)
|
||||||
{
|
{
|
||||||
// src has to be an LSTM node.
|
// src has to be an LSTM node.
|
||||||
std::vector<Variable> inputVars = src->Inputs();
|
std::vector<Variable> inputVars = src->Inputs();
|
||||||
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
||||||
GetDelayOps(inputVars, pastValueOps, futureValueOps);
|
GetDelayOps(inputVars, pastValueOps, futureValueOps);
|
||||||
|
@ -834,35 +828,34 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
||||||
visitedFunctions.insert(it->Owner());
|
visitedFunctions.insert(it->Owner());
|
||||||
}
|
}
|
||||||
|
|
||||||
// First find the peephole op node.
|
// First find the peephole op node.
|
||||||
// see CNTK\bindings\python\cntk\layers\blocks.py node references.
|
// see CNTK\bindings\python\cntk\layers\blocks.py node references.
|
||||||
std::vector<std::vector<FunctionPtr>> pathesBitBftJoint;
|
std::vector<std::vector<FunctionPtr>> pathesBitBftJoint;
|
||||||
{
|
{
|
||||||
std::vector<FunctionPtr> currentPeepholePath;
|
std::vector<FunctionPtr> currentPeepholePath;
|
||||||
|
|
||||||
// make a copy of traverse boundary
|
// make a copy of traverse boundary
|
||||||
std::unordered_set<FunctionPtr> peepHoleVisitedFunctions = visitedFunctions;
|
std::unordered_set<FunctionPtr> peepHoleVisitedFunctions = visitedFunctions;
|
||||||
|
|
||||||
// traverse to find the joint of bit and bft
|
// traverse to find the joint of bit and bft
|
||||||
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
||||||
peepHoleVisitedFunctions,
|
peepHoleVisitedFunctions,
|
||||||
(std::function<void(const FunctionPtr&)>)[
|
(std::function<void(const FunctionPtr &)>) [
|
||||||
&peepHoleVisitedFunctions, &pathesBitBftJoint, ¤tPeepholePath](const FunctionPtr& function)
|
&peepHoleVisitedFunctions, &pathesBitBftJoint, ¤tPeepholePath
|
||||||
{
|
](const FunctionPtr &function) {
|
||||||
currentPeepholePath.push_back(function);
|
currentPeepholePath.push_back(function);
|
||||||
if (function->OpName() == L"Plus" &&
|
if (function->OpName() == L"Plus" &&
|
||||||
function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" &&
|
function->Inputs()[0].Owner() && function->Inputs()[0].Owner()->OpName() == L"ElementTimes" &&
|
||||||
function->Inputs()[1].Owner() && function->Inputs()[1].Owner()->OpName() == L"ElementTimes")
|
function->Inputs()[1].Owner() && function->Inputs()[1].Owner()->OpName() == L"ElementTimes")
|
||||||
{
|
{
|
||||||
pathesBitBftJoint.push_back(currentPeepholePath);
|
pathesBitBftJoint.push_back(currentPeepholePath);
|
||||||
peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(),
|
peepHoleVisitedFunctions.erase(std::find_if(peepHoleVisitedFunctions.begin(), peepHoleVisitedFunctions.end(),
|
||||||
[function](FunctionPtr f) {return function == f; }));
|
[function](FunctionPtr f) { return function == f; }));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
(std::function<void(const FunctionPtr&)>)[¤tPeepholePath](const FunctionPtr& function)
|
(std::function<void(const FunctionPtr &)>) [¤tPeepholePath](const FunctionPtr &function) {
|
||||||
{
|
currentPeepholePath.pop_back();
|
||||||
currentPeepholePath.pop_back();
|
});
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionPtr peepholeCoOp;
|
FunctionPtr peepholeCoOp;
|
||||||
|
@ -871,9 +864,9 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
||||||
{
|
{
|
||||||
// the last ElementTimes op is the peephole op
|
// the last ElementTimes op is the peephole op
|
||||||
std::vector<FunctionPtr> &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(),
|
std::vector<FunctionPtr> &peepholePath = *std::max_element(pathesBitBftJoint.begin(), pathesBitBftJoint.end(),
|
||||||
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) {return p1.size() < p2.size(); });
|
[](std::vector<FunctionPtr> &p1, std::vector<FunctionPtr> &p2) { return p1.size() < p2.size(); });
|
||||||
std::vector<FunctionPtr>::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(),
|
std::vector<FunctionPtr>::reverse_iterator itPeepholeOp = std::find_if(peepholePath.rbegin(), peepholePath.rend(),
|
||||||
[](FunctionPtr function) {return function->OpName() == L"ElementTimes"; });
|
[](FunctionPtr function) { return function->OpName() == L"ElementTimes"; });
|
||||||
if (itPeepholeOp == peepholePath.rend())
|
if (itPeepholeOp == peepholePath.rend())
|
||||||
{
|
{
|
||||||
CNTK::LogicError("Cannot find peephole op from a LSTM graph");
|
CNTK::LogicError("Cannot find peephole op from a LSTM graph");
|
||||||
|
@ -897,39 +890,36 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
||||||
visitedFunctions.insert(peepholeCoOp);
|
visitedFunctions.insert(peepholeCoOp);
|
||||||
|
|
||||||
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
TraverseGraphWithPrePostActions(src->BlockRoot(),
|
||||||
visitedFunctions,
|
visitedFunctions,
|
||||||
(std::function<void(const FunctionPtr&)>)[&pathesToPlusSlice, ¤tPath](const FunctionPtr& function)
|
(std::function<void(const FunctionPtr &)>) [&pathesToPlusSlice, ¤tPath ](const FunctionPtr &function) {
|
||||||
{
|
currentPath.push_back(function);
|
||||||
currentPath.push_back(function);
|
if (function->OpName() == L"Slice")
|
||||||
if (function->OpName() == L"Slice")
|
{
|
||||||
{
|
FunctionPtr functionSource = function->Inputs()[0].Owner();
|
||||||
FunctionPtr functionSource = function->Inputs()[0].Owner();
|
if (functionSource->OpName() == L"Plus")
|
||||||
if (functionSource->OpName() == L"Plus")
|
{
|
||||||
{
|
pathesToPlusSlice.push_back(currentPath);
|
||||||
pathesToPlusSlice.push_back(currentPath);
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
},
|
(std::function<void(const FunctionPtr &)>) [¤tPath](const FunctionPtr &function) {
|
||||||
(std::function<void(const FunctionPtr&)>)[¤tPath](const FunctionPtr& function)
|
currentPath.pop_back();
|
||||||
{
|
});
|
||||||
currentPath.pop_back();
|
|
||||||
});
|
|
||||||
|
|
||||||
// 4 gates of LSTM shall be traced.
|
// 4 gates of LSTM shall be traced.
|
||||||
if (pathesToPlusSlice.size() != 4)
|
if (pathesToPlusSlice.size() != 4)
|
||||||
{
|
{
|
||||||
CNTK::LogicError("pathesToPlusSlice.size() != 4");
|
CNTK::LogicError("pathesToPlusSlice.size() != 4");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(),
|
std::sort(pathesToPlusSlice.begin(), pathesToPlusSlice.end(),
|
||||||
[](const std::vector<FunctionPtr>& path1, const std::vector<FunctionPtr>& path2)
|
[](const std::vector<FunctionPtr> &path1, const std::vector<FunctionPtr> &path2) {
|
||||||
{
|
FunctionPtr slice1 = *path1.rbegin();
|
||||||
FunctionPtr slice1 = *path1.rbegin();
|
FunctionPtr slice2 = *path2.rbegin();
|
||||||
FunctionPtr slice2 = *path2.rbegin();
|
int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
||||||
int beginIndex1 = slice1->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
int beginIndex2 = slice2->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
||||||
int beginIndex2 = slice2->Attributes()[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
|
return beginIndex1 < beginIndex2;
|
||||||
return beginIndex1 < beginIndex2;
|
});
|
||||||
});
|
|
||||||
|
|
||||||
// This code is heavily coupled with CNTK python layer code:
|
// This code is heavily coupled with CNTK python layer code:
|
||||||
// https://github.com/Microsoft/CNTK/blob/44c626a483edeaff97b4f7a46847b055a1d483aa/bindings/python/cntk/layers/blocks.py#L261
|
// https://github.com/Microsoft/CNTK/blob/44c626a483edeaff97b4f7a46847b055a1d483aa/bindings/python/cntk/layers/blocks.py#L261
|
||||||
|
@ -956,16 +946,13 @@ void TraceLSTMPathes(const FunctionPtr& src,
|
||||||
{
|
{
|
||||||
// Ci merges to ht_it_path via element-wise time
|
// Ci merges to ht_it_path via element-wise time
|
||||||
FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2];
|
FunctionPtr plusOp = ht_it_path[ht_it_path.size() - 2];
|
||||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
|
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||||
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
|
||||||
peepholeCi = GetPeepholeVariableFromOp(peepholeOp);
|
peepholeCi = GetPeepholeVariableFromOp(peepholeOp);
|
||||||
|
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Cf merges to ht_ft_path via element-wise time
|
// Cf merges to ht_ft_path via element-wise time
|
||||||
FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2];
|
FunctionPtr plusOp = ht_ft_path[ht_ft_path.size() - 2];
|
||||||
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ?
|
FunctionPtr peepholeOp = plusOp->Inputs()[0].Owner()->OpName() != L"Slice" ? plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
||||||
plusOp->Inputs()[0].Owner() : plusOp->Inputs()[1].Owner();
|
|
||||||
peepholeCf = GetPeepholeVariableFromOp(peepholeOp);
|
peepholeCf = GetPeepholeVariableFromOp(peepholeOp);
|
||||||
|
|
||||||
FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp);
|
FunctionPtr stabilizerDcOp = GetStabilizerOp(peepholeOp);
|
||||||
|
@ -998,8 +985,8 @@ FunctionPtr TraverseGraphFindFirstRNNOp(FunctionPtr src)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_activation,
|
void TraceGRUPathes(const FunctionPtr &src, string &f_activation, string &g_activation,
|
||||||
RNNDirection &direction, Variable &initStateH)
|
RNNDirection &direction, Variable &initStateH)
|
||||||
{
|
{
|
||||||
std::vector<Variable> inputVars = src->Inputs();
|
std::vector<Variable> inputVars = src->Inputs();
|
||||||
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
||||||
|
@ -1037,8 +1024,8 @@ void TraceGRUPathes(const FunctionPtr& src, string &f_activation, string &g_acti
|
||||||
g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName()));
|
g_activation = MapActivationNameCNTKToONNX(ToString(gActivation->OpName()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceRNNPathes(const FunctionPtr& src, string &activation,
|
void TraceRNNPathes(const FunctionPtr &src, string &activation,
|
||||||
RNNDirection &direction, Variable &initStateH)
|
RNNDirection &direction, Variable &initStateH)
|
||||||
{
|
{
|
||||||
std::vector<Variable> inputVars = src->Inputs();
|
std::vector<Variable> inputVars = src->Inputs();
|
||||||
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
std::vector<FunctionPtr> pastValueOps, futureValueOps;
|
||||||
|
@ -1083,10 +1070,10 @@ std::vector<FunctionPtr> GetRNNBlocksFromSingleOrBidirectionalRNN(const Function
|
||||||
CNTK::LogicError("An %s op should start with an GRU op (single direction) or a Splice op (bidirectional).", RNNStepOpName.c_str());
|
CNTK::LogicError("An %s op should start with an GRU op (single direction) or a Splice op (bidirectional).", RNNStepOpName.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// For single direction RNN, rnns.size() == 1. For bidirectional RNN, rnns.size() == 2.
|
// For single direction RNN, rnns.size() == 1. For bidirectional RNN, rnns.size() == 2.
|
||||||
// It is an error otherwise.
|
// It is an error otherwise.
|
||||||
if (rnns.size() == 0 || rnns.size() > 2 ||
|
if (rnns.size() == 0 || rnns.size() > 2 ||
|
||||||
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) {return ToString(f->OpName()) != RNNStepOpName; }))
|
std::any_of(rnns.cbegin(), rnns.cend(), [RNNStepOpName](const FunctionPtr &f) { return ToString(f->OpName()) != RNNStepOpName; }))
|
||||||
{
|
{
|
||||||
CNTK::LogicError("Invalid number of RNN ops to construct an ONNX %s node.", RNNStepOpName.c_str());
|
CNTK::LogicError("Invalid number of RNN ops to construct an ONNX %s node.", RNNStepOpName.c_str());
|
||||||
}
|
}
|
||||||
|
|
|
@ -716,10 +716,12 @@ def test_Neg(tmpdir):
|
||||||
verify_no_input(model, tmpdir, 'Neg_0')
|
verify_no_input(model, tmpdir, 'Neg_0')
|
||||||
|
|
||||||
#OptimizedRNNStack
|
#OptimizedRNNStack
|
||||||
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3), (True, 2, 4, 8), (True, 2, 6, 8),
|
OPTIM_RNN_STACK_CONFIGS = ((True, 2, 2, 3, 'lstm'), (True, 2, 4, 8, 'lstm'), (True, 2, 6, 8, 'lstm'),
|
||||||
(True, 4, 2, 3), (False, 2, 2, 3))
|
(True, 4, 2, 3, 'lstm'), (False, 2, 2, 3, 'lstm'),
|
||||||
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size", OPTIM_RNN_STACK_CONFIGS)
|
(True, 1, 2, 3, 'rnnReLU'), (True, 4, 4, 8, 'rnnReLU'), (False, 2, 6, 8, 'rnnReLU'),
|
||||||
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, tmpdir, device_id):
|
(True, 4, 2, 3, 'rnnTanh'), (False, 2, 2, 3, 'rnnTanh'), (True, 1, 2, 3, 'rnnTanh'))
|
||||||
|
@pytest.mark.parametrize("bidirectional, num_layers, input_size, hidden_size, recurrent_op", OPTIM_RNN_STACK_CONFIGS)
|
||||||
|
def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, recurrent_op, tmpdir, device_id):
|
||||||
if device_id == -1:
|
if device_id == -1:
|
||||||
pytest.skip('Test only runs on GPU')
|
pytest.skip('Test only runs on GPU')
|
||||||
dev = cntk_device(device_id)
|
dev = cntk_device(device_id)
|
||||||
|
@ -728,7 +730,7 @@ def test_OptimizedRNNStack(bidirectional, num_layers, input_size, hidden_size, t
|
||||||
W = C.parameter((C.InferredDimension, input_size), constant_initializer(0.1), device=dev)
|
W = C.parameter((C.InferredDimension, input_size), constant_initializer(0.1), device=dev)
|
||||||
x = C.sequence.input_variable(shape=(input_size,))
|
x = C.sequence.input_variable(shape=(input_size,))
|
||||||
s = np.asarray(np.random.uniform(-1, 1, (5,input_size)), dtype=np.float32)
|
s = np.asarray(np.random.uniform(-1, 1, (5,input_size)), dtype=np.float32)
|
||||||
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, name='MyRnnStack')
|
f = C.optimized_rnnstack(x, W, hidden_size, num_layers, bidirectional=bidirectional, recurrent_op=recurrent_op, name='MyRnnStack')
|
||||||
f.parameters[0].value = np.reshape(np.arange(np.prod(f.parameters[0].value.shape), dtype=np.float32), f.parameters[0].value.shape)
|
f.parameters[0].value = np.reshape(np.arange(np.prod(f.parameters[0].value.shape), dtype=np.float32), f.parameters[0].value.shape)
|
||||||
verify_one_input(f, s, tmpdir, model_filename)
|
verify_one_input(f, s, tmpdir, model_filename)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче