This commit is contained in:
Rui Zhao (SPEECH) 2019-10-28 09:43:03 -07:00
Родитель 9800670b2b
Коммит 58abf438b2
3 изменённых файлов: 187 добавлений и 116 удалений

Просмотреть файл

@ -255,6 +255,8 @@ void DoWriteOutput(const ConfigParameters& config)
size_t decodeExpandBeam = writerConfig(L"decode_expand_beam", 20);
string indictfile = writerConfig(L"DictFile", L"");
ElemType thresh = writerConfig(L"Thresh", 0.68f);
size_t rightsplice = writerConfig(L"rightsplice", 20);
size_t encoderdim = writerConfig(L"encoderdim", 640);
DataWriter testDataWriter(writerConfig);
// ConfigParameters
if (decodeType == 0)
@ -262,7 +264,7 @@ void DoWriteOutput(const ConfigParameters& config)
else if (decodeType == 1)
writer.WriteOutput_greedy(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, writerUnittest);
else if (decodeType == 2)
writer.WriteOutput_beam(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, writerUnittest, decodeBeam, decodeExpandBeam, indictfile,thresh);
writer.WriteOutput_beam(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, writerUnittest, decodeBeam, decodeExpandBeam, indictfile,thresh, rightsplice,encoderdim);
}
else if (config.Exists("outputPath"))

Просмотреть файл

