local changes for alignment node for attension-based mechanism
This commit is contained in:
Родитель
5a1175de68
Коммит
81f9ac56ee
1
CNTK.sln
1
CNTK.sln
|
@ -241,6 +241,7 @@ Global
|
|||
{014DA766-B37B-4581-BC26-963EA5507931} = {33EBFE78-A1A8-4961-8938-92A271941F94}
|
||||
{D667AF32-028A-4A5D-BE19-F46776F0F6B2} = {33EBFE78-A1A8-4961-8938-92A271941F94}
|
||||
{3ED0465D-23E7-4855-9694-F788717B6533} = {39E42C4B-A078-4CA4-9D92-B883D8129601}
|
||||
{065AF55D-AF02-448B-BFCD-52619FDA4BD0} = {39E42C4B-A078-4CA4-9D92-B883D8129601}
|
||||
{98D2C32B-0C1F-4E19-A626-65F7BA4600CF} = {065AF55D-AF02-448B-BFCD-52619FDA4BD0}
|
||||
{EA67F51F-1FE8-462D-9F3E-01161685AD59} = {065AF55D-AF02-448B-BFCD-52619FDA4BD0}
|
||||
{DE1A06BA-EC5C-4E0D-BCA8-3EA555310C58} = {065AF55D-AF02-448B-BFCD-52619FDA4BD0}
|
||||
|
|
|
@ -78,7 +78,7 @@ public:
|
|||
virtual bool GetData(const std::wstring&, size_t, void*, size_t&, size_t) { NOT_IMPLEMENTED; };
|
||||
virtual bool DataEnd(EndDataType) { NOT_IMPLEMENTED; };
|
||||
virtual void SetSentenceSegBatch(Matrix<ElemType>&, Matrix<ElemType>&) { NOT_IMPLEMENTED; };
|
||||
virtual void SetRandomSeed(int) { NOT_IMPLEMENTED; };
|
||||
virtual void SetRandomSeed(unsigned seed = 0) { m_seed = seed; };
|
||||
virtual bool GetProposalObs(std::map<std::wstring, Matrix<ElemType>*>&, const size_t, vector<size_t>&) { return false; }
|
||||
virtual void InitProposals(std::map<std::wstring, Matrix<ElemType>*>&) { }
|
||||
virtual bool CanReadFor(wstring /* nodeName */) {
|
||||
|
|
|
@ -543,8 +543,8 @@ void SequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
|||
std::wstring m_file = readerConfig("file");
|
||||
if (m_traceLevel > 0)
|
||||
{
|
||||
//fprintf(stderr, "reading sequence file %ls\n", m_file.c_str());
|
||||
std::wcerr << "reading sequence file" << m_file.c_str() << endl;
|
||||
fprintf(stderr, "reading sequence file %ls\n", m_file.c_str());
|
||||
//std::wcerr << "reading sequence file" << m_file.c_str() << endl;
|
||||
}
|
||||
|
||||
const LabelInfo& labelIn = m_labelInfo[labelInfoIn];
|
||||
|
@ -1503,8 +1503,8 @@ void BatchSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
|||
std::wstring m_file = readerConfig("file");
|
||||
if (m_traceLevel > 0)
|
||||
{
|
||||
//fwprintf(stderr, L"reading sequence file %s\n", m_file.c_str());
|
||||
std::wcerr << "reading sequence file " << m_file.c_str() << endl;
|
||||
fwprintf(stderr, L"reading sequence file %s\n", m_file.c_str());
|
||||
//std::wcerr << "reading sequence file " << m_file.c_str() << endl;
|
||||
}
|
||||
|
||||
const LabelInfo& labelIn = m_labelInfo[labelInfoIn];
|
||||
|
@ -1986,8 +1986,8 @@ bool BatchSequenceReader<ElemType>::DataEnd(EndDataType endDataType)
|
|||
/// notice that indices are defined as follows [begining ending_indx) of the class
|
||||
/// i.e., the ending_index is 1 plus of the true ending index
|
||||
template<class ElemType>
|
||||
void BatchSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
|
||||
Matrix<ElemType>*>& matrices,
|
||||
void BatchSequenceReader<ElemType>::GetLabelOutput(std::map < std::wstring,
|
||||
Matrix<ElemType>* > & matrices,
|
||||
size_t m_mbStartSample, size_t actualmbsize)
|
||||
{
|
||||
size_t j = 0;
|
||||
|
@ -2007,51 +2007,47 @@ void BatchSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
|
|||
labels->TransferFromDeviceToDevice(curDevId, CPUDEVICE, true, false, false);
|
||||
|
||||
if (labels->GetCurrentMatrixLocation() == CPU)
|
||||
for (size_t jSample = m_mbStartSample; j < actualmbsize; ++j, ++jSample)
|
||||
{
|
||||
// pick the right sample with randomization if desired
|
||||
size_t jRand = jSample;
|
||||
int wrd = m_labelIdData[jRand];
|
||||
labels->SetValue(0, j, (ElemType)wrd);
|
||||
SetSentenceEnd(wrd, j, actualmbsize);
|
||||
|
||||
if (readerMode == ReaderMode::NCE)
|
||||
for (size_t jSample = m_mbStartSample; j < actualmbsize; ++j, ++jSample)
|
||||
{
|
||||
labels->SetValue(1, j, (ElemType)m.logprob(wrd));
|
||||
for (size_t noiseid = 0; noiseid < this->noise_sample_size; noiseid++)
|
||||
// pick the right sample with randomization if desired
|
||||
size_t jRand = jSample;
|
||||
int wrd = m_labelIdData[jRand];
|
||||
labels->SetValue(0, j, (ElemType)wrd);
|
||||
SetSentenceEnd(wrd, j, actualmbsize);
|
||||
|
||||
if (readerMode == ReaderMode::NCE)
|
||||
{
|
||||
int wid = m.sample();
|
||||
labels->SetValue(2 * (noiseid + 1), j, (ElemType)wid);
|
||||
labels->SetValue(2 * (noiseid + 1) + 1, j, -(ElemType)m.logprob(wid));
|
||||
}
|
||||
}
|
||||
else if (readerMode == ReaderMode::Class)
|
||||
{
|
||||
int clsidx = idx4class[wrd];
|
||||
if (class_size > 0){
|
||||
|
||||
labels->SetValue(1, j, (ElemType)clsidx);
|
||||
|
||||
/// save the [begining ending_indx) of the class
|
||||
size_t lft = (size_t) (*m_classInfoLocal)(0, clsidx);
|
||||
size_t rgt = (size_t) (*m_classInfoLocal)(1, clsidx);
|
||||
if (wrd < lft || lft > rgt || wrd >= rgt)
|
||||
labels->SetValue(1, j, (ElemType)m.logprob(wrd));
|
||||
for (size_t noiseid = 0; noiseid < this->noise_sample_size; noiseid++)
|
||||
{
|
||||
LogicError("LMSequenceReader::GetLabelOutput word %d should be at least equal to or larger than its class's left index %d; right index %d of its class should be larger or equal to left index %d of its class; word index %d should be smaller than its class's right index %d.\n", wrd, lft, rgt, lft, wrd, rgt);
|
||||
int wid = m.sample();
|
||||
labels->SetValue(2 * (noiseid + 1), j, (ElemType)wid);
|
||||
labels->SetValue(2 * (noiseid + 1) + 1, j, -(ElemType)m.logprob(wid));
|
||||
}
|
||||
}
|
||||
else if (readerMode == ReaderMode::Class)
|
||||
{
|
||||
int clsidx = idx4class[wrd];
|
||||
if (class_size > 0){
|
||||
|
||||
labels->SetValue(1, j, (ElemType)clsidx);
|
||||
|
||||
/// save the [begining ending_indx) of the class
|
||||
size_t lft = (size_t)(*m_classInfoLocal)(0, clsidx);
|
||||
size_t rgt = (size_t)(*m_classInfoLocal)(1, clsidx);
|
||||
if (wrd < lft || lft > rgt || wrd >= rgt)
|
||||
{
|
||||
LogicError("LMSequenceReader::GetLabelOutput word %d should be at least equal to or larger than its class's left index %d; right index %d of its class should be larger or equal to left index %d of its class; word index %d should be smaller than its class's right index %d.\n", wrd, lft, rgt, lft, wrd, rgt);
|
||||
}
|
||||
labels->SetValue(2, j, (*m_classInfoLocal)(0, clsidx)); /// begining index of the class
|
||||
labels->SetValue(3, j, (*m_classInfoLocal)(1, clsidx)); /// end index of the class
|
||||
}
|
||||
labels->SetValue(2, j, (*m_classInfoLocal)(0, clsidx)); /// begining index of the class
|
||||
labels->SetValue(3, j, (*m_classInfoLocal)(1, clsidx)); /// end index of the class
|
||||
}
|
||||
}
|
||||
}
|
||||
else // GPU
|
||||
{
|
||||
RuntimeError("GetLabelOutput::should use CPU for labels ");
|
||||
}
|
||||
if (curDevId != CPUDEVICE)
|
||||
{
|
||||
labels->TransferFromDeviceToDevice(CPUDEVICE, curDevId, true, false, false);
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
|
|
|
@ -252,7 +252,6 @@ public:
|
|||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetRandomSeed(int) { NOT_IMPLEMENTED; }
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
|
|
|
@ -524,7 +524,7 @@ public:
|
|||
}
|
||||
|
||||
virtual void LoadFromFile(const std::wstring& fileName, const FileOptions fileFormat = FileOptions::fileOptionsBinary,
|
||||
const bool bAllowNoCriterionNode = false)
|
||||
const bool bAllowNoCriterionNode = false, ComputationNetwork<ElemType>* anotherNetwork=nullptr)
|
||||
{
|
||||
ClearNet();
|
||||
|
||||
|
@ -574,7 +574,7 @@ public:
|
|||
std::vector<ComputationNodePtr> childrenNodes;
|
||||
childrenNodes.resize(numChildren);
|
||||
for (int j = 0; j < numChildren; j++)
|
||||
childrenNodes[j] = GetNodeFromName(childrenNames[j]);
|
||||
childrenNodes[j] = GetNodeFromName(childrenNames[j], anotherNetwork);
|
||||
|
||||
if (nodePtr->OperationName() == RowStackNode<ElemType>::TypeName()) //allow for variable input nodes
|
||||
nodePtr->AttachInputs(childrenNodes);
|
||||
|
@ -1074,6 +1074,10 @@ public:
|
|||
newNode = new TimeReverseNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
|
||||
else if (nodeType == ParallelNode<ElemType>::TypeName())
|
||||
newNode = new ParallelNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
|
||||
else if (nodeType == AlignmentNode<ElemType>::TypeName())
|
||||
newNode = new AlignmentNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
|
||||
else if (nodeType == PairNetworkNode<ElemType>::TypeName())
|
||||
newNode = new PairNetworkNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "Error creating new ComputationNode of type %ls, with name %ls\n", nodeType.c_str(), nodeName.c_str());
|
||||
|
@ -1106,6 +1110,14 @@ public:
|
|||
return newNode;
|
||||
}
|
||||
|
||||
ComputationNodePtr PairNetwork(const ComputationNodePtr & a, const std::wstring nodeName = L"")
|
||||
{
|
||||
ComputationNodePtr newNode(new PairNetworkNode<ElemType>(m_deviceId, nodeName));
|
||||
newNode->AttachInputs(a);
|
||||
AddNodeToNet(newNode);
|
||||
return newNode;
|
||||
}
|
||||
|
||||
ComputationNodePtr CreateSparseInputNode(const std::wstring inputName, const size_t rows, const size_t cols)
|
||||
{
|
||||
ComputationNodePtr newNode(new SparseInputValue<ElemType>(rows, cols, m_deviceId, inputName));
|
||||
|
@ -1128,7 +1140,14 @@ public:
|
|||
return newNode;
|
||||
}
|
||||
|
||||
ComputationNodePtr CreateConvolutionNode(const std::wstring nodeName,
|
||||
ComputationNodePtr CreatePairNetworkNode(const std::wstring inputName, const size_t rows, const size_t cols)
|
||||
{
|
||||
ComputationNodePtr newNode(new PairNetworkNode<ElemType>(rows, cols, m_deviceId, inputName));
|
||||
AddNodeToNet(newNode);
|
||||
return newNode;
|
||||
}
|
||||
|
||||
ComputationNodePtr CreateConvolutionNode(const std::wstring nodeName,
|
||||
const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels,
|
||||
const size_t horizontalSubsample, const size_t verticalSubsample,
|
||||
const bool zeroPadding = false, const size_t maxTempMemSizeInSamples = 0)
|
||||
|
@ -1247,6 +1266,10 @@ public:
|
|||
newNode = new ParallelNode<ElemType>(m_deviceId, nodeName);
|
||||
else if (nodeType == RowStackNode<ElemType>::TypeName())
|
||||
newNode = new RowStackNode<ElemType>(m_deviceId, nodeName);
|
||||
else if (nodeType == AlignmentNode<ElemType>::TypeName())
|
||||
newNode = new AlignmentNode<ElemType>(m_deviceId, nodeName);
|
||||
else if (nodeType == PairNetworkNode<ElemType>::TypeName())
|
||||
newNode = new PairNetworkNode<ElemType>(m_deviceId, nodeName);
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "Error creating new ComputationNode of type %ls, with name %ls\n", nodeType.c_str(), nodeName.c_str());
|
||||
|
@ -1653,19 +1676,29 @@ public:
|
|||
return newNode;
|
||||
}
|
||||
|
||||
ComputationNodePtr Alignment(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName = L"")
|
||||
{
|
||||
ComputationNodePtr newNode(new AlignmentNode<ElemType>(m_deviceId, nodeName));
|
||||
newNode->AttachInputs(a, b, c);
|
||||
AddNodeToNet(newNode);
|
||||
return newNode;
|
||||
}
|
||||
|
||||
bool NodeNameExist(const std::wstring& name) const
|
||||
{
|
||||
auto iter = m_nameToNodeMap.find(name);
|
||||
return (iter != m_nameToNodeMap.end());
|
||||
}
|
||||
|
||||
ComputationNodePtr GetNodeFromName(const std::wstring& name) const
|
||||
ComputationNodePtr GetNodeFromName(const std::wstring& name, ComputationNetwork<ElemType>* anotherNetwork = nullptr) const
|
||||
{
|
||||
auto iter = m_nameToNodeMap.find(name);
|
||||
if (iter != m_nameToNodeMap.end()) //found
|
||||
return iter->second;
|
||||
else //should never try to get a node from nonexisting name
|
||||
throw std::runtime_error("GetNodeFromName: Node name does not exist.");
|
||||
if (anotherNetwork != nullptr)
|
||||
return anotherNetwork->GetNodeFromName(name);
|
||||
|
||||
RuntimeError("GetNodeFromName: Node name %s does not exist.", name.c_str());
|
||||
}
|
||||
|
||||
// GetNodesFromName - Get all the nodes from a name that may match a wildcard '*' pattern
|
||||
|
|
|
@ -867,7 +867,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return false;
|
||||
}
|
||||
|
||||
void EnumerateNodesForEval(std::unordered_set<ComputationNodePtr>& visited, std::list<ComputationNodePtr>& result,
|
||||
virtual void EnumerateNodesForEval(std::unordered_set<ComputationNodePtr>& visited, std::list<ComputationNodePtr>& result,
|
||||
std::vector<ComputationNodePtr>& sourceRecurrentNodePtr, const bool bFromDelayNode)
|
||||
{
|
||||
if (visited.find(this) == visited.end()) //not visited
|
||||
|
|
|
@ -15,8 +15,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
public:
|
||||
virtual ComputationNetwork<ElemType>& LoadNetworkFromFile(const std::wstring& modelFileName, bool forceLoad = true,
|
||||
bool bAllowNoCriterion = false) = 0;
|
||||
virtual ComputationNetwork<ElemType>& BuildNetworkFromDescription() = 0;
|
||||
bool bAllowNoCriterion = false, ComputationNetwork<ElemType>* = nullptr) = 0;
|
||||
virtual ComputationNetwork<ElemType>& BuildNetworkFromDescription(ComputationNetwork<ElemType>* = nullptr) = 0;
|
||||
virtual ~IComputationNetBuilder() {};
|
||||
};
|
||||
}}}
|
|
@ -586,4 +586,157 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
template class LookupTableNode<float>;
|
||||
template class LookupTableNode<double>;
|
||||
|
||||
/**
|
||||
pair this node to a node in another network
|
||||
*/
|
||||
template<class ElemType>
|
||||
class PairNetworkNode : public ComputationNode<ElemType>
|
||||
{
|
||||
UsingComputationNodeMembers;
|
||||
public:
|
||||
PairNetworkNode(const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode<ElemType>(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
m_deviceId = deviceId;
|
||||
MoveMatricesToDevice(deviceId);
|
||||
m_reqMultiSeqHandling = true;
|
||||
m_functionValues.Resize(1, 1);
|
||||
InitRecurrentNode();
|
||||
}
|
||||
|
||||
PairNetworkNode(File& fstream, const size_t modelVersion, const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode<ElemType>(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
|
||||
m_functionValues.Resize(1, 1);
|
||||
m_reqMultiSeqHandling = true;
|
||||
|
||||
LoadFromFile(fstream, modelVersion, deviceId);
|
||||
}
|
||||
|
||||
PairNetworkNode(const DEVICEID_TYPE deviceId, size_t row_size, size_t col_size, const std::wstring name = L"") : ComputationNode<ElemType>(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
m_deviceId = deviceId;
|
||||
MoveMatricesToDevice(deviceId);
|
||||
m_reqMultiSeqHandling = true;
|
||||
|
||||
m_functionValues.Resize(row_size, col_size);
|
||||
|
||||
m_gradientValues.Resize(row_size, col_size);
|
||||
m_gradientValues.SetValue(0.0f);
|
||||
|
||||
InitRecurrentNode();
|
||||
}
|
||||
|
||||
virtual const std::wstring OperationName() const { return TypeName(); }
|
||||
|
||||
/// to-do: need to change to the new way of resetting state
|
||||
virtual void ComputeInputPartial(const size_t inputIndex)
|
||||
{
|
||||
if (inputIndex > 0)
|
||||
throw std::invalid_argument("PairNetwork operation only takes one input.");
|
||||
|
||||
Matrix<ElemType>::ScaleAndAdd(1.0, GradientValues(), Inputs(inputIndex)->GradientValues());
|
||||
}
|
||||
|
||||
virtual void ComputeInputPartial(const size_t inputIndex, const size_t timeIdxInSeq)
|
||||
{
|
||||
if (inputIndex > 0)
|
||||
throw std::invalid_argument("Delay operation only takes one input.");
|
||||
assert(m_functionValues.GetNumRows() == GradientValues().GetNumRows()); // original used m_functionValues.GetNumRows() for loop dimension
|
||||
assert(m_sentenceSeg != nullptr);
|
||||
assert(m_existsSentenceBeginOrNoLabels != nullptr);
|
||||
|
||||
Matrix<ElemType> mTmp = Inputs(inputIndex)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType>::ScaleAndAdd(1.0, GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep),
|
||||
mTmp);
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode()
|
||||
{
|
||||
m_functionValues.SetValue(Inputs(0)->FunctionValues());
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode(const size_t timeIdxInSeq)
|
||||
{
|
||||
Matrix<ElemType> mTmp = FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
mTmp.SetValue(Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep));
|
||||
}
|
||||
|
||||
virtual void Validate()
|
||||
{
|
||||
PrintSelfBeforeValidation(true);
|
||||
|
||||
if (m_children.size() != 1)
|
||||
throw std::logic_error("PairNetwork operation should have one input.");
|
||||
|
||||
if (!(Inputs(0) == nullptr))
|
||||
{
|
||||
size_t rows0 = Inputs(0)->FunctionValues().GetNumRows(), cols0 = Inputs(0)->FunctionValues().GetNumCols();
|
||||
|
||||
if (rows0 > 0 && cols0 > 0) FunctionValues().Resize(rows0, cols0);
|
||||
}
|
||||
CopyImageSizeFromInputs();
|
||||
}
|
||||
|
||||
virtual void AttachInputs(const ComputationNodePtr inputNode)
|
||||
{
|
||||
m_children.resize(1);
|
||||
m_children[0] = inputNode;
|
||||
}
|
||||
|
||||
void EnumerateNodesForEval(std::unordered_set<ComputationNodePtr>& visited, std::list<ComputationNodePtr>& result,
|
||||
std::vector<ComputationNodePtr>& sourceRecurrentNodePtr, const bool bFromDelayNode)
|
||||
{
|
||||
if (visited.find(this) == visited.end()) //not visited
|
||||
{
|
||||
visited.insert(this); // have visited tagged here to avoid infinite loop over children, children's children, etc
|
||||
|
||||
//children first for function evaluation
|
||||
if (!IsLeaf())
|
||||
{
|
||||
if (ChildrenNeedGradient()) //only nodes that require gradient calculation is included in gradient calculation
|
||||
m_needGradient = true;
|
||||
else
|
||||
m_needGradient = false;
|
||||
}
|
||||
|
||||
result.push_back(ComputationNodePtr(this)); //we put this in the list even if it's leaf since we need to use it to determine learnable params
|
||||
this->m_visitedOrder = result.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!IsLeaf() && bFromDelayNode)
|
||||
sourceRecurrentNodePtr.push_back(this);
|
||||
}
|
||||
}
|
||||
|
||||
static const std::wstring TypeName() { return L"PairNetwork"; }
|
||||
|
||||
// copy constructor
|
||||
PairNetworkNode(const PairNetworkNode<ElemType>* node, const std::wstring& newName, const CopyNodeFlags flags)
|
||||
: ComputationNode<ElemType>(node->m_deviceId)
|
||||
{
|
||||
node->CopyTo(this, newName, flags);
|
||||
}
|
||||
|
||||
virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
const std::wstring& name = (newName == L"") ? NodeName() : newName;
|
||||
|
||||
ComputationNodePtr node = new PairNetworkNode<ElemType>(this, name, flags);
|
||||
return node;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual bool UseCustomizedMultiSeqHandling() { return true; }
|
||||
|
||||
};
|
||||
|
||||
template class PairNetworkNode<float>;
|
||||
template class PairNetworkNode<double>;
|
||||
|
||||
}}}
|
||||
|
|
|
@ -772,7 +772,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
}
|
||||
|
||||
static void WINAPI ComputeInputPartialLeft(Matrix<ElemType>& inputFunctionValues, Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues)
|
||||
static void WINAPI ComputeInputPartialLeft(const Matrix<ElemType>& inputFunctionValues, Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues)
|
||||
{
|
||||
#if DUMPOUTPUT
|
||||
gradientValues.Print("Gradient-in");
|
||||
|
@ -2578,8 +2578,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
inputFunctionValues.Print("child Function values");
|
||||
#endif
|
||||
|
||||
if (ones.GetNumRows() != inputGradientValues.GetNumRows() || ones.GetNumCols() != inputGradientValues.GetNumCols())
|
||||
ones = Matrix<ElemType>::Ones(inputGradientValues.GetNumRows(), inputGradientValues.GetNumCols(), inputGradientValues.GetDeviceId());
|
||||
if (ones.GetNumRows() != inputGradientValues.GetNumRows() || ones.GetNumCols() != inputGradientValues.GetNumRows())
|
||||
ones = Matrix<ElemType>::Ones(inputGradientValues.GetNumRows(), inputGradientValues.GetNumRows(), inputGradientValues.GetDeviceId());
|
||||
Matrix<ElemType>::MultiplyAndAdd(ones, false, gradientValues, true, inputGradientValues);
|
||||
#if DUMPOUTPUT
|
||||
inputGradientValues.Print("child Gradient-out");
|
||||
|
|
|
@ -158,7 +158,7 @@ namespace Microsoft {
|
|||
fprintf(stderr, "Starting from checkpoint. Load Decoder Network From File %ws.\n", modelFileName.c_str());
|
||||
|
||||
ComputationNetwork<ElemType>& decoderNet =
|
||||
startEpoch<0 ? decoderNetBuilder->BuildNetworkFromDescription() : decoderNetBuilder->LoadNetworkFromFile(modelFileName);
|
||||
startEpoch<0 ? decoderNetBuilder->BuildNetworkFromDescription(&encoderNet) : decoderNetBuilder->LoadNetworkFromFile(modelFileName, true, false, &encoderNet);
|
||||
|
||||
startEpoch = max(startEpoch, 0);
|
||||
|
||||
|
@ -373,6 +373,264 @@ namespace Microsoft {
|
|||
fprintf(stderr, "Finished Epoch[%lu]: Evaluation Node [%ws] Per Sample = %.8g\n", i + 1, evalNodeNames[j].c_str(), epochEvalErrors[j]);
|
||||
}
|
||||
|
||||
if (decoderValidationSetDataReader != decoderTrainSetDataReader && decoderValidationSetDataReader != nullptr &&
|
||||
encoderValidationSetDataReader != encoderTrainSetDataReader && encoderValidationSetDataReader != nullptr)
|
||||
{
|
||||
SimpleEvaluator<ElemType> evalforvalidation(decoderNet);
|
||||
vector<wstring> cvEncoderSetTrainAndEvalNodes;
|
||||
cvEncoderSetTrainAndEvalNodes.push_back(encoderEvaluationNodes[0]->NodeName());
|
||||
|
||||
vector<wstring> cvDecoderSetTrainAndEvalNodes;
|
||||
cvDecoderSetTrainAndEvalNodes.push_back(decoderCriterionNodes[0]->NodeName());
|
||||
cvDecoderSetTrainAndEvalNodes.push_back(decoderEvaluationNodes[0]->NodeName());
|
||||
|
||||
vector<ElemType> vScore = evalforvalidation.EvaluateEncoderDecoderWithHiddenStates(
|
||||
encoderNet, decoderNet,
|
||||
*encoderValidationSetDataReader,
|
||||
*decoderValidationSetDataReader, cvEncoderSetTrainAndEvalNodes,
|
||||
cvDecoderSetTrainAndEvalNodes, m_mbSize[i]);
|
||||
fprintf(stderr, "Finished Epoch[%lu]: [Validation Set] Train Loss Per Sample = %.8g EvalErr Per Sample = %.8g\n",
|
||||
i + 1, vScore[0], vScore[1]);
|
||||
|
||||
epochCriterion[0] = vScore[0]; //the first one is the decoder training criterion.
|
||||
}
|
||||
|
||||
bool loadedPrevModel = false;
|
||||
size_t epochsSinceLastLearnRateAdjust = i % m_learnRateAdjustInterval + 1;
|
||||
if (avgCriterion == std::numeric_limits<ElemType>::infinity())
|
||||
avgCriterion = epochCriterion[0];
|
||||
else
|
||||
avgCriterion = ((epochsSinceLastLearnRateAdjust - 1 - epochsNotCountedInAvgCriterion)* avgCriterion + epochCriterion[0]) / (epochsSinceLastLearnRateAdjust - epochsNotCountedInAvgCriterion);
|
||||
|
||||
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && m_learningRatesPerSample.size() <= i && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval)
|
||||
{
|
||||
if (prevCriterion - avgCriterion < 0 && prevCriterion != std::numeric_limits<ElemType>::infinity())
|
||||
{
|
||||
if (m_loadBestModel)
|
||||
{
|
||||
encoderNet.LoadPersistableParametersFromFile(GetEncoderModelNameForEpoch(i - 1),
|
||||
false);
|
||||
decoderNet.LoadPersistableParametersFromFile(GetDecoderModelNameForEpoch(i - 1),
|
||||
m_validateAfterModelReloading);
|
||||
encoderNet.ResetEvalTimeStamp();
|
||||
decoderNet.ResetEvalTimeStamp();
|
||||
LoadCheckPointInfo(i - 1, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion);
|
||||
fprintf(stderr, "Loaded the previous model which has better training criterion.\n");
|
||||
loadedPrevModel = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_continueReduce)
|
||||
{
|
||||
if (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && prevCriterion != std::numeric_limits<ElemType>::infinity())
|
||||
{
|
||||
if (learnRateReduced == false)
|
||||
{
|
||||
learnRateReduced = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
decoderNet.SaveToFile(GetDecoderModelNameForEpoch(i, true));
|
||||
encoderNet.SaveToFile(GetEncoderModelNameForEpoch(i, true));
|
||||
fprintf(stderr, "Finished training and saved final model\n\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (learnRateReduced)
|
||||
{
|
||||
learnRatePerSample *= m_learnRateDecreaseFactor;
|
||||
fprintf(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (prevCriterion - avgCriterion <= m_reduceLearnRateIfImproveLessThan * prevCriterion && prevCriterion != std::numeric_limits<ElemType>::infinity())
|
||||
{
|
||||
|
||||
learnRatePerSample *= m_learnRateDecreaseFactor;
|
||||
fprintf(stderr, "learnRatePerSample reduced to %.8g\n", learnRatePerSample);
|
||||
}
|
||||
else if (prevCriterion - avgCriterion > m_increaseLearnRateIfImproveMoreThan*prevCriterion && prevCriterion != std::numeric_limits<ElemType>::infinity())
|
||||
{
|
||||
learnRatePerSample *= m_learnRateIncreaseFactor;
|
||||
fprintf(stderr, "learnRatePerSample increased to %.8g\n", learnRatePerSample);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!loadedPrevModel && epochsSinceLastLearnRateAdjust == m_learnRateAdjustInterval) //not loading previous values then set them
|
||||
{
|
||||
prevCriterion = avgCriterion;
|
||||
epochsNotCountedInAvgCriterion = 0;
|
||||
}
|
||||
|
||||
//persist model and check-point info
|
||||
decoderNet.SaveToFile(GetDecoderModelNameForEpoch(i));
|
||||
encoderNet.SaveToFile(GetEncoderModelNameForEpoch(i));
|
||||
SaveCheckPointInfo(i, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion);
|
||||
if (!m_keepCheckPointFiles)
|
||||
_wunlink(GetCheckPointFileNameForEpoch(i - 1).c_str()); //delete previous checkpiont file to save space
|
||||
|
||||
if (learnRatePerSample < 1e-12)
|
||||
fprintf(stderr, "learnRate per sample is reduced to %.8g which is below 1e-12. stop training.\n", learnRatePerSample);
|
||||
}
|
||||
}
|
||||
|
||||
void sfbTrainEncoderDecoderModel(int startEpoch, ComputationNetwork<ElemType>& encoderNet,
|
||||
ComputationNetwork<ElemType>& decoderNet,
|
||||
IDataReader<ElemType>* encoderTrainSetDataReader,
|
||||
IDataReader<ElemType>* decoderTrainSetDataReader,
|
||||
IDataReader<ElemType>* encoderValidationSetDataReader,
|
||||
IDataReader<ElemType>* decoderValidationSetDataReader)
|
||||
{
|
||||
std::vector<ComputationNodePtr> & encoderFeatureNodes = encoderNet.FeatureNodes();
|
||||
std::vector<ComputationNodePtr> & encoderEvaluationNodes = encoderNet.OutputNodes();
|
||||
|
||||
std::vector<ComputationNodePtr> & decoderFeatureNodes = decoderNet.FeatureNodes();
|
||||
std::vector<ComputationNodePtr> & decoderLabelNodes = decoderNet.LabelNodes();
|
||||
std::vector<ComputationNodePtr> decoderCriterionNodes = GetTrainCriterionNodes(decoderNet);
|
||||
std::vector<ComputationNodePtr> decoderEvaluationNodes = GetEvalCriterionNodes(decoderNet);
|
||||
|
||||
std::map<std::wstring, Matrix<ElemType>*> encoderInputMatrices, decoderInputMatrices;
|
||||
for (size_t i = 0; i<encoderFeatureNodes.size(); i++)
|
||||
{
|
||||
encoderInputMatrices[encoderFeatureNodes[i]->NodeName()] =
|
||||
&encoderFeatureNodes[i]->FunctionValues();
|
||||
}
|
||||
for (size_t i = 0; i<decoderFeatureNodes.size(); i++)
|
||||
{
|
||||
decoderInputMatrices[decoderFeatureNodes[i]->NodeName()] =
|
||||
&decoderFeatureNodes[i]->FunctionValues();
|
||||
}
|
||||
for (size_t i = 0; i<decoderLabelNodes.size(); i++)
|
||||
{
|
||||
decoderInputMatrices[decoderLabelNodes[i]->NodeName()] = &decoderLabelNodes[i]->FunctionValues();
|
||||
}
|
||||
|
||||
//initializing weights and gradient holder
|
||||
std::list<ComputationNodePtr>& encoderLearnableNodes = encoderNet.LearnableNodes(encoderEvaluationNodes[0]); //only one criterion so far TODO: support multiple ones?
|
||||
std::list<ComputationNodePtr>& decoderLearnableNodes = decoderNet.LearnableNodes(decoderCriterionNodes[0]);
|
||||
std::list<ComputationNodePtr> learnableNodes;
|
||||
for (auto nodeIter = encoderLearnableNodes.begin(); nodeIter != encoderLearnableNodes.end(); nodeIter++)
|
||||
{
|
||||
ComputationNodePtr node = (*nodeIter);
|
||||
learnableNodes.push_back(node);
|
||||
}
|
||||
for (auto nodeIter = decoderLearnableNodes.begin(); nodeIter != decoderLearnableNodes.end(); nodeIter++)
|
||||
{
|
||||
ComputationNodePtr node = (*nodeIter);
|
||||
learnableNodes.push_back(node);
|
||||
}
|
||||
|
||||
std::list<Matrix<ElemType>> smoothedGradients;
|
||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
|
||||
{
|
||||
ComputationNodePtr node = (*nodeIter);
|
||||
smoothedGradients.push_back(Matrix<ElemType>(node->FunctionValues().GetNumRows(), node->FunctionValues().GetNumCols(), node->FunctionValues().GetDeviceId()));
|
||||
}
|
||||
|
||||
vector<ElemType> epochCriterion;
|
||||
ElemType avgCriterion, prevCriterion;
|
||||
for (size_t i = 0; i < 2; i++)
|
||||
epochCriterion.push_back(std::numeric_limits<ElemType>::infinity());
|
||||
avgCriterion = prevCriterion = std::numeric_limits<ElemType>::infinity();
|
||||
|
||||
size_t epochsNotCountedInAvgCriterion = startEpoch % m_learnRateAdjustInterval;
|
||||
|
||||
std::vector<ElemType> epochEvalErrors(decoderEvaluationNodes.size(), std::numeric_limits<ElemType>::infinity());
|
||||
|
||||
std::vector<wstring> evalNodeNames;
|
||||
for (size_t i = 0; i<decoderEvaluationNodes.size(); i++)
|
||||
evalNodeNames.push_back(decoderEvaluationNodes[i]->NodeName());
|
||||
|
||||
size_t totalSamplesSeen = 0;
|
||||
ElemType learnRatePerSample = 0.5f / m_mbSize[startEpoch];
|
||||
|
||||
int m_numPrevLearnRates = 5; //used to control the upper learnining rate in LR search to reduce computation
|
||||
vector<ElemType> prevLearnRates;
|
||||
prevLearnRates.resize(m_numPrevLearnRates);
|
||||
for (int i = 0; i<m_numPrevLearnRates; i++)
|
||||
prevLearnRates[i] = std::numeric_limits<ElemType>::infinity();
|
||||
|
||||
//precompute mean and invStdDev nodes and save initial model
|
||||
if (/// to-do doesn't support pre-compute such as MVN here
|
||||
/// PreCompute(net, encoderTrainSetDataReader, encoderFeatureNodes, encoderlabelNodes, encoderInputMatrices) ||
|
||||
startEpoch == 0)
|
||||
{
|
||||
encoderNet.SaveToFile(GetEncoderModelNameForEpoch(int(startEpoch) - 1));
|
||||
decoderNet.SaveToFile(GetDecoderModelNameForEpoch(int(startEpoch) - 1));
|
||||
}
|
||||
|
||||
bool learnRateInitialized = false;
|
||||
if (startEpoch > 0)
|
||||
{
|
||||
learnRateInitialized = LoadCheckPointInfo(startEpoch - 1, totalSamplesSeen, learnRatePerSample, smoothedGradients, prevCriterion);
|
||||
setMomentum(m_momentumInputPerMB[m_momentumInputPerMB.size() - 1]);
|
||||
}
|
||||
|
||||
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::AdjustAfterEpoch && !learnRateInitialized && m_learningRatesPerSample.size() <= startEpoch)
|
||||
throw std::invalid_argument("When using \"AdjustAfterEpoch\", there must either exist a checkpoint file, or an explicit learning rate must be specified in config for the starting epoch.");
|
||||
|
||||
ULONG dropOutSeed = 1;
|
||||
ElemType prevDropoutRate = 0;
|
||||
|
||||
bool learnRateReduced = false;
|
||||
|
||||
for (int i = int(startEpoch); i<int(m_maxEpochs); i++)
|
||||
{
|
||||
auto t_start_epoch = clock();
|
||||
|
||||
//set dropout rate
|
||||
SetDropoutRate(encoderNet, encoderEvaluationNodes[0], m_dropoutRates[i], prevDropoutRate, dropOutSeed);
|
||||
SetDropoutRate(decoderNet, decoderCriterionNodes[0], m_dropoutRates[i], prevDropoutRate, dropOutSeed);
|
||||
|
||||
//learning rate adjustment
|
||||
if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::None || (m_learningRatesPerSample.size() > 0 && m_learningRatesPerSample.size() > i))
|
||||
{
|
||||
learnRatePerSample = m_learningRatesPerSample[i];
|
||||
setMomentum(m_momentumInputPerMB[i]);
|
||||
}
|
||||
else if (m_autoLearnRateSearchType == LearningRateSearchAlgorithm::SearchBeforeEpoch)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
learnRateInitialized = true;
|
||||
|
||||
if (learnRatePerSample < m_minLearnRate)
|
||||
{
|
||||
fprintf(stderr, "Learn Rate Per Sample for Epoch[%lu] = %.8g is less than minLearnRate %.8g. Training stops.\n", i + 1, learnRatePerSample, m_minLearnRate);
|
||||
break;
|
||||
}
|
||||
|
||||
TrainOneEpochEncoderDecoderWithHiddenStates(encoderNet, decoderNet, i, m_epochSize, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, learnRatePerSample,
|
||||
encoderFeatureNodes, encoderEvaluationNodes, encoderInputMatrices,
|
||||
decoderFeatureNodes, decoderLabelNodes, decoderCriterionNodes, decoderEvaluationNodes,
|
||||
decoderInputMatrices, learnableNodes, smoothedGradients,
|
||||
epochCriterion, epochEvalErrors, totalSamplesSeen);
|
||||
|
||||
|
||||
auto t_end_epoch = clock();
|
||||
ElemType epochTime = ElemType(1.0)*(t_end_epoch - t_start_epoch) / (CLOCKS_PER_SEC);
|
||||
|
||||
// fprintf(stderr, "Finished Epoch[%lu]: [Training Set] Train Loss Per Sample = %.8g ", i + 1, epochCriterion);
|
||||
fprintf(stderr, "Finished Epoch[%lu]: [Training Set] Decoder Train Loss Per Sample = %.8g ", i + 1, epochCriterion[0]);
|
||||
if (epochEvalErrors.size() == 1)
|
||||
{
|
||||
fprintf(stderr, "EvalErr Per Sample = %.8g Ave Learn Rate Per Sample = %.10g Epoch Time=%.8g\n", epochEvalErrors[0], learnRatePerSample, epochTime);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "EvalErr Per Sample ");
|
||||
for (size_t j = 0; j<epochEvalErrors.size(); j++)
|
||||
fprintf(stderr, "[%lu]=%.8g ", j, epochEvalErrors[j]);
|
||||
fprintf(stderr, "Ave Learn Rate Per Sample = %.10g Epoch Time=%.8g\n", learnRatePerSample, epochTime);
|
||||
fprintf(stderr, "Finished Epoch[%lu]: Criterion Node [%ls] Per Sample = %.8g\n", i + 1, decoderCriterionNodes[0]->NodeName().c_str(), epochCriterion[i + 1]);
|
||||
for (size_t j = 0; j<epochEvalErrors.size(); j++)
|
||||
fprintf(stderr, "Finished Epoch[%lu]: Evaluation Node [%ws] Per Sample = %.8g\n", i + 1, evalNodeNames[j].c_str(), epochEvalErrors[j]);
|
||||
}
|
||||
|
||||
if (decoderValidationSetDataReader != decoderTrainSetDataReader && decoderValidationSetDataReader != nullptr &&
|
||||
encoderValidationSetDataReader != encoderTrainSetDataReader && encoderValidationSetDataReader != nullptr)
|
||||
{
|
||||
|
@ -518,6 +776,198 @@ namespace Microsoft {
|
|||
|
||||
int numMBsRun = 0;
|
||||
|
||||
size_t numEvalNodes = epochEvalErrors.size();
|
||||
|
||||
// NOTE: the following two local matrices are not used in PTask path
|
||||
Matrix<ElemType> localEpochCriterion(1, 2, decoderNet.GetDeviceID()); //assume only one training criterion node for each epoch
|
||||
Matrix<ElemType> localEpochEvalErrors(1, numEvalNodes, decoderNet.GetDeviceID());
|
||||
|
||||
localEpochCriterion.SetValue(0);
|
||||
localEpochEvalErrors.SetValue(0);
|
||||
|
||||
encoderTrainSetDataReader->StartMinibatchLoop(m_mbSize[epochNumber], epochNumber, m_epochSize);
|
||||
decoderTrainSetDataReader->StartMinibatchLoop(m_mbSize[epochNumber], epochNumber, m_epochSize);
|
||||
|
||||
startReadMBTime = clock();
|
||||
Matrix<ElemType> mEncoderOutput(encoderEvaluationNodes[0]->FunctionValues().GetDeviceId());
|
||||
Matrix<ElemType> mDecoderInput(decoderEvaluationNodes[0]->FunctionValues().GetDeviceId());
|
||||
|
||||
unsigned uSeedForDataReader = epochNumber;
|
||||
|
||||
bool bContinueDecoding = true;
|
||||
while (bContinueDecoding)
|
||||
{
|
||||
try{
|
||||
encoderTrainSetDataReader->SetRandomSeed(uSeedForDataReader);
|
||||
encoderTrainSetDataReader->GetMinibatch(encoderInputMatrices);
|
||||
|
||||
/// now gradients on decoder network
|
||||
decoderTrainSetDataReader->SetRandomSeed(uSeedForDataReader);
|
||||
if (decoderTrainSetDataReader->GetMinibatch(decoderInputMatrices) == false)
|
||||
break;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
RuntimeError("Errors in reading features ");
|
||||
}
|
||||
|
||||
size_t actualMBSize = decoderNet.GetActualMBSize();
|
||||
if (actualMBSize == 0)
|
||||
LogicError("decoderTrainSetDataReader read data but decoderNet reports no data read");
|
||||
|
||||
UpdateEvalTimeStamps(encoderFeatureNodes);
|
||||
UpdateEvalTimeStamps(decoderFeatureNodes);
|
||||
UpdateEvalTimeStamps(decoderLabelNodes);
|
||||
|
||||
endReadMBTime = clock();
|
||||
startComputeMBTime = clock();
|
||||
|
||||
/// not the sentence begining, because the initial hidden layer activity is from the encoder network
|
||||
// decoderTrainSetDataReader->SetSentenceBegin(false);
|
||||
// decoderTrainSetDataReader->SetSentenceSegBatch(decoderNet.m_sentenceSeg);
|
||||
// decoderTrainSetDataReader->SetSentenceSegBatch(decoderNet.m_sentenceBegin);
|
||||
|
||||
if (m_doGradientCheck)
|
||||
{
|
||||
if (EncoderDecoderGradientCheck(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors) == false)
|
||||
{
|
||||
throw runtime_error("SGD::TrainOneEpochEncoderDecoderWithHiddenStates gradient check not passed!");
|
||||
}
|
||||
localEpochCriterion.SetValue(0);
|
||||
localEpochEvalErrors.SetValue(0);
|
||||
}
|
||||
|
||||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
EncoderDecoderWithHiddenStatesErrorProp(encoderNet,
|
||||
decoderNet, encoderEvaluationNodes,
|
||||
decoderCriterionNodes,
|
||||
historyMat, m_lst_pair_encoder_decoder_nodes);
|
||||
|
||||
//update model parameters
|
||||
if (learnRatePerSample > m_minLearnRate * 0.01)
|
||||
{
|
||||
auto smoothedGradientIter = smoothedGradients.begin();
|
||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, smoothedGradientIter++)
|
||||
{
|
||||
ComputationNodePtr node = (*nodeIter);
|
||||
Matrix<ElemType>& smoothedGradient = (*smoothedGradientIter);
|
||||
|
||||
UpdateWeights(node, smoothedGradient, learnRatePerSample, actualMBSize, m_mbSize[epochNumber], m_L2RegWeight, m_L1RegWeight, m_needAveMultiplier);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
endComputeMBTime = clock();
|
||||
numMBsRun++;
|
||||
if (m_traceLevel > 0)
|
||||
{
|
||||
ElemType MBReadTime = (ElemType)(endReadMBTime - startReadMBTime) / (CLOCKS_PER_SEC);
|
||||
ElemType MBComputeTime = (ElemType)(endComputeMBTime - startComputeMBTime) / CLOCKS_PER_SEC;
|
||||
|
||||
readTimeInMBs += MBReadTime;
|
||||
ComputeTimeInMBs += MBComputeTime;
|
||||
numSamplesLastMBs += int(actualMBSize);
|
||||
|
||||
if (numMBsRun % m_numMBsToShowResult == 0)
|
||||
{
|
||||
|
||||
epochCriterion[0] = localEpochCriterion.Get00Element();
|
||||
for (size_t i = 0; i< numEvalNodes; i++)
|
||||
epochEvalErrors[i] = (const ElemType)localEpochEvalErrors(0, i);
|
||||
|
||||
ElemType llk = (epochCriterion[0] - epochCriterionLastMBs[0]) / numSamplesLastMBs;
|
||||
ElemType ppl = exp(llk);
|
||||
fprintf(stderr, "Epoch[%d]-Minibatch[%d-%d]: Samples Seen = %d Decoder Train Loss Per Sample = %.8g PPL = %.4e ", epochNumber + 1, numMBsRun - m_numMBsToShowResult + 1, numMBsRun, numSamplesLastMBs,
|
||||
llk, ppl);
|
||||
for (size_t i = 0; i<numEvalNodes; i++){
|
||||
fprintf(stderr, "EvalErr[%lu] Per Sample = %.8g ", i, (epochEvalErrors[i] - epochEvalErrorsLastMBs[i]) / numSamplesLastMBs);
|
||||
}
|
||||
fprintf(stderr, "ReadData Time = %.8g Computing Time=%.8g Total Time Per Sample=%.8g\n", readTimeInMBs, ComputeTimeInMBs, (readTimeInMBs + ComputeTimeInMBs) / numSamplesLastMBs);
|
||||
|
||||
//reset statistics
|
||||
readTimeInMBs = ComputeTimeInMBs = 0;
|
||||
numSamplesLastMBs = 0;
|
||||
|
||||
epochCriterionLastMBs = epochCriterion;
|
||||
for (size_t i = 0; i< numEvalNodes; i++)
|
||||
epochEvalErrorsLastMBs[i] = epochEvalErrors[i];
|
||||
}
|
||||
}
|
||||
startReadMBTime = clock();
|
||||
totalEpochSamples += actualMBSize;
|
||||
totalSamplesSeen += actualMBSize;
|
||||
|
||||
if (totalEpochSamples >= epochSize)
|
||||
break;
|
||||
|
||||
/// call DataEnd function
|
||||
/// DataEnd does reader specific process if sentence ending is reached
|
||||
// encoderTrainSetDataReader->SetSentenceEnd(true);
|
||||
// decoderTrainSetDataReader->SetSentenceEnd(true);
|
||||
encoderTrainSetDataReader->DataEnd(endDataSentence);
|
||||
decoderTrainSetDataReader->DataEnd(endDataSentence);
|
||||
|
||||
uSeedForDataReader++;
|
||||
}
|
||||
|
||||
localEpochCriterion /= float(totalEpochSamples);
|
||||
localEpochEvalErrors /= float(totalEpochSamples);
|
||||
|
||||
epochCriterion[0] = localEpochCriterion.Get00Element();
|
||||
for (size_t i = 0; i < numEvalNodes; i++)
|
||||
{
|
||||
epochEvalErrors[i] = (const ElemType)localEpochEvalErrors(0, i);
|
||||
}
|
||||
fprintf(stderr, "total samples in epoch[%d] = %d\n", epochNumber, totalEpochSamples);
|
||||
}
|
||||
|
||||
/// use hidden states between encoder and decoder to communicate between two networks
|
||||
void sfbTrainOneEpochEncoderDecoderWithHiddenStates(
|
||||
ComputationNetwork<ElemType>& encoderNet, /// encoder network
|
||||
ComputationNetwork<ElemType>& decoderNet,
|
||||
const int epochNumber, const size_t epochSize,
|
||||
IDataReader<ElemType>* encoderTrainSetDataReader,
|
||||
IDataReader<ElemType>* decoderTrainSetDataReader,
|
||||
const ElemType learnRatePerSample,
|
||||
const std::vector<ComputationNodePtr>& encoderFeatureNodes,
|
||||
const std::vector<ComputationNodePtr>& encoderEvaluationNodes,
|
||||
std::map<std::wstring, Matrix<ElemType>*>& encoderInputMatrices,
|
||||
const std::vector<ComputationNodePtr>& decoderFeatureNodes,
|
||||
const std::vector<ComputationNodePtr>& decoderLabelNodes,
|
||||
const std::vector<ComputationNodePtr>& decoderCriterionNodes,
|
||||
const std::vector<ComputationNodePtr>& decoderEvaluationNodes,
|
||||
std::map<std::wstring, Matrix<ElemType>*>& decoderInputMatrices,
|
||||
const std::list<ComputationNodePtr>& learnableNodes,
|
||||
std::list<Matrix<ElemType>>& smoothedGradients,
|
||||
vector<ElemType>& epochCriterion, std::vector<ElemType>& epochEvalErrors, size_t& totalSamplesSeen)
|
||||
{
|
||||
assert(encoderEvaluationNodes.size() == 1);
|
||||
|
||||
Matrix<ElemType> historyMat(encoderNet.GetDeviceID());
|
||||
|
||||
ElemType readTimeInMBs = 0, ComputeTimeInMBs = 0;
|
||||
vector<ElemType> epochCriterionLastMBs;
|
||||
for (size_t i = 0; i < epochCriterion.size(); i++)
|
||||
epochCriterionLastMBs.push_back(0);
|
||||
|
||||
int numSamplesLastMBs = 0;
|
||||
std::vector<ElemType> epochEvalErrorsLastMBs(epochEvalErrors.size(), 0);
|
||||
|
||||
clock_t startReadMBTime = 0, startComputeMBTime = 0;
|
||||
clock_t endReadMBTime = 0, endComputeMBTime = 0;
|
||||
|
||||
//initialize statistics
|
||||
size_t totalEpochSamples = 0;
|
||||
|
||||
int numMBsRun = 0;
|
||||
|
||||
/// get the pair of encode and decoder nodes
|
||||
if (m_lst_pair_encoder_decoder_nodes.size() == 0 && m_lst_pair_encoder_decode_node_names.size() > 0)
|
||||
{
|
||||
|
@ -600,14 +1050,14 @@ namespace Microsoft {
|
|||
}
|
||||
|
||||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors);
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
EncoderDecoderWithHiddenStatesErrorProp(encoderNet,
|
||||
decoderNet, encoderEvaluationNodes,
|
||||
decoderCriterionNodes,
|
||||
historyMat, m_lst_pair_encoder_decoder_nodes);
|
||||
decoderNet, encoderEvaluationNodes,
|
||||
decoderCriterionNodes,
|
||||
historyMat, m_lst_pair_encoder_decoder_nodes);
|
||||
|
||||
//update model parameters
|
||||
if (learnRatePerSample > m_minLearnRate * 0.01)
|
||||
|
@ -741,7 +1191,7 @@ namespace Microsoft {
|
|||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors);
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
ElemType score1 = localEpochCriterion.Get00Element();
|
||||
|
||||
|
@ -759,7 +1209,7 @@ namespace Microsoft {
|
|||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors);
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
ElemType score2 = localEpochCriterion.Get00Element();
|
||||
|
||||
|
@ -776,7 +1226,7 @@ namespace Microsoft {
|
|||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors);
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
EncoderDecoderWithHiddenStatesErrorProp(encoderNet,
|
||||
decoderNet, encoderEvaluationNodes,
|
||||
|
@ -836,7 +1286,7 @@ namespace Microsoft {
|
|||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors);
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
ElemType score1 = localEpochCriterion.Get00Element();
|
||||
|
||||
|
@ -852,7 +1302,7 @@ namespace Microsoft {
|
|||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors);
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
ElemType score1r = localEpochCriterion.Get00Element();
|
||||
|
||||
|
@ -869,7 +1319,7 @@ namespace Microsoft {
|
|||
EncoderDecoderWithHiddenStatesForwardPass(encoderNet,
|
||||
decoderNet, encoderTrainSetDataReader,
|
||||
decoderTrainSetDataReader, encoderEvaluationNodes,
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, historyMat, localEpochCriterion, localEpochEvalErrors);
|
||||
decoderFeatureNodes, decoderCriterionNodes, decoderEvaluationNodes, localEpochCriterion, localEpochEvalErrors);
|
||||
|
||||
EncoderDecoderWithHiddenStatesErrorProp(encoderNet,
|
||||
decoderNet, encoderEvaluationNodes,
|
||||
|
@ -906,7 +1356,6 @@ namespace Microsoft {
|
|||
const std::vector<ComputationNodePtr>& decoderFeatureNodes,
|
||||
const std::vector<ComputationNodePtr>& decoderCriterionNodes,
|
||||
const std::vector<ComputationNodePtr>& decoderEvaluationNodes,
|
||||
Matrix<ElemType>& historyMat,
|
||||
Matrix<ElemType>& localEpochCriterion,
|
||||
Matrix<ElemType>& localEpochEvalErrors
|
||||
)
|
||||
|
@ -928,21 +1377,6 @@ namespace Microsoft {
|
|||
/// not the sentence begining, because the initial hidden layer activity is from the encoder network
|
||||
decoderTrainSetDataReader->SetSentenceSegBatch(decoderNet.mSentenceBoundary, decoderNet.mExistsBeginOrNoLabels);
|
||||
|
||||
/// get the pair of encode and decoder nodes
|
||||
for (typename list<pair<ComputationNodePtr, ComputationNodePtr>>::iterator iter = m_lst_pair_encoder_decoder_nodes.begin(); iter != m_lst_pair_encoder_decoder_nodes.end(); iter++)
|
||||
{
|
||||
/// past hidden layer activity from encoder network to decoder network
|
||||
ComputationNodePtr encoderNode = iter->first;
|
||||
ComputationNodePtr decoderNode = iter->second;
|
||||
|
||||
encoderNode->GetHistory(historyMat, true); /// get the last state activity
|
||||
decoderNode->SetHistory(historyMat);
|
||||
#ifdef DEBUG_DECODER
|
||||
fprintf(stderr, "LSTM past output norm = %.8e\n", historyMat.ColumnSlice(0, nstreams).FrobeniusNorm());
|
||||
fprintf(stderr, "LSTM past state norm = %.8e\n", historyMat.ColumnSlice(nstreams, nstreams).FrobeniusNorm());
|
||||
#endif
|
||||
}
|
||||
|
||||
UpdateEvalTimeStamps(decoderFeatureNodes);
|
||||
decoderNet.Evaluate(decoderCriterionNodes[0]);
|
||||
|
||||
|
|
|
@ -153,10 +153,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
delete m_executionEngine;
|
||||
}
|
||||
virtual ComputationNetwork<ElemType>& LoadNetworkFromFile(const wstring& modelFileName, bool forceLoad = true,
|
||||
bool bAllowNoCriterionNode = false)
|
||||
bool bAllowNoCriterionNode = false, ComputationNetwork<ElemType>* anotherNetwork = nullptr)
|
||||
{
|
||||
if (m_net->GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load
|
||||
m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterionNode);
|
||||
m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterionNode, anotherNetwork);
|
||||
|
||||
m_net->ResetEvalTimeStamp();
|
||||
return *m_net;
|
||||
|
@ -211,7 +211,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
ndlUtil.ProcessNDLConfig(config, true);
|
||||
}
|
||||
|
||||
virtual ComputationNetwork<ElemType>& BuildNetworkFromDescription()
|
||||
virtual ComputationNetwork<ElemType>& BuildNetworkFromDescription(ComputationNetwork<ElemType>* = nullptr)
|
||||
{
|
||||
if (m_net->GetTotalNumberOfNodes() < 1) //not built yet
|
||||
{
|
||||
|
|
|
@ -238,6 +238,10 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
|
|||
ret = true;
|
||||
else if (EqualInsensitive(nodeType, LSTMNode<ElemType>::TypeName(), L"LSTM"))
|
||||
ret = true;
|
||||
else if (EqualInsensitive(nodeType, AlignmentNode<ElemType>::TypeName(), L"Alignment"))
|
||||
ret = true;
|
||||
else if (EqualInsensitive(nodeType, AlignmentNode<ElemType>::TypeName(), L"PairNetwork"))
|
||||
ret = true;
|
||||
|
||||
// return the actual node name in the parameter if we found something
|
||||
if (ret)
|
||||
|
|
|
@ -27,7 +27,7 @@ to-dos:
|
|||
delay_node : has another input that points to additional observations.
|
||||
memory_node: M x N node, with a argument telling whether to save the last observation, or save a window size of observations, or save all observations
|
||||
pair_node : copy function values and gradient values from one node in source network to target network
|
||||
|
||||
sequential_alignment_node: compute similarity of the previous time or any matrix, versus a block of input, and output a weighted average from the input
|
||||
decoder delay_node -> memory_node -> pair(source, target) pair(source, target) -> memory_node -> encoder output node
|
||||
|
||||
|
||||
|
@ -581,7 +581,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
Developed by Kaisheng Yao
|
||||
Used in the following works:
|
||||
K. Yao, G. Zweig, "Sequence to sequence neural net models for graphone to phoneme conversion", submitted to Interspeech 2015
|
||||
K. Yao, G. Zweig, "Sequence to sequence neural net models for graphone to phoneme conversion", in Interspeech 2015
|
||||
*/
|
||||
template<class ElemType>
|
||||
class LSTMNode : public ComputationNode<ElemType>
|
||||
|
@ -1793,4 +1793,270 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
template class LSTMNode<float>;
|
||||
template class LSTMNode<double>;
|
||||
|
||||
/**
|
||||
This node uses softmax to compute the similarity of an input versus the second input, which is a block of memory, and outputs
|
||||
the weighed average of the second input.
|
||||
*/
|
||||
template<class ElemType>
|
||||
class AlignmentNode : public ComputationNode<ElemType>
|
||||
{
|
||||
UsingComputationNodeMembers;
|
||||
public:
|
||||
AlignmentNode(const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode<ElemType>(deviceId), m_memoryBlk4EachUtt(deviceId), m_softmax(deviceId), m_ones(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
m_deviceId = deviceId;
|
||||
MoveMatricesToDevice(deviceId);
|
||||
InitRecurrentNode();
|
||||
}
|
||||
|
||||
AlignmentNode(File& fstream, const size_t modelVersion, const DEVICEID_TYPE deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode<ElemType>(deviceId), m_memoryBlk4EachUtt(deviceId), m_softmax(deviceId), m_ones(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
LoadFromFile(fstream, modelVersion, deviceId);
|
||||
}
|
||||
|
||||
virtual const std::wstring OperationName() const { return TypeName(); }
|
||||
static const std::wstring TypeName() { return L"Alignment"; }
|
||||
|
||||
virtual void ComputeInputPartial(const size_t )
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
virtual void ComputeInputPartial(const size_t inputIndex, const size_t timeIdxInSeq)
|
||||
{
|
||||
if (inputIndex == 0)
|
||||
return;
|
||||
|
||||
if (inputIndex > 2)
|
||||
throw std::invalid_argument("Alignment has three inputs.");
|
||||
|
||||
Matrix<ElemType> sliceOutputGrad = GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> mTmp(m_deviceId);
|
||||
Matrix<ElemType> mTmp1(m_deviceId);
|
||||
Matrix<ElemType> mTmp2(m_deviceId);
|
||||
Matrix<ElemType> mTmp3(m_deviceId);
|
||||
Matrix<ElemType> mGBeforeSoftmax(m_deviceId);
|
||||
Matrix<ElemType> mTmp4(m_deviceId);
|
||||
Matrix<ElemType> mGToMemblk(m_deviceId);
|
||||
size_t T = Inputs(1)->FunctionValues().GetNumCols() / m_samplesInRecurrentStep;
|
||||
size_t e = Inputs(0)->FunctionValues().GetNumRows();
|
||||
size_t d = Inputs(1)->FunctionValues().GetNumRows();
|
||||
Matrix<ElemType> mGBeforeSoftmaxTimes(m_deviceId);
|
||||
mGBeforeSoftmaxTimes.Resize(T, e);
|
||||
mGBeforeSoftmaxTimes.SetValue(0);
|
||||
mGToMemblk.Resize(d, T);
|
||||
mGToMemblk.SetValue(0);
|
||||
|
||||
if (m_ones.GetNumRows() != e || m_ones.GetNumCols() != e)
|
||||
{
|
||||
m_ones = Matrix<ElemType>::Ones(e, e, m_deviceId);
|
||||
}
|
||||
|
||||
mGBeforeSoftmax.Resize(m_softmax.GetNumRows(),1);
|
||||
mGBeforeSoftmax.SetValue(0);
|
||||
for (size_t k = 0; k < m_samplesInRecurrentStep; k++)
|
||||
{
|
||||
size_t i = timeIdxInSeq * m_samplesInRecurrentStep + k;
|
||||
|
||||
/// right branch with softmax
|
||||
mTmp4 = m_memoryBlk4EachUtt.ColumnSlice(k*T, T);
|
||||
TimesNode<ElemType>::ComputeInputPartialRight(mTmp4, mTmp, sliceOutputGrad.ColumnSlice(k, 1)); /// before times
|
||||
SoftmaxNode<ElemType>::ComputeInputPartialS(mTmp1, mTmp2, mGBeforeSoftmax, mTmp, m_softmax.ColumnSlice(k, 1)); /// before softmax
|
||||
TimesNode<ElemType>::ComputeInputPartialLeft(Inputs(0)->FunctionValues().ColumnSlice(i, 1), mGBeforeSoftmaxTimes, mGBeforeSoftmax); /// before times
|
||||
|
||||
switch (inputIndex)
|
||||
{
|
||||
case 0:
|
||||
LogicError("no gradients should be backpropagated to past observation");
|
||||
case 1: //derivative to memory block
|
||||
TimesNode<ElemType>::ComputeInputPartialLeft(m_softmax.ColumnSlice(k, 1), mGToMemblk,
|
||||
sliceOutputGrad.ColumnSlice(k,1));
|
||||
|
||||
mTmp4.Resize(T,e);
|
||||
mTmp4.SetValue(0);
|
||||
TimesNode<ElemType>::ComputeInputPartialLeft(Inputs(2)->FunctionValues(), mTmp4, mGBeforeSoftmaxTimes);
|
||||
TransposeNode<ElemType>::ComputeInputPartial(mGToMemblk, m_ones, mTmp4);
|
||||
|
||||
for (size_t j = 0; j < T; j++)
|
||||
Inputs(1)->GradientValues().ColumnSlice(j*m_samplesInRecurrentStep + k, 1) += mGToMemblk.ColumnSlice(j, 1);
|
||||
|
||||
break;
|
||||
case 2: // derivative to similarity matrix
|
||||
mTmp2 = m_memoryBlk4EachUtt.ColumnSlice(k*T, T);
|
||||
Matrix<ElemType>::MultiplyAndAdd(mTmp2, false, mGBeforeSoftmaxTimes, false, Inputs(2)->GradientValues()); /// before times
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode()
|
||||
{
|
||||
EvaluateThisNodeS(m_functionValues, Inputs(0)->FunctionValues(), Inputs(1)->FunctionValues(), Inputs(2)->FunctionValues(),
|
||||
m_memoryBlk4EachUtt, m_softmax);
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode(const size_t timeIdxInSeq)
|
||||
{
|
||||
Matrix<ElemType> sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
EvaluateThisNodeS(sliceOutputValue, sliceInputValue, Inputs(1)->FunctionValues(), Inputs(2)->FunctionValues(), m_memoryBlk4EachUtt, m_softmax);
|
||||
}
|
||||
|
||||
static void WINAPI EvaluateThisNodeS(Matrix<ElemType>& functionValues,
|
||||
const Matrix<ElemType>& refFunction, const Matrix<ElemType>& memoryBlk,
|
||||
const Matrix<ElemType>& wgtMatrix,
|
||||
Matrix<ElemType>& tmpMemoryBlk4EachUtt, Matrix<ElemType>& tmpSoftMax)
|
||||
{
|
||||
size_t e = wgtMatrix.GetNumCols();
|
||||
size_t nbrUttPerSample = refFunction.GetNumCols();
|
||||
size_t T = memoryBlk.GetNumCols() / nbrUttPerSample;
|
||||
|
||||
tmpMemoryBlk4EachUtt.Resize(memoryBlk.GetNumRows(), memoryBlk.GetNumCols());
|
||||
tmpSoftMax.Resize(T, nbrUttPerSample);
|
||||
Matrix<ElemType> tmpMat(tmpMemoryBlk4EachUtt.GetDeviceId());
|
||||
Matrix<ElemType> tmpMat3(tmpMemoryBlk4EachUtt.GetDeviceId());
|
||||
tmpMat3.Resize(e, T);
|
||||
Matrix<ElemType> tmpMat4(tmpMemoryBlk4EachUtt.GetDeviceId());
|
||||
Matrix<ElemType> tmpMat2(tmpMemoryBlk4EachUtt.GetDeviceId());
|
||||
|
||||
for (size_t k = 0; k < nbrUttPerSample; k++)
|
||||
{
|
||||
for (size_t t = 0; t < T; t++)
|
||||
{
|
||||
size_t i = t * nbrUttPerSample + k;
|
||||
tmpMat3.ColumnSlice(t, 1).SetValue(memoryBlk.ColumnSlice(i, 1));
|
||||
}
|
||||
/// d x T
|
||||
tmpMemoryBlk4EachUtt.ColumnSlice(k*T, T) = tmpMat3;
|
||||
|
||||
Matrix<ElemType>::Multiply(tmpMat3, true, wgtMatrix, false, tmpMat);
|
||||
/// T x d x (d x e) = T x e
|
||||
|
||||
Matrix<ElemType>::Multiply(tmpMat, false, refFunction.ColumnSlice(k,1), false, tmpMat2);
|
||||
/// T x e x (e x 1) = T x 1
|
||||
|
||||
tmpSoftMax.ColumnSlice(k, 1) = tmpMat2;
|
||||
tmpMat2.InplaceLogSoftmax(true);
|
||||
tmpMat2.InplaceExp();
|
||||
|
||||
Matrix<ElemType>::Multiply(tmpMat3, false, tmpMat2, false, tmpMat4);
|
||||
functionValues.ColumnSlice(k, 1).SetValue(tmpMat4);
|
||||
/// d x 1
|
||||
}
|
||||
/// d x k
|
||||
|
||||
#if NANCHECK
|
||||
functionValues.HasNan("Alignment");
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
input 0, denoted as r (in d x k) : this is an input that is treated as given observation, so no gradient is backpropagated into it.
|
||||
input 1, denoted as M (in e x k x T) : this is a block of memory
|
||||
input 2, denoted as W (in d x e) : this is a matrix to compute similarity
|
||||
d : input 0 feature dimension
|
||||
k : number of utterances per minibatch
|
||||
T : input 1 time dimension
|
||||
e : input 1 feature dimension
|
||||
the operation is
|
||||
s = r^T W M in k x T
|
||||
w = softmax(s) in k x T
|
||||
o = M w^T in e x k
|
||||
*/
|
||||
virtual void Validate()
|
||||
{
|
||||
PrintSelfBeforeValidation();
|
||||
size_t k, T, e, d, i;
|
||||
|
||||
if (m_children.size() != 3)
|
||||
throw std::logic_error("AlignmentNode operation should have three input.");
|
||||
|
||||
if (Inputs(0)->FunctionValues().GetNumElements() == 0 ||
|
||||
Inputs(1)->FunctionValues().GetNumElements() == 0 ||
|
||||
Inputs(2)->FunctionValues().GetNumElements() == 0)
|
||||
throw std::logic_error("AlignmentNode operation: the input nodes have 0 element.");
|
||||
|
||||
d = Inputs(0)->FunctionValues().GetNumRows();
|
||||
k = Inputs(0)->FunctionValues().GetNumCols();
|
||||
i = Inputs(1)->FunctionValues().GetNumCols();
|
||||
e = Inputs(1)->FunctionValues().GetNumRows();
|
||||
T = i / k;
|
||||
if (Inputs(2)->FunctionValues().GetNumRows() != d ||
|
||||
Inputs(2)->FunctionValues().GetNumCols() != e)
|
||||
LogicError("AlignmentNode operation: the weight matrix dimension doesn't match input feature dimensions.");
|
||||
|
||||
FunctionValues().Resize(e, k);
|
||||
|
||||
CopyImageSizeFromInputs();
|
||||
}
|
||||
|
||||
virtual void AttachInputs(const ComputationNodePtr refFeature, const ComputationNodePtr memoryBlk, const ComputationNodePtr wgtMatrix)
|
||||
{
|
||||
m_children.resize(3);
|
||||
m_children[0] = refFeature;
|
||||
m_children[1] = memoryBlk;
|
||||
m_children[2] = wgtMatrix;
|
||||
}
|
||||
|
||||
virtual void MoveMatricesToDevice(const DEVICEID_TYPE deviceId)
|
||||
{
|
||||
ComputationNode<ElemType>::MoveMatricesToDevice(deviceId);
|
||||
|
||||
if (deviceId != AUTOPLACEMATRIX)
|
||||
{
|
||||
if (m_memoryBlk4EachUtt.GetDeviceId() != deviceId)
|
||||
m_memoryBlk4EachUtt.TransferFromDeviceToDevice(m_memoryBlk4EachUtt.GetDeviceId(), deviceId);
|
||||
if (m_softmax.GetDeviceId() != deviceId)
|
||||
m_softmax.TransferFromDeviceToDevice(m_softmax.GetDeviceId(), deviceId);
|
||||
if (m_weight.GetDeviceId() != deviceId)
|
||||
m_weight.TransferFromDeviceToDevice(m_weight.GetDeviceId(), deviceId);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void CopyTo(const ComputationNodePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
ComputationNode<ElemType>::CopyTo(nodeP, newName, flags);
|
||||
AlignmentNode<ElemType>* node = (AlignmentNode<ElemType>*) nodeP;
|
||||
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
node->m_memoryBlk4EachUtt = m_memoryBlk4EachUtt;
|
||||
node->m_softmax = m_softmax;
|
||||
node->m_weight = m_weight;
|
||||
node->m_ones = m_ones;
|
||||
}
|
||||
}
|
||||
|
||||
// copy constructor
|
||||
AlignmentNode(const AlignmentNode<ElemType>* node, const std::wstring& newName, const CopyNodeFlags flags)
|
||||
: ComputationNode<ElemType>(node->m_deviceId), m_memoryBlk4EachUtt(node->m_deviceId), m_softmax(node->m_deviceId), m_ones(node->m_deviceId)
|
||||
{
|
||||
node->CopyTo(this, newName, flags);
|
||||
}
|
||||
|
||||
virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
const std::wstring& name = (newName == L"") ? NodeName() : newName;
|
||||
|
||||
ComputationNodePtr node = new AlignmentNode<ElemType>(this, name, flags);
|
||||
return node;
|
||||
}
|
||||
|
||||
private:
|
||||
Matrix<ElemType> m_memoryBlk4EachUtt;
|
||||
Matrix<ElemType> m_softmax;
|
||||
Matrix<ElemType> m_weight;
|
||||
Matrix<ElemType> m_ones;
|
||||
};
|
||||
|
||||
template class AlignmentNode<float>;
|
||||
template class AlignmentNode<double>;
|
||||
|
||||
}}}
|
||||
|
|
|
@ -431,6 +431,202 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
}
|
||||
|
||||
//initialize eval results
|
||||
std::vector<ElemType> evalResults;
|
||||
for (int i = 0; i < decoderEvalNodes.size(); i++)
|
||||
{
|
||||
evalResults.push_back((ElemType)0);
|
||||
}
|
||||
|
||||
//prepare features and labels
|
||||
std::vector<ComputationNodePtr> & encoderFeatureNodes = encoderNet.FeatureNodes();
|
||||
|
||||
std::vector<ComputationNodePtr> & decoderFeatureNodes = decoderNet.FeatureNodes();
|
||||
std::vector<ComputationNodePtr> & decoderLabelNodes = decoderNet.LabelNodes();
|
||||
|
||||
std::map<std::wstring, Matrix<ElemType>*> encoderInputMatrices;
|
||||
for (size_t i = 0; i < encoderFeatureNodes.size(); i++)
|
||||
{
|
||||
encoderInputMatrices[encoderFeatureNodes[i]->NodeName()] = &encoderFeatureNodes[i]->FunctionValues();
|
||||
}
|
||||
|
||||
std::map<std::wstring, Matrix<ElemType>*> decoderInputMatrices;
|
||||
for (size_t i = 0; i < decoderFeatureNodes.size(); i++)
|
||||
{
|
||||
decoderInputMatrices[decoderFeatureNodes[i]->NodeName()] = &decoderFeatureNodes[i]->FunctionValues();
|
||||
}
|
||||
for (size_t i = 0; i < decoderLabelNodes.size(); i++)
|
||||
{
|
||||
decoderInputMatrices[decoderLabelNodes[i]->NodeName()] = &decoderLabelNodes[i]->FunctionValues();
|
||||
}
|
||||
|
||||
//evaluate through minibatches
|
||||
size_t totalEpochSamples = 0;
|
||||
size_t numMBsRun = 0;
|
||||
size_t actualMBSize = 0;
|
||||
size_t numSamplesLastMBs = 0;
|
||||
size_t lastMBsRun = 0; //MBs run before this display
|
||||
|
||||
std::vector<ElemType> evalResultsLastMBs;
|
||||
for (int i = 0; i < evalResults.size(); i++)
|
||||
evalResultsLastMBs.push_back((ElemType)0);
|
||||
|
||||
encoderDataReader.StartMinibatchLoop(mbSize, 0, testSize);
|
||||
decoderDataReader.StartMinibatchLoop(mbSize, 0, testSize);
|
||||
|
||||
Matrix<ElemType> mEncoderOutput(encoderEvalNodes[0]->FunctionValues().GetDeviceId());
|
||||
Matrix<ElemType> historyMat(encoderEvalNodes[0]->FunctionValues().GetDeviceId());
|
||||
|
||||
bool bContinueDecoding = true;
|
||||
while (bContinueDecoding){
|
||||
/// first evaluate encoder network
|
||||
if (encoderDataReader.GetMinibatch(encoderInputMatrices) == false)
|
||||
break;
|
||||
if (decoderDataReader.GetMinibatch(decoderInputMatrices) == false)
|
||||
break;
|
||||
UpdateEvalTimeStamps(encoderFeatureNodes);
|
||||
UpdateEvalTimeStamps(decoderFeatureNodes);
|
||||
|
||||
actualMBSize = decoderNet.GetActualMBSize();
|
||||
if (actualMBSize == 0)
|
||||
LogicError("decoderTrainSetDataReader read data but decoderNet reports no data read");
|
||||
|
||||
encoderNet.SetActualMiniBatchSize(actualMBSize);
|
||||
encoderNet.SetActualNbrSlicesInEachRecIter(encoderDataReader.NumberSlicesInEachRecurrentIter());
|
||||
encoderDataReader.SetSentenceSegBatch(encoderNet.mSentenceBoundary, encoderNet.mExistsBeginOrNoLabels);
|
||||
|
||||
assert(encoderEvalNodes.size() == 1);
|
||||
for (int i = 0; i < encoderEvalNodes.size(); i++)
|
||||
{
|
||||
encoderNet.Evaluate(encoderEvalNodes[i]);
|
||||
}
|
||||
|
||||
|
||||
/// not the sentence begining, because the initial hidden layer activity is from the encoder network
|
||||
decoderNet.SetActualNbrSlicesInEachRecIter(decoderDataReader.NumberSlicesInEachRecurrentIter());
|
||||
decoderDataReader.SetSentenceSegBatch(decoderNet.mSentenceBoundary, decoderNet.mExistsBeginOrNoLabels);
|
||||
|
||||
for (int i = 0; i<decoderEvalNodes.size(); i++)
|
||||
{
|
||||
decoderNet.Evaluate(decoderEvalNodes[i]);
|
||||
evalResults[i] += decoderEvalNodes[i]->FunctionValues().Get00Element(); //criterionNode should be a scalar
|
||||
}
|
||||
|
||||
totalEpochSamples += actualMBSize;
|
||||
numMBsRun++;
|
||||
|
||||
if (m_traceLevel > 0)
|
||||
{
|
||||
numSamplesLastMBs += actualMBSize;
|
||||
|
||||
if (numMBsRun % m_numMBsToShowResult == 0)
|
||||
{
|
||||
DisplayEvalStatistics(lastMBsRun + 1, numMBsRun, numSamplesLastMBs, decoderEvalNodes, evalResults, evalResultsLastMBs);
|
||||
|
||||
for (int i = 0; i < evalResults.size(); i++)
|
||||
{
|
||||
evalResultsLastMBs[i] = evalResults[i];
|
||||
}
|
||||
numSamplesLastMBs = 0;
|
||||
lastMBsRun = numMBsRun;
|
||||
}
|
||||
}
|
||||
|
||||
/// call DataEnd to check if end of sentence is reached
|
||||
/// datareader will do its necessary/specific process for sentence ending
|
||||
encoderDataReader.DataEnd(endDataSentence);
|
||||
decoderDataReader.DataEnd(endDataSentence);
|
||||
}
|
||||
|
||||
// show last batch of results
|
||||
if (m_traceLevel > 0 && numSamplesLastMBs > 0)
|
||||
{
|
||||
DisplayEvalStatistics(lastMBsRun + 1, numMBsRun, numSamplesLastMBs, decoderEvalNodes, evalResults, evalResultsLastMBs);
|
||||
}
|
||||
|
||||
//final statistics
|
||||
for (int i = 0; i < evalResultsLastMBs.size(); i++)
|
||||
{
|
||||
evalResultsLastMBs[i] = 0;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Final Results: ");
|
||||
DisplayEvalStatistics(1, numMBsRun, totalEpochSamples, decoderEvalNodes, evalResults, evalResultsLastMBs);
|
||||
|
||||
for (int i = 0; i < evalResults.size(); i++)
|
||||
{
|
||||
evalResults[i] /= totalEpochSamples;
|
||||
}
|
||||
|
||||
return evalResults;
|
||||
}
|
||||
|
||||
/// this evaluates encoder network and decoder network
|
||||
vector<ElemType> sfbEvaluateEncoderDecoderWithHiddenStates(
|
||||
ComputationNetwork<ElemType>& encoderNet,
|
||||
ComputationNetwork<ElemType>& decoderNet,
|
||||
IDataReader<ElemType>& encoderDataReader,
|
||||
IDataReader<ElemType>& decoderDataReader,
|
||||
const vector<wstring>& encoderEvalNodeNames,
|
||||
const vector<wstring>& decoderEvalNodeNames,
|
||||
const size_t mbSize,
|
||||
const size_t testSize = requestDataSize)
|
||||
{
|
||||
//specify evaluation nodes
|
||||
std::vector<ComputationNodePtr> encoderEvalNodes;
|
||||
std::vector<ComputationNodePtr> decoderEvalNodes;
|
||||
|
||||
if (encoderEvalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (encoderNet.EvaluationNodes().size() == 0)
|
||||
throw std::logic_error("There is no default evalnodes criterion node specified in the network.");
|
||||
|
||||
for (int i = 0; i < encoderNet.EvaluationNodes().size(); i++)
|
||||
encoderEvalNodes.push_back(encoderNet.EvaluationNodes()[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < encoderEvalNodeNames.size(); i++)
|
||||
{
|
||||
ComputationNodePtr node = encoderNet.GetNodeFromName(encoderEvalNodeNames[i]);
|
||||
encoderNet.BuildAndValidateNetwork(node);
|
||||
if (!node->FunctionValues().GetNumElements() == 1)
|
||||
{
|
||||
throw std::logic_error("The nodes passed to SimpleEvaluator::Evaluate function must be either eval or training criterion nodes (which evalues to 1x1 value).");
|
||||
}
|
||||
encoderEvalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
if (decoderEvalNodeNames.size() == 0)
|
||||
{
|
||||
fprintf(stderr, "evalNodeNames are not specified, using all the default evalnodes and training criterion nodes.\n");
|
||||
if (decoderNet.EvaluationNodes().size() == 0)
|
||||
throw std::logic_error("There is no default evalnodes criterion node specified in the network.");
|
||||
if (decoderNet.FinalCriterionNodes().size() == 0)
|
||||
throw std::logic_error("There is no default criterion criterion node specified in the network.");
|
||||
|
||||
for (int i = 0; i < decoderNet.EvaluationNodes().size(); i++)
|
||||
decoderEvalNodes.push_back(encoderNet.EvaluationNodes()[i]);
|
||||
|
||||
for (int i = 0; i < decoderNet.FinalCriterionNodes().size(); i++)
|
||||
decoderEvalNodes.push_back(decoderNet.FinalCriterionNodes()[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < decoderEvalNodeNames.size(); i++)
|
||||
{
|
||||
ComputationNodePtr node = decoderNet.GetNodeFromName(decoderEvalNodeNames[i]);
|
||||
decoderNet.BuildAndValidateNetwork(node);
|
||||
if (!node->FunctionValues().GetNumElements() == 1)
|
||||
{
|
||||
throw std::logic_error("The nodes passed to SimpleEvaluator::Evaluate function must be either eval or training criterion nodes (which evalues to 1x1 value).");
|
||||
}
|
||||
decoderEvalNodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
if (m_lst_pair_encoder_decoder_nodes.size() == 0)
|
||||
throw runtime_error("TrainOneEpochEncoderDecoderWithHiddenStates: no encoder and decoder node pairs");
|
||||
|
||||
|
|
|
@ -355,6 +355,125 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return *m_net;
|
||||
}
|
||||
|
||||
/**
|
||||
this builds an alignment based LM generator
|
||||
the aligment node takes a variable length input and relates each element to a variable length output
|
||||
*/
|
||||
template<class ElemType>
|
||||
ComputationNetwork<ElemType>& SimpleNetworkBuilder<ElemType>::BuildAlignmentDecoderNetworkFromDescription(ComputationNetwork<ElemType>* encoderNet, size_t mbSize)
|
||||
{
|
||||
if (m_net->GetTotalNumberOfNodes() < 1) //not built yet
|
||||
{
|
||||
unsigned long randomSeed = 1;
|
||||
|
||||
size_t numHiddenLayers = m_layerSizes.size() - 2;
|
||||
|
||||
size_t numRecurrentLayers = m_recurrentLayers.size();
|
||||
|
||||
ComputationNodePtr input = nullptr, encoderOutput = nullptr, e = nullptr,
|
||||
b = nullptr, w = nullptr, u = nullptr, delay = nullptr, output = nullptr, label = nullptr, alignoutput = nullptr;
|
||||
ComputationNodePtr clslogpostprob = nullptr;
|
||||
ComputationNodePtr clsweight = nullptr;
|
||||
|
||||
input = m_net->CreateSparseInputNode(L"features", m_layerSizes[0], mbSize);
|
||||
m_net->FeatureNodes().push_back(input);
|
||||
|
||||
if (m_lookupTableOrder > 0)
|
||||
{
|
||||
e = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"E%d", 0), m_layerSizes[1], m_layerSizes[0] / m_lookupTableOrder);
|
||||
m_net->InitLearnableParameters(e, m_uniformInit, randomSeed++, m_initValueScale);
|
||||
output = m_net->LookupTable(e, input, L"LookupTable");
|
||||
|
||||
if (m_addDropoutNodes)
|
||||
input = m_net->Dropout(output);
|
||||
else
|
||||
input = output;
|
||||
}
|
||||
else
|
||||
{
|
||||
LogicError("BuildCLASSLSTMNetworkFromDescription: LSTMNode cannot take sparse input. Need to project sparse input to continuous vector using LookupTable. Suggest using setups below\n layerSizes=$VOCABSIZE$:100:$HIDDIM$:$VOCABSIZE$ \nto have 100 dimension projection, and lookupTableOrder=1\n to project to a single window. To use larger context window, set lookupTableOrder=3 for example with width-3 context window.\n ");
|
||||
}
|
||||
|
||||
int recur_idx = 0;
|
||||
int offset = m_lookupTableOrder > 0 ? 1 : 0;
|
||||
|
||||
/// the source network side output dimension needs to match the 1st layer dimension in the decoder network
|
||||
std::vector<ComputationNodePtr> & encoderEvaluationNodes = encoderNet->OutputNodes();
|
||||
if (encoderEvaluationNodes.size() != 1)
|
||||
LogicError("BuildAlignmentDecoderNetworkFromDescription: encoder network should have only one output node as source node for the decoder network: ");
|
||||
|
||||
encoderOutput = m_net->PairNetwork(encoderEvaluationNodes[0], L"pairNetwork");
|
||||
|
||||
if (numHiddenLayers > 0)
|
||||
{
|
||||
int i = 1 + offset;
|
||||
u = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"U%d", i), m_layerSizes[i], m_layerSizes[offset] * (offset ? m_lookupTableOrder : 1));
|
||||
m_net->InitLearnableParameters(u, m_uniformInit, randomSeed++, m_initValueScale);
|
||||
w = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"W%d", i), m_layerSizes[i], m_layerSizes[i]);
|
||||
m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale);
|
||||
|
||||
delay = m_net->Delay(NULL, m_defaultHiddenActivity, (size_t)m_layerSizes[i], mbSize);
|
||||
// output = (ComputationNodePtr)BuildLSTMNodeComponent(randomSeed, 0, m_layerSizes[offset] * (offset ? m_lookupTableOrder : 1), m_layerSizes[offset + 1], input);
|
||||
// output = (ComputationNodePtr)BuildLSTMComponent(randomSeed, mbSize, 0, m_layerSizes[offset] * (offset ? m_lookupTableOrder : 1), m_layerSizes[offset + 1], input);
|
||||
|
||||
/// alignment node to get weights from source to target
|
||||
/// this aligment node computes weights of the current hidden state after special encoder ending symbol to all
|
||||
/// states before the special encoder ending symbol. The weights are used to summarize all encoder inputs.
|
||||
/// the weighted sum of inputs are then used as the additional input to the LSTM input in the next layer
|
||||
e = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"MatForSimilarity%d", i), m_layerSizes[i], m_layerSizes[i]);
|
||||
m_net->InitLearnableParameters(e, m_uniformInit, randomSeed++, m_initValueScale);
|
||||
|
||||
alignoutput = (ComputationNodePtr)m_net->Alignment(delay, encoderOutput, e, L"alignment");
|
||||
|
||||
output = ApplyNonlinearFunction(
|
||||
m_net->Plus(
|
||||
m_net->Times(u, input), m_net->Times(w, alignoutput)), 0);
|
||||
delay->AttachInputs(output);
|
||||
input = output;
|
||||
|
||||
for (; i < numHiddenLayers; i++)
|
||||
{
|
||||
output = (ComputationNodePtr)BuildLSTMNodeComponent(randomSeed, i, m_layerSizes[i], m_layerSizes[i + 1], input);
|
||||
//output = (ComputationNodePtr)BuildLSTMComponent(randomSeed, mbSize, i, m_layerSizes[i], m_layerSizes[i + 1], input);
|
||||
|
||||
if (m_addDropoutNodes)
|
||||
input = m_net->Dropout(output);
|
||||
else
|
||||
input = output;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// need to have [input_dim x output_dim] matrix
|
||||
/// e.g., [200 x 10000], where 10000 is the vocabulary size
|
||||
/// this is for speed-up issue as per word matrix can be simply obtained using column slice
|
||||
w = m_net->CreateLearnableParameter(msra::strfun::wstrprintf(L"OW%d", numHiddenLayers), m_layerSizes[numHiddenLayers], m_layerSizes[numHiddenLayers + 1]);
|
||||
m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale);
|
||||
|
||||
/// the label is a dense matrix. each element is the word index
|
||||
label = m_net->CreateInputNode(L"labels", 4, mbSize);
|
||||
|
||||
clsweight = m_net->CreateLearnableParameter(L"WeightForClassPostProb", m_nbrCls, m_layerSizes[numHiddenLayers]);
|
||||
m_net->InitLearnableParameters(clsweight, m_uniformInit, randomSeed++, m_initValueScale);
|
||||
clslogpostprob = m_net->Times(clsweight, input, L"ClassPostProb");
|
||||
|
||||
output = AddTrainAndEvalCriterionNodes(input, label, w, L"TrainNodeClassBasedCrossEntropy", L"EvalNodeClassBasedCrossEntrpy",
|
||||
clslogpostprob);
|
||||
|
||||
output = m_net->Times(m_net->Transpose(w), input, L"outputs");
|
||||
|
||||
m_net->OutputNodes().push_back(output);
|
||||
|
||||
//add softmax layer (if prob is needed or KL reg adaptation is needed)
|
||||
output = m_net->Softmax(output, L"PosteriorProb");
|
||||
}
|
||||
|
||||
m_net->ResetEvalTimeStamp();
|
||||
|
||||
return *m_net;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
ComputationNetwork<ElemType>& SimpleNetworkBuilder<ElemType>::BuildLogBilinearNetworkFromDescription(size_t mbSize)
|
||||
{
|
||||
|
|
|
@ -36,7 +36,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
NPLM = 32, CLASSLSTM = 64, NCELSTM = 128,
|
||||
CLSTM = 256, RCRF = 512,
|
||||
UNIDIRECTIONALLSTM=19,
|
||||
BIDIRECTIONALLSTM= 20} RNNTYPE;
|
||||
BIDIRECTIONALLSTM= 20,
|
||||
ALIGNMENTSIMILARITYGENERATOR=21
|
||||
} RNNTYPE;
|
||||
|
||||
|
||||
enum class TrainingCriterion : int
|
||||
|
@ -179,6 +181,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (std::find(strType.begin(), strType.end(), L"JOINTCONDITIONALBILSTMSTREAMS") != strType.end() ||
|
||||
std::find(strType.begin(), strType.end(), L"BIDIRECTIONALLSTMWITHPASTPREDICTION") != strType.end())
|
||||
m_rnnType = BIDIRECTIONALLSTM;
|
||||
if (std::find(strType.begin(), strType.end(), L"ALIGNMENTSIMILARITYGENERATOR") != strType.end())
|
||||
m_rnnType = ALIGNMENTSIMILARITYGENERATOR;
|
||||
}
|
||||
|
||||
// Init - Builder Initialize for multiple data sets
|
||||
|
@ -235,7 +239,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
virtual ComputationNetwork<ElemType>& LoadNetworkFromFile(const wstring& modelFileName, bool forceLoad = true,
|
||||
bool bAllowNoCriterion = false)
|
||||
bool bAllowNoCriterion = false, ComputationNetwork<ElemType>* anotherNetwork=nullptr)
|
||||
{
|
||||
if (m_net->GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load
|
||||
{
|
||||
|
@ -252,7 +256,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
else
|
||||
{
|
||||
m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterion);
|
||||
m_net->LoadFromFile(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterion, anotherNetwork);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -260,7 +264,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return *m_net;
|
||||
}
|
||||
|
||||
ComputationNetwork<ElemType>& BuildNetworkFromDescription()
|
||||
ComputationNetwork<ElemType>& BuildNetworkFromDescription(ComputationNetwork<ElemType>* encoderNet)
|
||||
{
|
||||
size_t mbSize = 1;
|
||||
|
||||
|
@ -288,6 +292,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
return BuildUnidirectionalLSTMNetworksFromDescription(mbSize);
|
||||
if (m_rnnType == BIDIRECTIONALLSTM)
|
||||
return BuildBiDirectionalLSTMNetworksFromDescription(mbSize);
|
||||
if (m_rnnType == ALIGNMENTSIMILARITYGENERATOR)
|
||||
return BuildAlignmentDecoderNetworkFromDescription(encoderNet, mbSize);
|
||||
|
||||
if (m_net->GetTotalNumberOfNodes() < 1) //not built yet
|
||||
{
|
||||
|
@ -421,7 +427,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
ComputationNetwork<ElemType>& BuildNCELSTMNetworkFromDescription(size_t mbSize = 1);
|
||||
|
||||
|
||||
ComputationNetwork<ElemType>& BuildAlignmentDecoderNetworkFromDescription(ComputationNetwork<ElemType>* encoderNet, size_t mbSize = 1);
|
||||
|
||||
ComputationNetwork<ElemType>& BuildNetworkFromDbnFile(const std::wstring& dbnModelFileName)
|
||||
{
|
||||
|
||||
|
|
|
@ -1408,7 +1408,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
size_t i = t % nS;
|
||||
if (m_existsSentenceBeginOrNoLabels->ColumnSlice(j, 1).Get00Element() == EXISTS_SENTENCE_BEGIN_OR_NO_LABELS)
|
||||
{
|
||||
if ((*m_sentenceSeg)(j, i) == NO_LABELS)
|
||||
if ((*m_sentenceSeg)(i,j) == NO_LABELS)
|
||||
{
|
||||
matrixToBeMasked.ColumnSlice(t,1).SetValue(0);
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче