decode for blstm
This commit is contained in:
Родитель
9800670b2b
Коммит
58abf438b2
|
@ -255,6 +255,8 @@ void DoWriteOutput(const ConfigParameters& config)
|
|||
size_t decodeExpandBeam = writerConfig(L"decode_expand_beam", 20);
|
||||
string indictfile = writerConfig(L"DictFile", L"");
|
||||
ElemType thresh = writerConfig(L"Thresh", 0.68f);
|
||||
size_t rightsplice = writerConfig(L"rightsplice", 20);
|
||||
size_t encoderdim = writerConfig(L"encoderdim", 640);
|
||||
DataWriter testDataWriter(writerConfig);
|
||||
// ConfigParameters
|
||||
if (decodeType == 0)
|
||||
|
@ -262,7 +264,7 @@ void DoWriteOutput(const ConfigParameters& config)
|
|||
else if (decodeType == 1)
|
||||
writer.WriteOutput_greedy(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, writerUnittest);
|
||||
else if (decodeType == 2)
|
||||
writer.WriteOutput_beam(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, writerUnittest, decodeBeam, decodeExpandBeam, indictfile,thresh);
|
||||
writer.WriteOutput_beam(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, writerUnittest, decodeBeam, decodeExpandBeam, indictfile,thresh, rightsplice,encoderdim);
|
||||
|
||||
}
|
||||
else if (config.Exists("outputPath"))
|
||||
|
|
|
@ -13,7 +13,12 @@
|
|||
#include "Basics.h"
|
||||
#include "Matrix.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
namespace Microsoft
|
||||
{
|
||||
namespace MSR
|
||||
{
|
||||
namespace CNTK
|
||||
{
|
||||
|
||||
// Forward declarations
|
||||
class FrameRange;
|
||||
|
@ -101,7 +106,10 @@ struct MBLayout
|
|||
{
|
||||
return seqId == other.seqId && s == other.s && tBegin == other.tBegin && tEnd == other.tEnd;
|
||||
}
|
||||
size_t GetNumTimeSteps() const { return (size_t)(tEnd - tBegin); }
|
||||
size_t GetNumTimeSteps() const
|
||||
{
|
||||
return (size_t)(tEnd - tBegin);
|
||||
}
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
@ -121,9 +129,9 @@ struct MBLayout
|
|||
|
||||
// copy the content of another MBLayoutPtr over
|
||||
// Use this instead of actual assignment to make it super-obvious that this is not copying the pointer but actual content. The pointer is kept fixed.
|
||||
// Use "keepName" if the "identity" of the target is to be preserved, e.g.
|
||||
// Use "keepName" if the "identity" of the target is to be preserved, e.g.
|
||||
// while copying from reader space to network space.
|
||||
void CopyFrom(const MBLayoutPtr& other, bool keepName=false)
|
||||
void CopyFrom(const MBLayoutPtr &other, bool keepName = false)
|
||||
{
|
||||
m_numTimeSteps = other->m_numTimeSteps;
|
||||
m_numParallelSequences = other->m_numParallelSequences;
|
||||
|
@ -148,7 +156,7 @@ struct MBLayout
|
|||
}
|
||||
|
||||
// Destructive copy that steals ownership if the content, like std::move()
|
||||
// Note: For some reason the VC++ compiler does not generate the
|
||||
// Note: For some reason the VC++ compiler does not generate the
|
||||
// move assignment and we have to do this ourselves
|
||||
void MoveFrom(MBLayoutPtr other)
|
||||
{
|
||||
|
@ -173,8 +181,8 @@ struct MBLayout
|
|||
m_axisName = std::move(other->m_axisName);
|
||||
}
|
||||
|
||||
MBLayout(const MBLayout&) = delete;
|
||||
MBLayout& operator=(const MBLayout&) = delete;
|
||||
MBLayout(const MBLayout &) = delete;
|
||||
MBLayout &operator=(const MBLayout &) = delete;
|
||||
|
||||
public:
|
||||
// resize and reset all frames to None (note: this is an invalid state and must be fixed by caller afterwards)
|
||||
|
@ -183,7 +191,7 @@ public:
|
|||
// remember the dimensions
|
||||
m_numParallelSequences = numParallelSequences;
|
||||
m_numTimeSteps = numTimeSteps;
|
||||
if (deepInit)
|
||||
if (deepInit)
|
||||
{
|
||||
m_distanceToStart.Resize(m_numParallelSequences, m_numTimeSteps);
|
||||
m_distanceToEnd.Resize(m_numParallelSequences, m_numTimeSteps);
|
||||
|
@ -211,10 +219,10 @@ public:
|
|||
// - width: maximum width of structure; set to maximum over sequence lengths
|
||||
// - inputSequences: vector of input SequenceInfo records (only seqId and GetNumTimeSteps() are used)
|
||||
// - placement, rowAllocations: temp buffers (passed in to be able to optimize memory allocations)
|
||||
template<typename SequenceInfoVector>
|
||||
void InitAsPackedSequences(const SequenceInfoVector& inputSequences,
|
||||
/*temp buffer*/std::vector<std::pair<size_t, size_t>>& placement,
|
||||
/*temp buffer*/std::vector<size_t> rowAllocations)
|
||||
template <typename SequenceInfoVector>
|
||||
void InitAsPackedSequences(const SequenceInfoVector &inputSequences,
|
||||
/*temp buffer*/ std::vector<std::pair<size_t, size_t>> &placement,
|
||||
/*temp buffer*/ std::vector<size_t> rowAllocations)
|
||||
{
|
||||
placement.resize(inputSequences.size()); // [sequence index] result goes here (entries are invalid for gaps)
|
||||
// determine width of MBLayout
|
||||
|
@ -227,7 +235,7 @@ public:
|
|||
width = inputSequences[i].GetNumTimeSteps();
|
||||
}
|
||||
// allocate
|
||||
rowAllocations.clear(); // [row] we build rows one by one
|
||||
rowAllocations.clear(); // [row] we build rows one by one
|
||||
for (size_t i = 0; i < inputSequences.size(); i++)
|
||||
{
|
||||
if (inputSequences[i].seqId == GAP_SEQUENCE_ID)
|
||||
|
@ -253,36 +261,48 @@ public:
|
|||
{
|
||||
if (inputSequences[i].seqId == GAP_SEQUENCE_ID)
|
||||
continue;
|
||||
size_t s, tBegin; tie
|
||||
(s, tBegin) = placement[i];
|
||||
AddSequence(inputSequences[i].seqId, s, (ptrdiff_t)tBegin, tBegin + inputSequences[i].GetNumTimeSteps());
|
||||
size_t s, tBegin;
|
||||
tie(s, tBegin) = placement[i];
|
||||
AddSequence(inputSequences[i].seqId, s, (ptrdiff_t) tBegin, tBegin + inputSequences[i].GetNumTimeSteps());
|
||||
}
|
||||
// need to fill the gaps as well
|
||||
for (size_t s = 0; s < rowAllocations.size(); s++)
|
||||
AddGap(s, (size_t)rowAllocations[s], width);
|
||||
AddGap(s, (size_t) rowAllocations[s], width);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// accessors
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
size_t GetNumTimeSteps() const { return m_numTimeSteps; }
|
||||
size_t GetNumParallelSequences() const { return m_numParallelSequences; }
|
||||
size_t GetNumTimeSteps() const
|
||||
{
|
||||
return m_numTimeSteps;
|
||||
}
|
||||
size_t GetNumParallelSequences() const
|
||||
{
|
||||
return m_numParallelSequences;
|
||||
}
|
||||
size_t GetNumSequences() const
|
||||
{
|
||||
return std::count_if(m_sequences.begin(), m_sequences.end(), [](const SequenceInfo& sequence) {
|
||||
return std::count_if(m_sequences.begin(), m_sequences.end(), [](const SequenceInfo &sequence) {
|
||||
return sequence.seqId != GAP_SEQUENCE_ID;
|
||||
});
|
||||
}
|
||||
|
||||
// axis names are for now only a debugging aid
|
||||
// In the future, there will be a mechanism to denote that axes are meant to be the same.
|
||||
const wchar_t* GetAxisName() const { return m_axisName.c_str(); }
|
||||
void SetAxisName(const std::wstring& name) { m_axisName = name; }
|
||||
const wchar_t *GetAxisName() const
|
||||
{
|
||||
return m_axisName.c_str();
|
||||
}
|
||||
void SetAxisName(const std::wstring &name)
|
||||
{
|
||||
m_axisName = name;
|
||||
}
|
||||
void SetUniqueAxisName(std::wstring name) // helper for constructing
|
||||
{
|
||||
// Unfortunatelly, initialization of local static variables is not thread-safe in VS2013.
|
||||
// As workaround, it is moved to the struct level.
|
||||
// As workaround, it is moved to the struct level.
|
||||
// Todo: when upgraded to VS2013, change back to use the local static mutex, and remove also Sequences.cpp.
|
||||
// The mutex is need to make access to nameIndices be thread-safe.
|
||||
// static std::mutex nameIndiciesMutex;
|
||||
|
@ -297,7 +317,7 @@ public:
|
|||
}
|
||||
|
||||
if (index > 0)
|
||||
name += msra::strfun::wstrprintf(L"%d", (int)index);
|
||||
name += msra::strfun::wstrprintf(L"%d", (int) index);
|
||||
SetAxisName(name);
|
||||
}
|
||||
|
||||
|
@ -309,13 +329,13 @@ public:
|
|||
|
||||
// Get the number of frames of the input sequence that belong to the MB, i.e. disregarding sequence elements that are outside of the MB boundaries
|
||||
// Input sequence is expected to belong to this MBLayout
|
||||
size_t GetNumSequenceFramesInCurrentMB(const SequenceInfo& sequenceInfo) const
|
||||
size_t GetNumSequenceFramesInCurrentMB(const SequenceInfo &sequenceInfo) const
|
||||
{
|
||||
return min(sequenceInfo.tEnd, GetNumTimeSteps()) - max(sequenceInfo.tBegin, (ptrdiff_t)0);
|
||||
return min(sequenceInfo.tEnd, GetNumTimeSteps()) - max(sequenceInfo.tBegin, (ptrdiff_t) 0);
|
||||
}
|
||||
|
||||
// return all sequences stored in this minibatch
|
||||
const vector<SequenceInfo>& GetAllSequences() const
|
||||
const vector<SequenceInfo> &GetAllSequences() const
|
||||
{
|
||||
return m_sequences;
|
||||
}
|
||||
|
@ -324,10 +344,10 @@ public:
|
|||
// This is used by MeanNode and InvStdDevNode, and by statistics reporting.
|
||||
size_t GetActualNumSamples() const;
|
||||
|
||||
const Matrix<char>& GetColumnsValidityMask(DEVICEID_TYPE deviceId) const;
|
||||
const Matrix<char> &GetColumnsValidityMask(DEVICEID_TYPE deviceId) const;
|
||||
|
||||
// compare whether two layouts are the same
|
||||
bool operator==(const MBLayout& other) const
|
||||
bool operator==(const MBLayout &other) const
|
||||
{
|
||||
if (this == &other)
|
||||
return true;
|
||||
|
@ -352,7 +372,7 @@ public:
|
|||
{
|
||||
if (!first)
|
||||
s << ", ";
|
||||
s << "{seqId:" << seq.seqId << ", s:" << seq.s <<", begin:" << seq.tBegin << ", end:" << seq.tEnd << "}";
|
||||
s << "{seqId:" << seq.seqId << ", s:" << seq.s << ", begin:" << seq.tBegin << ", end:" << seq.tEnd << "}";
|
||||
first = false;
|
||||
}
|
||||
s << "]}";
|
||||
|
@ -403,7 +423,7 @@ public:
|
|||
if (seqId == GAP_SEQUENCE_ID)
|
||||
{
|
||||
m_numGapFrames += (e - b);
|
||||
if (initDistances)
|
||||
if (initDistances)
|
||||
{
|
||||
for (size_t t = b; t < e; t++)
|
||||
{
|
||||
|
@ -412,24 +432,23 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
if (initDistances)
|
||||
else if (initDistances)
|
||||
{
|
||||
for (size_t t = b; t < e; t++)
|
||||
{
|
||||
for (size_t t = b; t < e; t++)
|
||||
{
|
||||
// update the nearest sentence boundaries, minimum over all parallel sequences
|
||||
// If 0, then we are on a boundary. If not 0, we can still test in presence of FrameRange.m_timeOffset.
|
||||
ptrdiff_t distanceToStart = (ptrdiff_t)t - beginTime;
|
||||
ptrdiff_t distanceToEnd = (ptrdiff_t)(endTime - 1 - t);
|
||||
m_distanceToStart(s, t) = (float)distanceToStart;
|
||||
m_distanceToEnd(s, t) = (float)distanceToEnd;
|
||||
// and the aggregate
|
||||
if (m_distanceToNearestStart[t] > distanceToStart)
|
||||
m_distanceToNearestStart[t] = distanceToStart;
|
||||
if (m_distanceToNearestEnd[t] > distanceToEnd)
|
||||
m_distanceToNearestEnd[t] = distanceToEnd;
|
||||
}
|
||||
// update the nearest sentence boundaries, minimum over all parallel sequences
|
||||
// If 0, then we are on a boundary. If not 0, we can still test in presence of FrameRange.m_timeOffset.
|
||||
ptrdiff_t distanceToStart = (ptrdiff_t) t - beginTime;
|
||||
ptrdiff_t distanceToEnd = (ptrdiff_t)(endTime - 1 - t);
|
||||
m_distanceToStart(s, t) = (float) distanceToStart;
|
||||
m_distanceToEnd(s, t) = (float) distanceToEnd;
|
||||
// and the aggregate
|
||||
if (m_distanceToNearestStart[t] > distanceToStart)
|
||||
m_distanceToNearestStart[t] = distanceToStart;
|
||||
if (m_distanceToNearestEnd[t] > distanceToEnd)
|
||||
m_distanceToNearestEnd[t] = distanceToEnd;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// short-hand to initialize an MBLayout for the common case of frame mode
|
||||
|
@ -471,7 +490,7 @@ public:
|
|||
}
|
||||
|
||||
// find a sequence by its id
|
||||
const SequenceInfo& FindSequence(UniqueSequenceId seqId) const
|
||||
const SequenceInfo &FindSequence(UniqueSequenceId seqId) const
|
||||
{
|
||||
for (const auto &seqInfo : m_sequences)
|
||||
if (seqInfo.seqId == seqId)
|
||||
|
@ -481,13 +500,13 @@ public:
|
|||
|
||||
// find a sequence by SequenceInfo array and position
|
||||
// Use this if sequences may be matching 1:1.
|
||||
const SequenceInfo& FindMatchingSequence(const vector<SequenceInfo>& querySequences, size_t i) const
|
||||
const SequenceInfo &FindMatchingSequence(const vector<SequenceInfo> &querySequences, size_t i) const
|
||||
{
|
||||
// TODO: What are our sorted-ness guarantees?
|
||||
let seqId = querySequences[i].seqId; // the seq id we are looking for
|
||||
if (seqId == GAP_SEQUENCE_ID)
|
||||
LogicError("FindMatchingSequence: Cannot be applied go gaps.");
|
||||
if (seqId == m_sequences[i].seqId) // if both sequence arrays match 1:1 then we found it
|
||||
if (seqId == m_sequences[i].seqId) // if both sequence arrays match 1:1 then we found it
|
||||
return m_sequences[i];
|
||||
else
|
||||
return FindSequence(seqId);
|
||||
|
@ -508,10 +527,15 @@ public:
|
|||
return m_rightSplice;
|
||||
}
|
||||
|
||||
void setRightSplice(int rightsplice)
|
||||
{
|
||||
m_rightSplice = rightsplice;
|
||||
}
|
||||
|
||||
// test boundary flags for a specific condition
|
||||
bool IsBeyondStartOrEnd(const FrameRange& fr) const;
|
||||
bool IsGap(const FrameRange& fr) const;
|
||||
bool IsBeyondMinibatch(const FrameRange& fr) const;
|
||||
bool IsBeyondStartOrEnd(const FrameRange &fr) const;
|
||||
bool IsGap(const FrameRange &fr) const;
|
||||
bool IsBeyondMinibatch(const FrameRange &fr) const;
|
||||
|
||||
// test whether at least one sequence crosses the bounds of this minibatch
|
||||
bool HasSequenceBeyondBegin() const
|
||||
|
@ -540,29 +564,29 @@ public:
|
|||
// -------------------------------------------------------------------
|
||||
|
||||
// get the matrix-column index for a given time step in a given sequence
|
||||
size_t GetColumnIndex(const SequenceInfo& seq, size_t t) const
|
||||
size_t GetColumnIndex(const SequenceInfo &seq, size_t t) const
|
||||
{
|
||||
if (t > seq.GetNumTimeSteps())
|
||||
LogicError("GetColumnIndex: t out of sequence bounds.");
|
||||
if (seq.s > GetNumParallelSequences())
|
||||
LogicError("GetColumnIndex: seq.s out of sequence bounds."); // can only happen if 'seq' does not come out of our own m_sequences array, which is verboten
|
||||
ptrdiff_t tIn = (ptrdiff_t)t + seq.tBegin; // shifted time index
|
||||
if (tIn < 0 || (size_t)tIn >= GetNumTimeSteps())
|
||||
ptrdiff_t tIn = (ptrdiff_t) t + seq.tBegin; // shifted time index
|
||||
if (tIn < 0 || (size_t) tIn >= GetNumTimeSteps())
|
||||
LogicError("GetColumnIndex: Attempted to access a time step that is accessing a portion of a sequence that is not included in current minibatch."); // we may encounter this for truncated BPTT
|
||||
size_t col = (size_t)tIn * GetNumParallelSequences() + seq.s;
|
||||
size_t col = (size_t) tIn * GetNumParallelSequences() + seq.s;
|
||||
assert(col < GetNumCols());
|
||||
return col;
|
||||
}
|
||||
|
||||
// get the matrix-column indices for a given sequence
|
||||
// sequence is expected to belong to this MB
|
||||
vector<size_t> GetColumnIndices(const SequenceInfo& seq) const
|
||||
vector<size_t> GetColumnIndices(const SequenceInfo &seq) const
|
||||
{
|
||||
size_t numFrames = GetNumSequenceFramesInCurrentMB(seq);
|
||||
vector<size_t> res;
|
||||
res.reserve(numFrames);
|
||||
for (size_t i = 0; i < numFrames;++i)
|
||||
res.push_back(GetColumnIndex(seq,i));
|
||||
for (size_t i = 0; i < numFrames; ++i)
|
||||
res.push_back(GetColumnIndex(seq, i));
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -653,7 +677,6 @@ private:
|
|||
static std::map<std::wstring, size_t> s_nameIndices;
|
||||
|
||||
public:
|
||||
|
||||
// special accessor for sequence training --TODO: must be replaced by a different mechanism
|
||||
bool IsEnd(size_t s, size_t t) const
|
||||
{
|
||||
|
@ -781,33 +804,33 @@ public:
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::pair<size_t,size_t> GetSequenceRange() const
|
||||
std::pair<size_t, size_t> GetSequenceRange() const
|
||||
{
|
||||
if (!m_pMBLayout) return
|
||||
make_pair(0, 1);
|
||||
else if (seqIndex == SIZE_MAX) return
|
||||
make_pair(0, m_pMBLayout->GetNumParallelSequences());
|
||||
else return
|
||||
make_pair(seqIndex, seqIndex + 1);
|
||||
if (!m_pMBLayout)
|
||||
return make_pair(0, 1);
|
||||
else if (seqIndex == SIZE_MAX)
|
||||
return make_pair(0, m_pMBLayout->GetNumParallelSequences());
|
||||
else
|
||||
return make_pair(seqIndex, seqIndex + 1);
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> GetTimeRange() const
|
||||
{
|
||||
if (!m_pMBLayout) return
|
||||
make_pair(0, 1);
|
||||
else if (IsAllFrames()) return
|
||||
make_pair(0, m_pMBLayout->GetNumTimeSteps());
|
||||
else return
|
||||
make_pair(timeIdxInSeq + m_timeOffset, timeIdxInSeq + m_timeOffset + m_timeRange);
|
||||
if (!m_pMBLayout)
|
||||
return make_pair(0, 1);
|
||||
else if (IsAllFrames())
|
||||
return make_pair(0, m_pMBLayout->GetNumTimeSteps());
|
||||
else
|
||||
return make_pair(timeIdxInSeq + m_timeOffset, timeIdxInSeq + m_timeOffset + m_timeRange);
|
||||
}
|
||||
|
||||
bool IsOneColumnWrt(const shared_ptr<MBLayout> &pMBLayout) const
|
||||
{
|
||||
if (!pMBLayout) return
|
||||
true; // target has no layout: This would broadcast.
|
||||
else return
|
||||
(pMBLayout->GetNumTimeSteps() == 1 || (!IsAllFrames() && m_timeRange == 1)) &&
|
||||
(pMBLayout->GetNumParallelSequences() == 1 || seqIndex != SIZE_MAX);
|
||||
if (!pMBLayout)
|
||||
return true; // target has no layout: This would broadcast.
|
||||
else
|
||||
return (pMBLayout->GetNumTimeSteps() == 1 || (!IsAllFrames() && m_timeRange == 1)) &&
|
||||
(pMBLayout->GetNumParallelSequences() == 1 || seqIndex != SIZE_MAX);
|
||||
}
|
||||
|
||||
bool IsBatchMatmul(const shared_ptr<MBLayout> &pMBLayout) const
|
||||
|
@ -816,7 +839,7 @@ public:
|
|||
return false;
|
||||
else
|
||||
return (pMBLayout->GetNumTimeSteps() > 1 && (IsAllFrames() || m_timeRange > 1)) ||
|
||||
(pMBLayout->GetNumParallelSequences() > 1);
|
||||
(pMBLayout->GetNumParallelSequences() > 1);
|
||||
}
|
||||
|
||||
// code that can only handle single-frame ranges will call t() to get the time index, which will throw if numFrames != 1
|
||||
|
@ -882,16 +905,16 @@ inline bool MBLayout::IsGap(const FrameRange &fr) const
|
|||
}
|
||||
|
||||
// test whether frame is exceeding the bounds of the MB
|
||||
inline bool MBLayout::IsBeyondMinibatch(const FrameRange& fr) const
|
||||
inline bool MBLayout::IsBeyondMinibatch(const FrameRange &fr) const
|
||||
{
|
||||
CheckIsValid();
|
||||
|
||||
if (fr.IsAllFrames())
|
||||
LogicError("MBLayout::IsBeyondStartOrEnd() cannot be applied to FrameRange that specifies more than a single time step.");
|
||||
|
||||
const auto beginTime = (ptrdiff_t)fr.timeIdxInSeq + fr.m_timeOffset; // we test off the frame with offset
|
||||
const auto endTime = beginTime + (ptrdiff_t)fr.m_timeRange;
|
||||
return beginTime < 0 || endTime > (ptrdiff_t)GetNumTimeSteps();
|
||||
const auto beginTime = (ptrdiff_t) fr.timeIdxInSeq + fr.m_timeOffset; // we test off the frame with offset
|
||||
const auto endTime = beginTime + (ptrdiff_t) fr.m_timeRange;
|
||||
return beginTime < 0 || endTime > (ptrdiff_t) GetNumTimeSteps();
|
||||
}
|
||||
|
||||
// test whether frame is exceeding the sentence boundaries
|
||||
|
@ -938,14 +961,17 @@ inline bool MBLayout::IsBeyondStartOrEnd(const FrameRange &fr) const
|
|||
}
|
||||
|
||||
// TODO: Remove this version (with sanity checks) after this has been tested. Then the function can be inlined above.
|
||||
inline size_t MBLayout::GetActualNumSamples() const { return m_numFramesDeclared - m_numGapFrames; }
|
||||
inline size_t MBLayout::GetActualNumSamples() const
|
||||
{
|
||||
return m_numFramesDeclared - m_numGapFrames;
|
||||
}
|
||||
|
||||
// return m_columnsValidityMask(,), which is lazily created here upon first call
|
||||
// only called from MaskMissingColumnsTo()
|
||||
// Update: also called from GatherNode::BackpropToNonLooping().
|
||||
// Update: also called from GatherNode::BackpropToNonLooping().
|
||||
// TODO: Can probably be faster by using the sequence array directly.
|
||||
// TODO: Or should we just blast m_distanceToStart to GPU, and maks based on that? It is small compared to features.
|
||||
inline const Matrix<char>& MBLayout::GetColumnsValidityMask(DEVICEID_TYPE deviceId) const
|
||||
inline const Matrix<char> &MBLayout::GetColumnsValidityMask(DEVICEID_TYPE deviceId) const
|
||||
{
|
||||
CheckIsValid();
|
||||
// lazily compute the validity mask
|
||||
|
@ -1152,7 +1178,7 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
|
|||
result.second = shape;
|
||||
|
||||
// get position of time and sequence index
|
||||
const size_t iterDim = shape.size() -1; // valid if data has MBLayout
|
||||
const size_t iterDim = shape.size() - 1; // valid if data has MBLayout
|
||||
|
||||
// MBLayout of data and of FrameRange must be identical pointers,
|
||||
// or in case of broadcasting, respective parent pointers.
|
||||
|
@ -1171,7 +1197,7 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
|
|||
LogicError("TensorSliceWithMBLayoutFor: FrameRange has no layout, incompatible with data's layout: %s",
|
||||
static_cast<string>(*(pMBLayout)).c_str());
|
||||
else
|
||||
LogicError("TensorSliceWithMBLayoutFor: FrameRange's dynamic axis is inconsistent with data: %s vs. %s",
|
||||
LogicError("TensorSliceWithMBLayoutFor: FrameRange's dynamic axis is inconsistent with data: %s vs. %s",
|
||||
static_cast<string>(*(fr.m_pMBLayout)).c_str(), static_cast<string>(*(pMBLayout)).c_str());
|
||||
}
|
||||
// if FrameRange refers to whole minibatch (map mode)
|
||||
|
@ -1201,7 +1227,7 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
|
|||
}
|
||||
|
||||
// sequence index
|
||||
if (fr.seqIndex != SIZE_MAX) // sequence requested?
|
||||
if (fr.seqIndex != SIZE_MAX) // sequence requested?
|
||||
{
|
||||
if (pMBLayout) // (if no layout then broadcast to all sequences)
|
||||
{
|
||||
|
@ -1211,8 +1237,8 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
|
|||
size_t s = fr.seqIndex;
|
||||
if (s >= result.second[sequenceDim])
|
||||
LogicError("TensorSliceWithMBLayoutFor: FrameRange specifies a parallel-sequence index that is out of range.");
|
||||
result.first[sequenceDim] = (ElemType)s;
|
||||
result.second[sequenceDim] = (ElemType)s + 1;
|
||||
result.first[sequenceDim] = (ElemType) s;
|
||||
result.second[sequenceDim] = (ElemType) s + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1233,13 +1259,13 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
|
|||
// 'Reduce' style operations--the criterion nodes and gradient computation--call this.
|
||||
// Warning: The layout used here must match the matrix. E.g. don't pass a child's matrix from a criterion node (use Input(x)->MaskMissing{Values,Gradient}ColumnsToZero() instead.
|
||||
template <class ElemType>
|
||||
static inline void MaskMissingColumnsTo(Matrix<ElemType>& matrixToMask, const MBLayoutPtr& pMBLayout, const FrameRange& fr, ElemType val)
|
||||
static inline void MaskMissingColumnsTo(Matrix<ElemType> &matrixToMask, const MBLayoutPtr &pMBLayout, const FrameRange &fr, ElemType val)
|
||||
{
|
||||
if (pMBLayout && (pMBLayout->HasGaps(fr) || pMBLayout->HasRightSplice()))
|
||||
{
|
||||
const auto& maskMatrix = pMBLayout->GetColumnsValidityMask(matrixToMask.GetDeviceId());
|
||||
const auto &maskMatrix = pMBLayout->GetColumnsValidityMask(matrixToMask.GetDeviceId());
|
||||
|
||||
maskMatrix.TransferToDeviceIfNotThere(matrixToMask.GetDeviceId(), /*ismoved=*/ false, /*emptyTransfer=*/ false, /*updatePreferredDevice=*/ false);
|
||||
maskMatrix.TransferToDeviceIfNotThere(matrixToMask.GetDeviceId(), /*ismoved=*/false, /*emptyTransfer=*/false, /*updatePreferredDevice=*/false);
|
||||
auto maskSlice = DataWithMBLayoutFor(maskMatrix, fr, pMBLayout);
|
||||
|
||||
auto matrixSliceToMask = DataWithMBLayoutFor(matrixToMask, fr, pMBLayout);
|
||||
|
@ -1250,4 +1276,6 @@ static inline void MaskMissingColumnsTo(Matrix<ElemType>& matrixToMask, const MB
|
|||
}
|
||||
}
|
||||
|
||||
}}}
|
||||
} // namespace CNTK
|
||||
} // namespace MSR
|
||||
} // namespace Microsoft
|
||||
|
|
|
@ -394,7 +394,7 @@ public:
|
|||
vector<pair<size_t, ElemType>> datapair;
|
||||
typedef vector<pair<size_t, ElemType>>::value_type ValueType;
|
||||
ElemType* probdata = prob.CopyToArray();
|
||||
|
||||
|
||||
for (size_t n = 0; n < prob.GetNumRows(); n++)
|
||||
{
|
||||
datapair.push_back(ValueType(n, probdata[n]));
|
||||
|
@ -431,8 +431,10 @@ public:
|
|||
return true;
|
||||
}
|
||||
|
||||
void forwardmerged(Sequence a, size_t t, Matrix<ElemType>& sumofENandDE, Matrix<ElemType>& encodeOutput, Matrix<ElemType>& decodeOutput, ComputationNodeBasePtr PlusNode,
|
||||
ComputationNodeBasePtr PlusTransNode, std::vector<ComputationNodeBasePtr> Plusnodes, std::vector<ComputationNodeBasePtr> Plustransnodes, Matrix<ElemType>& Wm, Matrix<ElemType>& Wm2, Matrix<ElemType>& bm)
|
||||
|
||||
void forwardmerged(Sequence a, size_t t, Matrix<ElemType>& sumofENandDE, Matrix<ElemType>& encodeOutput, Matrix<ElemType>& decodeOutput, ComputationNodeBasePtr PlusNode,
|
||||
ComputationNodeBasePtr PlusTransNode, std::vector<ComputationNodeBasePtr> Plusnodes, std::vector<ComputationNodeBasePtr> Plustransnodes, Matrix<ElemType>& Wm, Matrix<ElemType>& bm)
|
||||
|
||||
{
|
||||
/*auto edNode = PlusNode->As<PlusBroadcastNode<ElemType>>();
|
||||
if (edNode->getCombineMode() == 1)
|
||||
|
@ -456,14 +458,14 @@ public:
|
|||
m_net->ForwardPropFromTo(Plusnodes, Plustransnodes);
|
||||
decodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusTransNode)->Value()));
|
||||
tempMatrix.AssignProductOf(Wm, true, decodeOutput, false);
|
||||
tempMatrix2.AssignProductOf(Wm2, true, tempMatrix, false);
|
||||
decodeOutput.AssignSumOf(tempMatrix2, bm);
|
||||
decodeOutput.AssignSumOf(tempMatrix, bm);
|
||||
|
||||
//decodeOutput.VectorMax(maxIdx, maxVal, true);
|
||||
decodeOutput.InplaceLogSoftmax(true);
|
||||
|
||||
decodeOutput.InplaceLogSoftmax(true);
|
||||
}
|
||||
void WriteOutput_beam(IDataReader& dataReader, size_t mbSize, IDataWriter& dataWriter, const std::vector<std::wstring>& outputNodeNames,
|
||||
size_t numOutputSamples = requestDataSize, bool doWriterUnitTest = false, size_t beamSize = 10, size_t expandBeam = 20, string dictfile = L"", ElemType thresh = 0.68)
|
||||
size_t numOutputSamples = requestDataSize, bool doWriterUnitTest = false, size_t beamSize = 10, size_t expandBeam = 20, string dictfile = L"", ElemType thresh = 0.68,
|
||||
size_t rightsplice = 20, size_t encoderdim = 640)
|
||||
{
|
||||
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
|
||||
|
||||
|
@ -509,11 +511,12 @@ public:
|
|||
//get merged input
|
||||
ComputationNodeBasePtr PlusNode = m_net->GetNodeFromName(outputNodeNames[2]);
|
||||
ComputationNodeBasePtr PlusTransNode = m_net->GetNodeFromName(outputNodeNames[3]);
|
||||
|
||||
|
||||
ComputationNodeBasePtr WmNode = m_net->GetNodeFromName(outputNodeNames[4]);
|
||||
ComputationNodeBasePtr Wm2Node = m_net->GetNodeFromName(outputNodeNames[5]);
|
||||
ComputationNodeBasePtr bmNode = m_net->GetNodeFromName(outputNodeNames[6]);
|
||||
|
||||
|
||||
|
||||
ComputationNodeBasePtr bmNode = m_net->GetNodeFromName(outputNodeNames[5]);
|
||||
|
||||
|
||||
//StreamMinibatchInputs PlusinputMatrices =
|
||||
std::vector<ComputationNodeBasePtr> Plusnodes, Plustransnodes;
|
||||
|
@ -539,10 +542,8 @@ public:
|
|||
//for merged RNNT node
|
||||
Matrix<ElemType> Wm(deviceid), Wm2(deviceid), bm(deviceid);
|
||||
Wm.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(WmNode)->Value()));
|
||||
Wm2.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(Wm2Node)->Value()));
|
||||
bm.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(bmNode)->Value()));
|
||||
|
||||
|
||||
//encodeOutput.GetDeviceId
|
||||
const size_t numIterationsBeforePrintingProgress = 100;
|
||||
//size_t numItersSinceLastPrintOfProgress = 0;
|
||||
|
@ -550,12 +551,51 @@ public:
|
|||
vector<Sequence> CurSequences, nextSequences;
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, encodeInputMatrices, actualMBSize, nullptr))
|
||||
{
|
||||
|
||||
//encode forward prop for whole utterance
|
||||
ComputationNetwork::BumpEvalTimeStamp(encodeInputNodes);
|
||||
|
||||
/*auto InputMBLayout = encodeInputNodes[0]->GetMBLayout();
|
||||
int rightsplice = InputMBLayout->RightSplice();
|
||||
fprintf(stderr, "right splice :%d\n", rightsplice);
|
||||
InputMBLayout->setRightSplice(20);*/
|
||||
//forward prop encoder network
|
||||
m_net->ForwardProp(encodeOutputNodes[0]);
|
||||
encodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(encodeOutputNodes[0])->Value()));
|
||||
|
||||
//do truncated
|
||||
size_t step = rightsplice;
|
||||
size_t curt = 0;
|
||||
|
||||
auto feainput = encodeInputMatrices.begin();
|
||||
Matrix<ElemType> FeaMatrix(deviceid);
|
||||
//auto InputTruncatedMatrix = ;
|
||||
FeaMatrix.SetValue(feainput->second.GetMatrix<ElemType>());
|
||||
|
||||
size_t frameNum = FeaMatrix.GetNumCols();
|
||||
size_t numRows = FeaMatrix.GetNumRows();
|
||||
|
||||
encodeOutput.Resize(encoderdim, frameNum);
|
||||
while (curt < frameNum)
|
||||
{
|
||||
m_net->ResetEvalTimeStamps();
|
||||
//m_net->ResetMBLayouts();
|
||||
size_t actualSize = min(curt+step*2, frameNum );
|
||||
feainput->second.GetMatrix<ElemType>().Resize(numRows, actualSize);
|
||||
feainput->second.pMBLayout->Init(1, actualSize);
|
||||
feainput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, actualSize);
|
||||
feainput->second.GetMatrix<ElemType>().SetValue(FeaMatrix.ColumnSlice(0, actualSize));
|
||||
ComputationNetwork::BumpEvalTimeStamp(encodeInputNodes);
|
||||
|
||||
m_net->ForwardProp(encodeOutputNodes[0]);
|
||||
Matrix<ElemType> encodeOutputSlice = (&dynamic_pointer_cast<ComputationNode<ElemType>>(encodeOutputNodes[0])->Value())->ColumnSlice(curt, min(step, frameNum - curt));
|
||||
|
||||
encodeOutput.SetColumnSlice(encodeOutputSlice, curt, min(step, frameNum - curt));
|
||||
curt += step;
|
||||
//
|
||||
}
|
||||
|
||||
|
||||
|
||||
//encodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(encodeOutputNodes[0])->Value()));
|
||||
//encodeOutput.Print("encodeoutput");
|
||||
dataReader.DataEnd();
|
||||
|
||||
|
@ -622,7 +662,9 @@ public:
|
|||
deleteSeq(*maxSeq);
|
||||
CurSequences.erase(maxSeq);
|
||||
forward_decode(tempSeq, decodeinputMatrices, deviceid, decodeOutputNodes, decodeinputNodes, vocabSize, tempSeq.labelseq.size());
|
||||
forwardmerged(tempSeq, t, sumofENandDE, encodeOutput, decodeOutput, PlusNode, PlusTransNode, Plusnodes, Plustransnodes,Wm, Wm2, bm);
|
||||
|
||||
forwardmerged(tempSeq, t, sumofENandDE, encodeOutput, decodeOutput, PlusNode, PlusTransNode, Plusnodes, Plustransnodes, Wm, bm);
|
||||
|
||||
|
||||
//sumofENandDE.Print("sum");
|
||||
//sort log posterior and get best N labels
|
||||
|
@ -684,9 +726,8 @@ public:
|
|||
if (topN[iLabel].first != blankId)
|
||||
{
|
||||
extendSeq(seqK, topN[iLabel].first, newlogP);
|
||||
CurSequences.push_back(seqK);
|
||||
CurSequences.push_back(seqK);
|
||||
}
|
||||
|
||||
}
|
||||
vector<pair<size_t, ElemType>>().swap(topN);
|
||||
//delete topN;
|
||||
|
|
Загрузка…
Ссылка в новой задаче