@ -13,7 +13,12 @@
#include "Basics.h"
#include "Matrix.h"
namespace Microsoft { namespace MSR { namespace CNTK {
namespace Microsoft
{
namespace MSR
{
namespace CNTK
{
// Forward declarations
class FrameRange;
@ -101,7 +106,10 @@ struct MBLayout
{
return seqId == other.seqId && s == other.s && tBegin == other.tBegin && tEnd == other.tEnd;
}
size_t GetNumTimeSteps() const { return (size_t)(tEnd - tBegin); }
size_t GetNumTimeSteps() const
{
return (size_t)(tEnd - tBegin);
}
};
// -------------------------------------------------------------------
@ -121,9 +129,9 @@ struct MBLayout
// copy the content of another MBLayoutPtr over
// Use this instead of actual assignment to make it super-obvious that this is not copying the pointer but actual content. The pointer is kept fixed.
// Use "keepName" if the "identity" of the target is to be preserved, e.g.
// Use "keepName" if the "identity" of the target is to be preserved, e.g.
// while copying from reader space to network space.
void CopyFrom(const MBLayoutPtr& other, bool keepName=false)
void CopyFrom(const MBLayoutPtr &other, bool keepName = false)
{
m_numTimeSteps = other->m_numTimeSteps;
m_numParallelSequences = other->m_numParallelSequences;
@ -148,7 +156,7 @@ struct MBLayout
}
// Destructive copy that steals ownership if the content, like std::move()
// Note: For some reason the VC++ compiler does not generate the
// Note: For some reason the VC++ compiler does not generate the
// move assignment and we have to do this ourselves
void MoveFrom(MBLayoutPtr other)
{
@ -173,8 +181,8 @@ struct MBLayout
m_axisName = std::move(other->m_axisName);
}
MBLayout(const MBLayout&) = delete;
MBLayout& operator=(const MBLayout&) = delete;
MBLayout(const MBLayout &) = delete;
MBLayout &operator=(const MBLayout &) = delete;
public:
// resize and reset all frames to None (note: this is an invalid state and must be fixed by caller afterwards)
@ -183,7 +191,7 @@ public:
// remember the dimensions
m_numParallelSequences = numParallelSequences;
m_numTimeSteps = numTimeSteps;
if (deepInit)
if (deepInit)
{
m_distanceToStart.Resize(m_numParallelSequences, m_numTimeSteps);
m_distanceToEnd.Resize(m_numParallelSequences, m_numTimeSteps);
@ -211,10 +219,10 @@ public:
// - width: maximum width of structure; set to maximum over sequence lengths
// - inputSequences: vector of input SequenceInfo records (only seqId and GetNumTimeSteps() are used)
// - placement, rowAllocations: temp buffers (passed in to be able to optimize memory allocations)
template<typename SequenceInfoVector>
void InitAsPackedSequences(const SequenceInfoVector& inputSequences,
/*temp buffer*/std::vector<std::pair<size_t, size_t>>& placement,
/*temp buffer*/std::vector<size_t> rowAllocations)
template <typename SequenceInfoVector>
void InitAsPackedSequences(const SequenceInfoVector &inputSequences,
/*temp buffer*/ std::vector<std::pair<size_t, size_t>> &placement,
/*temp buffer*/ std::vector<size_t> rowAllocations)
{
placement.resize(inputSequences.size()); // [sequence index] result goes here (entries are invalid for gaps)
// determine width of MBLayout
@ -227,7 +235,7 @@ public:
width = inputSequences[i].GetNumTimeSteps();
}
// allocate
rowAllocations.clear(); // [row] we build rows one by one
rowAllocations.clear(); // [row] we build rows one by one
for (size_t i = 0; i < inputSequences.size(); i++)
{
if (inputSequences[i].seqId == GAP_SEQUENCE_ID)
@ -253,36 +261,48 @@ public:
{
if (inputSequences[i].seqId == GAP_SEQUENCE_ID)
continue;
size_t s, tBegin; tie
(s, tBegin) = placement[i];
AddSequence(inputSequences[i].seqId, s, (ptrdiff_t)tBegin, tBegin + inputSequences[i].GetNumTimeSteps());
size_t s, tBegin;
tie(s, tBegin) = placement[i];
AddSequence(inputSequences[i].seqId, s, (ptrdiff_t) tBegin, tBegin + inputSequences[i].GetNumTimeSteps());
}
// need to fill the gaps as well
for (size_t s = 0; s < rowAllocations.size(); s++)
AddGap(s, (size_t)rowAllocations[s], width);
AddGap(s, (size_t) rowAllocations[s], width);
}
// -------------------------------------------------------------------
// accessors
// -------------------------------------------------------------------
size_t GetNumTimeSteps() const { return m_numTimeSteps; }
size_t GetNumParallelSequences() const { return m_numParallelSequences; }
size_t GetNumTimeSteps() const
{
return m_numTimeSteps;
}
size_t GetNumParallelSequences() const
{
return m_numParallelSequences;
}
size_t GetNumSequences() const
{
return std::count_if(m_sequences.begin(), m_sequences.end(), [](const SequenceInfo& sequence) {
return std::count_if(m_sequences.begin(), m_sequences.end(), [](const SequenceInfo &sequence) {
return sequence.seqId != GAP_SEQUENCE_ID;
});
}
// axis names are for now only a debugging aid
// In the future, there will be a mechanism to denote that axes are meant to be the same.
const wchar_t* GetAxisName() const { return m_axisName.c_str(); }
void SetAxisName(const std::wstring& name) { m_axisName = name; }
const wchar_t *GetAxisName() const
{
return m_axisName.c_str();
}
void SetAxisName(const std::wstring &name)
{
m_axisName = name;
}
void SetUniqueAxisName(std::wstring name) // helper for constructing
{
// Unfortunatelly, initialization of local static variables is not thread-safe in VS2013.
// As workaround, it is moved to the struct level.
// As workaround, it is moved to the struct level.
// Todo: when upgraded to VS2013, change back to use the local static mutex, and remove also Sequences.cpp.
// The mutex is need to make access to nameIndices be thread-safe.
// static std::mutex nameIndiciesMutex;
@ -297,7 +317,7 @@ public:
}
if (index > 0)
name += msra::strfun::wstrprintf(L"%d", (int)index);
name += msra::strfun::wstrprintf(L"%d", (int) index);
SetAxisName(name);
}
@ -309,13 +329,13 @@ public:
// Get the number of frames of the input sequence that belong to the MB, i.e. disregarding sequence elements that are outside of the MB boundaries
// Input sequence is expected to belong to this MBLayout
size_t GetNumSequenceFramesInCurrentMB(const SequenceInfo& sequenceInfo) const
size_t GetNumSequenceFramesInCurrentMB(const SequenceInfo &sequenceInfo) const
{
return min(sequenceInfo.tEnd, GetNumTimeSteps()) - max(sequenceInfo.tBegin, (ptrdiff_t)0);
return min(sequenceInfo.tEnd, GetNumTimeSteps()) - max(sequenceInfo.tBegin, (ptrdiff_t) 0);
}
// return all sequences stored in this minibatch
const vector<SequenceInfo>& GetAllSequences() const
const vector<SequenceInfo> &GetAllSequences() const
{
return m_sequences;
}
@ -324,10 +344,10 @@ public:
// This is used by MeanNode and InvStdDevNode, and by statistics reporting.
size_t GetActualNumSamples() const;
const Matrix<char>& GetColumnsValidityMask(DEVICEID_TYPE deviceId) const;
const Matrix<char> &GetColumnsValidityMask(DEVICEID_TYPE deviceId) const;
// compare whether two layouts are the same
bool operator==(const MBLayout& other) const
bool operator==(const MBLayout &other) const
{
if (this == &other)
return true;
@ -352,7 +372,7 @@ public:
{
if (!first)
s << ", ";
s << "{seqId:" << seq.seqId << ", s:" << seq.s <<", begin:" << seq.tBegin << ", end:" << seq.tEnd << "}";
s << "{seqId:" << seq.seqId << ", s:" << seq.s << ", begin:" << seq.tBegin << ", end:" << seq.tEnd << "}";
first = false;
}
s << "]}";
@ -403,7 +423,7 @@ public:
if (seqId == GAP_SEQUENCE_ID)
{
m_numGapFrames += (e - b);
if (initDistances)
if (initDistances)
{
for (size_t t = b; t < e; t++)
{
@ -412,24 +432,23 @@ public:
}
}
}
else
if (initDistances)
else if (initDistances)
{
for (size_t t = b; t < e; t++)
{
for (size_t t = b; t < e; t++)
{
// update the nearest sentence boundaries, minimum over all parallel sequences
// If 0, then we are on a boundary. If not 0, we can still test in presence of FrameRange.m_timeOffset.
ptrdiff_t distanceToStart = (ptrdiff_t)t - beginTime;
ptrdiff_t distanceToEnd = (ptrdiff_t)(endTime - 1 - t);
m_distanceToStart(s, t) = (float)distanceToStart;
m_distanceToEnd(s, t) = (float)distanceToEnd;
// and the aggregate
if (m_distanceToNearestStart[t] > distanceToStart)
m_distanceToNearestStart[t] = distanceToStart;
if (m_distanceToNearestEnd[t] > distanceToEnd)
m_distanceToNearestEnd[t] = distanceToEnd;
}
// update the nearest sentence boundaries, minimum over all parallel sequences
// If 0, then we are on a boundary. If not 0, we can still test in presence of FrameRange.m_timeOffset.
ptrdiff_t distanceToStart = (ptrdiff_t) t - beginTime;
ptrdiff_t distanceToEnd = (ptrdiff_t)(endTime - 1 - t);
m_distanceToStart(s, t) = (float) distanceToStart;
m_distanceToEnd(s, t) = (float) distanceToEnd;
// and the aggregate
if (m_distanceToNearestStart[t] > distanceToStart)
m_distanceToNearestStart[t] = distanceToStart;
if (m_distanceToNearestEnd[t] > distanceToEnd)
m_distanceToNearestEnd[t] = distanceToEnd;
}
}
}
// short-hand to initialize an MBLayout for the common case of frame mode
@ -471,7 +490,7 @@ public:
}
// find a sequence by its id
const SequenceInfo& FindSequence(UniqueSequenceId seqId) const
const SequenceInfo &FindSequence(UniqueSequenceId seqId) const
{
for (const auto &seqInfo : m_sequences)
if (seqInfo.seqId == seqId)
@ -481,13 +500,13 @@ public:
// find a sequence by SequenceInfo array and position
// Use this if sequences may be matching 1:1.
const SequenceInfo& FindMatchingSequence(const vector<SequenceInfo>& querySequences, size_t i) const
const SequenceInfo &FindMatchingSequence(const vector<SequenceInfo> &querySequences, size_t i) const
{
// TODO: What are our sorted-ness guarantees?
let seqId = querySequences[i].seqId; // the seq id we are looking for
if (seqId == GAP_SEQUENCE_ID)
LogicError("FindMatchingSequence: Cannot be applied go gaps.");
if (seqId == m_sequences[i].seqId) // if both sequence arrays match 1:1 then we found it
if (seqId == m_sequences[i].seqId) // if both sequence arrays match 1:1 then we found it
return m_sequences[i];
else
return FindSequence(seqId);
@ -508,10 +527,15 @@ public:
return m_rightSplice;
}
void setRightSplice(int rightsplice)
{
m_rightSplice = rightsplice;
}
// test boundary flags for a specific condition
bool IsBeyondStartOrEnd(const FrameRange& fr) const;
bool IsGap(const FrameRange& fr) const;
bool IsBeyondMinibatch(const FrameRange& fr) const;
bool IsBeyondStartOrEnd(const FrameRange &fr) const;
bool IsGap(const FrameRange &fr) const;
bool IsBeyondMinibatch(const FrameRange &fr) const;
// test whether at least one sequence crosses the bounds of this minibatch
bool HasSequenceBeyondBegin() const
@ -540,29 +564,29 @@ public:
// -------------------------------------------------------------------
// get the matrix-column index for a given time step in a given sequence
size_t GetColumnIndex(const SequenceInfo& seq, size_t t) const
size_t GetColumnIndex(const SequenceInfo &seq, size_t t) const
{
if (t > seq.GetNumTimeSteps())
LogicError("GetColumnIndex: t out of sequence bounds.");
if (seq.s > GetNumParallelSequences())
LogicError("GetColumnIndex: seq.s out of sequence bounds."); // can only happen if 'seq' does not come out of our own m_sequences array, which is verboten
ptrdiff_t tIn = (ptrdiff_t)t + seq.tBegin; // shifted time index
if (tIn < 0 || (size_t)tIn >= GetNumTimeSteps())
ptrdiff_t tIn = (ptrdiff_t) t + seq.tBegin; // shifted time index
if (tIn < 0 || (size_t) tIn >= GetNumTimeSteps())
LogicError("GetColumnIndex: Attempted to access a time step that is accessing a portion of a sequence that is not included in current minibatch."); // we may encounter this for truncated BPTT
size_t col = (size_t)tIn * GetNumParallelSequences() + seq.s;
size_t col = (size_t) tIn * GetNumParallelSequences() + seq.s;
assert(col < GetNumCols());
return col;
}
// get the matrix-column indices for a given sequence
// sequence is expected to belong to this MB
vector<size_t> GetColumnIndices(const SequenceInfo& seq) const
vector<size_t> GetColumnIndices(const SequenceInfo &seq) const
{
size_t numFrames = GetNumSequenceFramesInCurrentMB(seq);
vector<size_t> res;
res.reserve(numFrames);
for (size_t i = 0; i < numFrames;++i)
res.push_back(GetColumnIndex(seq,i));
for (size_t i = 0; i < numFrames; ++i)
res.push_back(GetColumnIndex(seq, i));
return res;
}
@ -653,7 +677,6 @@ private:
static std::map<std::wstring, size_t> s_nameIndices;
public:
// special accessor for sequence training --TODO: must be replaced by a different mechanism
bool IsEnd(size_t s, size_t t) const
{
@ -781,33 +804,33 @@ public:
return ret;
}
std::pair<size_t,size_t> GetSequenceRange() const
std::pair<size_t, size_t> GetSequenceRange() const
{
if (!m_pMBLayout) return
make_pair(0, 1);
else if (seqIndex == SIZE_MAX) return
make_pair(0, m_pMBLayout->GetNumParallelSequences());
else return
make_pair(seqIndex, seqIndex + 1);
if (!m_pMBLayout)
return make_pair(0, 1);
else if (seqIndex == SIZE_MAX)
return make_pair(0, m_pMBLayout->GetNumParallelSequences());
else
return make_pair(seqIndex, seqIndex + 1);
}
std::pair<size_t, size_t> GetTimeRange() const
{
if (!m_pMBLayout) return
make_pair(0, 1);
else if (IsAllFrames()) return
make_pair(0, m_pMBLayout->GetNumTimeSteps());
else return
make_pair(timeIdxInSeq + m_timeOffset, timeIdxInSeq + m_timeOffset + m_timeRange);
if (!m_pMBLayout)
return make_pair(0, 1);
else if (IsAllFrames())
return make_pair(0, m_pMBLayout->GetNumTimeSteps());
else
return make_pair(timeIdxInSeq + m_timeOffset, timeIdxInSeq + m_timeOffset + m_timeRange);
}
bool IsOneColumnWrt(const shared_ptr<MBLayout> &pMBLayout) const
{
if (!pMBLayout) return
true; // target has no layout: This would broadcast.
else return
(pMBLayout->GetNumTimeSteps() == 1 || (!IsAllFrames() && m_timeRange == 1)) &&
(pMBLayout->GetNumParallelSequences() == 1 || seqIndex != SIZE_MAX);
if (!pMBLayout)
return true; // target has no layout: This would broadcast.
else
return (pMBLayout->GetNumTimeSteps() == 1 || (!IsAllFrames() && m_timeRange == 1)) &&
(pMBLayout->GetNumParallelSequences() == 1 || seqIndex != SIZE_MAX);
}
bool IsBatchMatmul(const shared_ptr<MBLayout> &pMBLayout) const
@ -816,7 +839,7 @@ public:
return false;
else
return (pMBLayout->GetNumTimeSteps() > 1 && (IsAllFrames() || m_timeRange > 1)) ||
(pMBLayout->GetNumParallelSequences() > 1);
(pMBLayout->GetNumParallelSequences() > 1);
}
// code that can only handle single-frame ranges will call t() to get the time index, which will throw if numFrames != 1
@ -882,16 +905,16 @@ inline bool MBLayout::IsGap(const FrameRange &fr) const
}
// test whether frame is exceeding the bounds of the MB
inline bool MBLayout::IsBeyondMinibatch(const FrameRange& fr) const
inline bool MBLayout::IsBeyondMinibatch(const FrameRange &fr) const
{
CheckIsValid();
if (fr.IsAllFrames())
LogicError("MBLayout::IsBeyondStartOrEnd() cannot be applied to FrameRange that specifies more than a single time step.");
const auto beginTime = (ptrdiff_t)fr.timeIdxInSeq + fr.m_timeOffset; // we test off the frame with offset
const auto endTime = beginTime + (ptrdiff_t)fr.m_timeRange;
return beginTime < 0 || endTime > (ptrdiff_t)GetNumTimeSteps();
const auto beginTime = (ptrdiff_t) fr.timeIdxInSeq + fr.m_timeOffset; // we test off the frame with offset
const auto endTime = beginTime + (ptrdiff_t) fr.m_timeRange;
return beginTime < 0 || endTime > (ptrdiff_t) GetNumTimeSteps();
}
// test whether frame is exceeding the sentence boundaries
@ -938,14 +961,17 @@ inline bool MBLayout::IsBeyondStartOrEnd(const FrameRange &fr) const
}
// TODO: Remove this version (with sanity checks) after this has been tested. Then the function can be inlined above.
inline size_t MBLayout::GetActualNumSamples() const { return m_numFramesDeclared - m_numGapFrames; }
inline size_t MBLayout::GetActualNumSamples() const
{
return m_numFramesDeclared - m_numGapFrames;
}
// return m_columnsValidityMask(,), which is lazily created here upon first call
// only called from MaskMissingColumnsTo()
// Update: also called from GatherNode::BackpropToNonLooping().
// Update: also called from GatherNode::BackpropToNonLooping().
// TODO: Can probably be faster by using the sequence array directly.
// TODO: Or should we just blast m_distanceToStart to GPU, and maks based on that? It is small compared to features.
inline const Matrix<char>& MBLayout::GetColumnsValidityMask(DEVICEID_TYPE deviceId) const
inline const Matrix<char> &MBLayout::GetColumnsValidityMask(DEVICEID_TYPE deviceId) const
{
CheckIsValid();
// lazily compute the validity mask
@ -1152,7 +1178,7 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
result.second = shape;
// get position of time and sequence index
const size_t iterDim = shape.size() -1; // valid if data has MBLayout
const size_t iterDim = shape.size() - 1; // valid if data has MBLayout
// MBLayout of data and of FrameRange must be identical pointers,
// or in case of broadcasting, respective parent pointers.
@ -1171,7 +1197,7 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
LogicError("TensorSliceWithMBLayoutFor: FrameRange has no layout, incompatible with data's layout: %s",
static_cast<string>(*(pMBLayout)).c_str());
else
LogicError("TensorSliceWithMBLayoutFor: FrameRange's dynamic axis is inconsistent with data: %s vs. %s",
LogicError("TensorSliceWithMBLayoutFor: FrameRange's dynamic axis is inconsistent with data: %s vs. %s",
static_cast<string>(*(fr.m_pMBLayout)).c_str(), static_cast<string>(*(pMBLayout)).c_str());
}
// if FrameRange refers to whole minibatch (map mode)
@ -1201,7 +1227,7 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
}
// sequence index
if (fr.seqIndex != SIZE_MAX) // sequence requested?
if (fr.seqIndex != SIZE_MAX) // sequence requested?
{
if (pMBLayout) // (if no layout then broadcast to all sequences)
{
@ -1211,8 +1237,8 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
size_t s = fr.seqIndex;
if (s >= result.second[sequenceDim])
LogicError("TensorSliceWithMBLayoutFor: FrameRange specifies a parallel-sequence index that is out of range.");
result.first[sequenceDim] = (ElemType)s;
result.second[sequenceDim] = (ElemType)s + 1;
result.first[sequenceDim] = (ElemType) s;
result.second[sequenceDim] = (ElemType) s + 1;
}
}
}
@ -1233,13 +1259,13 @@ static inline std::pair<DimensionVector, DimensionVector> TensorSliceWithMBLayou
// 'Reduce' style operations--the criterion nodes and gradient computation--call this.
// Warning: The layout used here must match the matrix. E.g. don't pass a child's matrix from a criterion node (use Input(x)->MaskMissing{Values,Gradient}ColumnsToZero() instead.
template <class ElemType>
static inline void MaskMissingColumnsTo(Matrix<ElemType>& matrixToMask, const MBLayoutPtr& pMBLayout, const FrameRange& fr, ElemType val)
static inline void MaskMissingColumnsTo(Matrix<ElemType> &matrixToMask, const MBLayoutPtr &pMBLayout, const FrameRange &fr, ElemType val)
{
if (pMBLayout && (pMBLayout->HasGaps(fr) || pMBLayout->HasRightSplice()))
{
const auto& maskMatrix = pMBLayout->GetColumnsValidityMask(matrixToMask.GetDeviceId());
const auto &maskMatrix = pMBLayout->GetColumnsValidityMask(matrixToMask.GetDeviceId());
maskMatrix.TransferToDeviceIfNotThere(matrixToMask.GetDeviceId(), /*ismoved=*/ false, /*emptyTransfer=*/ false, /*updatePreferredDevice=*/ false);
maskMatrix.TransferToDeviceIfNotThere(matrixToMask.GetDeviceId(), /*ismoved=*/false, /*emptyTransfer=*/false, /*updatePreferredDevice=*/false);
auto maskSlice = DataWithMBLayoutFor(maskMatrix, fr, pMBLayout);
auto matrixSliceToMask = DataWithMBLayoutFor(matrixToMask, fr, pMBLayout);
@ -1250,4 +1276,6 @@ static inline void MaskMissingColumnsTo(Matrix<ElemType>& matrixToMask, const MB
}
}
}}}
} // namespace CNTK
} // namespace MSR
} // namespace Microsoft

Просмотреть файл

@ -394,7 +394,7 @@ public:
vector<pair<size_t, ElemType>> datapair;
typedef vector<pair<size_t, ElemType>>::value_type ValueType;
ElemType* probdata = prob.CopyToArray();
for (size_t n = 0; n < prob.GetNumRows(); n++)
{
datapair.push_back(ValueType(n, probdata[n]));
@ -431,8 +431,10 @@ public:
return true;
}
void forwardmerged(Sequence a, size_t t, Matrix<ElemType>& sumofENandDE, Matrix<ElemType>& encodeOutput, Matrix<ElemType>& decodeOutput, ComputationNodeBasePtr PlusNode,
ComputationNodeBasePtr PlusTransNode, std::vector<ComputationNodeBasePtr> Plusnodes, std::vector<ComputationNodeBasePtr> Plustransnodes, Matrix<ElemType>& Wm, Matrix<ElemType>& Wm2, Matrix<ElemType>& bm)
void forwardmerged(Sequence a, size_t t, Matrix<ElemType>& sumofENandDE, Matrix<ElemType>& encodeOutput, Matrix<ElemType>& decodeOutput, ComputationNodeBasePtr PlusNode,
ComputationNodeBasePtr PlusTransNode, std::vector<ComputationNodeBasePtr> Plusnodes, std::vector<ComputationNodeBasePtr> Plustransnodes, Matrix<ElemType>& Wm, Matrix<ElemType>& bm)
{
/*auto edNode = PlusNode->As<PlusBroadcastNode<ElemType>>();
if (edNode->getCombineMode() == 1)
@ -456,14 +458,14 @@ public:
m_net->ForwardPropFromTo(Plusnodes, Plustransnodes);
decodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(PlusTransNode)->Value()));
tempMatrix.AssignProductOf(Wm, true, decodeOutput, false);
tempMatrix2.AssignProductOf(Wm2, true, tempMatrix, false);
decodeOutput.AssignSumOf(tempMatrix2, bm);
decodeOutput.AssignSumOf(tempMatrix, bm);
//decodeOutput.VectorMax(maxIdx, maxVal, true);
decodeOutput.InplaceLogSoftmax(true);
decodeOutput.InplaceLogSoftmax(true);
}
void WriteOutput_beam(IDataReader& dataReader, size_t mbSize, IDataWriter& dataWriter, const std::vector<std::wstring>& outputNodeNames,
size_t numOutputSamples = requestDataSize, bool doWriterUnitTest = false, size_t beamSize = 10, size_t expandBeam = 20, string dictfile = L"", ElemType thresh = 0.68)
size_t numOutputSamples = requestDataSize, bool doWriterUnitTest = false, size_t beamSize = 10, size_t expandBeam = 20, string dictfile = L"", ElemType thresh = 0.68,
size_t rightsplice = 20, size_t encoderdim = 640)
{
ScopedNetworkOperationMode modeGuard(m_net, NetworkOperationMode::inferring);
@ -509,11 +511,12 @@ public:
//get merged input
ComputationNodeBasePtr PlusNode = m_net->GetNodeFromName(outputNodeNames[2]);
ComputationNodeBasePtr PlusTransNode = m_net->GetNodeFromName(outputNodeNames[3]);
ComputationNodeBasePtr WmNode = m_net->GetNodeFromName(outputNodeNames[4]);
ComputationNodeBasePtr Wm2Node = m_net->GetNodeFromName(outputNodeNames[5]);
ComputationNodeBasePtr bmNode = m_net->GetNodeFromName(outputNodeNames[6]);
ComputationNodeBasePtr bmNode = m_net->GetNodeFromName(outputNodeNames[5]);
//StreamMinibatchInputs PlusinputMatrices =
std::vector<ComputationNodeBasePtr> Plusnodes, Plustransnodes;
@ -539,10 +542,8 @@ public:
//for merged RNNT node
Matrix<ElemType> Wm(deviceid), Wm2(deviceid), bm(deviceid);
Wm.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(WmNode)->Value()));
Wm2.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(Wm2Node)->Value()));
bm.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(bmNode)->Value()));
//encodeOutput.GetDeviceId
const size_t numIterationsBeforePrintingProgress = 100;
//size_t numItersSinceLastPrintOfProgress = 0;
@ -550,12 +551,51 @@ public:
vector<Sequence> CurSequences, nextSequences;
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, encodeInputMatrices, actualMBSize, nullptr))
{
//encode forward prop for whole utterance
ComputationNetwork::BumpEvalTimeStamp(encodeInputNodes);
/*auto InputMBLayout = encodeInputNodes[0]->GetMBLayout();
int rightsplice = InputMBLayout->RightSplice();
fprintf(stderr, "right splice :%d\n", rightsplice);
InputMBLayout->setRightSplice(20);*/
//forward prop encoder network
m_net->ForwardProp(encodeOutputNodes[0]);
encodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(encodeOutputNodes[0])->Value()));
//do truncated
size_t step = rightsplice;
size_t curt = 0;
auto feainput = encodeInputMatrices.begin();
Matrix<ElemType> FeaMatrix(deviceid);
//auto InputTruncatedMatrix = ;
FeaMatrix.SetValue(feainput->second.GetMatrix<ElemType>());
size_t frameNum = FeaMatrix.GetNumCols();
size_t numRows = FeaMatrix.GetNumRows();
encodeOutput.Resize(encoderdim, frameNum);
while (curt < frameNum)
{
m_net->ResetEvalTimeStamps();
//m_net->ResetMBLayouts();
size_t actualSize = min(curt+step*2, frameNum );
feainput->second.GetMatrix<ElemType>().Resize(numRows, actualSize);
feainput->second.pMBLayout->Init(1, actualSize);
feainput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, actualSize);
feainput->second.GetMatrix<ElemType>().SetValue(FeaMatrix.ColumnSlice(0, actualSize));
ComputationNetwork::BumpEvalTimeStamp(encodeInputNodes);
m_net->ForwardProp(encodeOutputNodes[0]);
Matrix<ElemType> encodeOutputSlice = (&dynamic_pointer_cast<ComputationNode<ElemType>>(encodeOutputNodes[0])->Value())->ColumnSlice(curt, min(step, frameNum - curt));
encodeOutput.SetColumnSlice(encodeOutputSlice, curt, min(step, frameNum - curt));
curt += step;
//
}
//encodeOutput.SetValue(*(&dynamic_pointer_cast<ComputationNode<ElemType>>(encodeOutputNodes[0])->Value()));
//encodeOutput.Print("encodeoutput");
dataReader.DataEnd();
@ -622,7 +662,9 @@ public:
deleteSeq(*maxSeq);
CurSequences.erase(maxSeq);
forward_decode(tempSeq, decodeinputMatrices, deviceid, decodeOutputNodes, decodeinputNodes, vocabSize, tempSeq.labelseq.size());
forwardmerged(tempSeq, t, sumofENandDE, encodeOutput, decodeOutput, PlusNode, PlusTransNode, Plusnodes, Plustransnodes,Wm, Wm2, bm);
forwardmerged(tempSeq, t, sumofENandDE, encodeOutput, decodeOutput, PlusNode, PlusTransNode, Plusnodes, Plustransnodes, Wm, bm);
//sumofENandDE.Print("sum");
//sort log posterior and get best N labels
@ -684,9 +726,8 @@ public:
if (topN[iLabel].first != blankId)
{
extendSeq(seqK, topN[iLabel].first, newlogP);
CurSequences.push_back(seqK);
CurSequences.push_back(seqK);
}
}
vector<pair<size_t, ElemType>>().swap(topN);
//delete topN;