LUSequenceReader changes: use unk word if a word is not i no observed list. Support read word class information from LU Sequence Reader (mode=class)
This commit is contained in:
Родитель
26c87bfcfc
Коммит
0ba2bc05c0
|
@ -95,13 +95,24 @@ long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested,
|
|||
for (size_t i = 0; i < vstr.size() - 1; i++)
|
||||
{
|
||||
if (inputlabel2id.find(vstr[i]) == inputlabel2id.end())
|
||||
LogicError("cannot find item %ls in input label", vstr[i].c_str());
|
||||
|
||||
vtmp.push_back(inputlabel2id.find(vstr[i])->second);
|
||||
{
|
||||
if (inputlabel2id.find(mUnkStr) == inputlabel2id.end())
|
||||
{
|
||||
LogicError("cannot find item %ls and unk str %ls in input label", vstr[i].c_str(), mUnkStr.c_str());
|
||||
}
|
||||
vtmp.push_back(inputlabel2id.find(mUnkStr)->second);
|
||||
}
|
||||
else
|
||||
vtmp.push_back(inputlabel2id.find(vstr[i])->second);
|
||||
}
|
||||
if (outputlabel2id.find(vstr[vstr.size() - 1]) == outputlabel2id.end())
|
||||
LogicError("cannot find item %ls in output label", vstr[vstr.size() - 1].c_str());
|
||||
labels->push_back(outputlabel2id.find(vstr[vstr.size() - 1])->second);
|
||||
{
|
||||
if (outputlabel2id.find(mUnkStr) == outputlabel2id.end())
|
||||
LogicError("cannot find item %ls and unk str %ls in output label", vstr[vstr.size() - 1].c_str(), mUnkStr.c_str());
|
||||
labels->push_back(outputlabel2id.find(mUnkStr)->second);
|
||||
}
|
||||
else
|
||||
labels->push_back(outputlabel2id.find(vstr[vstr.size() - 1])->second);
|
||||
input->push_back(vtmp);
|
||||
if ((vstr[vstr.size() - 1] == m_endSequenceOut ||
|
||||
/// below is for backward support
|
||||
|
|
|
@ -133,6 +133,9 @@ typedef struct{
|
|||
template <typename NumType, typename LabelType = wstring>
|
||||
class LUBatchLUSequenceParser : public LUSequenceParser<NumType, LabelType>
|
||||
{
|
||||
public:
|
||||
wstring mUnkStr;
|
||||
|
||||
public:
|
||||
wifstream mFile;
|
||||
std::wstring mFileName;
|
||||
|
@ -160,7 +163,7 @@ public:
|
|||
mFile.close();
|
||||
}
|
||||
|
||||
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut)
|
||||
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut, wstring unkstr = "<UNK>")
|
||||
{
|
||||
assert(fileName != NULL);
|
||||
mFileName = fileName;
|
||||
|
@ -176,6 +179,8 @@ public:
|
|||
m_beginTag = m_beginSequenceIn;
|
||||
m_endTag = m_endSequenceIn;
|
||||
|
||||
mUnkStr = unkstr;
|
||||
|
||||
mFile.close();
|
||||
|
||||
mFile.open(fileName, wifstream::in);
|
||||
|
|
|
@ -34,9 +34,27 @@ long LUSequenceReader<ElemType>::GetIdFromLabel(const LabelType& labelValue, Lab
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void LUSequenceReader<ElemType>::ReadLabelInfo(const wstring & vocfile,
|
||||
BatchLUSequenceReader<ElemType>::~BatchLUSequenceReader()
|
||||
{
|
||||
if (m_labelTemp.size() > 0)
|
||||
m_labelTemp.clear();
|
||||
if (m_featureTemp.size() > 0)
|
||||
m_featureTemp.clear();
|
||||
for (int index = labelInfoMin; index < labelInfoMax; ++index)
|
||||
{
|
||||
delete[] m_labelInfo[index].m_id2classLocal;
|
||||
delete[] m_labelInfo[index].m_classInfoLocal;
|
||||
};
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchLUSequenceReader<ElemType>::ReadLabelInfo(const wstring & vocfile,
|
||||
map<wstring, long> & word4idx,
|
||||
map<long, wstring>& idx4word)
|
||||
bool readClass,
|
||||
map<wstring, long>& word4cls,
|
||||
map<long, wstring>& idx4word,
|
||||
map<long, long>& idx4class,
|
||||
int & mNbrCls)
|
||||
{
|
||||
char strFileName[MAX_STRING];
|
||||
wstring strtmp;
|
||||
|
@ -50,17 +68,86 @@ void LUSequenceReader<ElemType>::ReadLabelInfo(const wstring & vocfile,
|
|||
if (!vin.good())
|
||||
LogicError("LUSequenceReader cannot open %ls \n", strFileName);
|
||||
|
||||
wstring wstr = L" ";
|
||||
b = 0;
|
||||
nwords = 0;
|
||||
int prevcls = -1;
|
||||
|
||||
mNbrCls = 0;
|
||||
while (vin.good())
|
||||
{
|
||||
getline(vin, strtmp);
|
||||
strtmp = wtrim(strtmp);
|
||||
if (strtmp.length() == 0)
|
||||
break;
|
||||
word4idx[strtmp] = b;
|
||||
idx4word[b++] = strtmp;
|
||||
if (readClass)
|
||||
{
|
||||
vector<wstring> wordandcls = wsep_string(strtmp, wstr);
|
||||
long cls = (long)_wtoi(wordandcls[1].c_str());
|
||||
word4cls[wordandcls[0]] = cls;
|
||||
|
||||
idx4class[b] = cls;
|
||||
|
||||
if (idx4class[b] != prevcls)
|
||||
{
|
||||
if (idx4class[b] < prevcls)
|
||||
LogicError("LUSequenceReader: the word list needs to be grouped into classes and the classes indices need to be ascending.");
|
||||
prevcls = idx4class[b];
|
||||
}
|
||||
|
||||
word4idx[wordandcls[0]] = b;
|
||||
idx4word[b++] = wordandcls[0];
|
||||
if (mNbrCls < cls)
|
||||
mNbrCls = cls;
|
||||
}
|
||||
else {
|
||||
word4idx[strtmp] = b;
|
||||
idx4word[b++] = strtmp;
|
||||
}
|
||||
nwords++;
|
||||
}
|
||||
vin.close();
|
||||
|
||||
if (readClass)
|
||||
mNbrCls++;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchLUSequenceReader<ElemType>::GetClassInfo(LabelInfo& lblInfo)
|
||||
{
|
||||
if (lblInfo.m_clsinfoRead || lblInfo.mNbrClasses == 0) return;
|
||||
|
||||
// populate local CPU matrix
|
||||
if (lblInfo.m_id2classLocal == nullptr)
|
||||
lblInfo.m_id2classLocal = new Matrix<ElemType>(CPUDEVICE);
|
||||
if (lblInfo.m_classInfoLocal == nullptr)
|
||||
lblInfo.m_classInfoLocal = new Matrix<ElemType>(CPUDEVICE);
|
||||
|
||||
lblInfo.m_classInfoLocal->SwitchToMatrixType(MatrixType::DENSE, matrixFormatDense, false);
|
||||
lblInfo.m_classInfoLocal->Resize(2, lblInfo.mNbrClasses);
|
||||
|
||||
//move to CPU since element-wise operation is expensive and can go wrong in GPU
|
||||
int curDevId = lblInfo.m_classInfoLocal->GetDeviceId();
|
||||
lblInfo.m_classInfoLocal->TransferFromDeviceToDevice(curDevId, CPUDEVICE, true, false, false);
|
||||
|
||||
int clsidx;
|
||||
int prvcls = -1;
|
||||
for (size_t j = 0; j < nwords; j++)
|
||||
{
|
||||
clsidx = lblInfo.idx4class[(long)j];
|
||||
if (prvcls != clsidx)
|
||||
{
|
||||
if (prvcls >= 0)
|
||||
(*lblInfo.m_classInfoLocal)(1, prvcls) = (float)j;
|
||||
prvcls = clsidx;
|
||||
(*lblInfo.m_classInfoLocal)(0, prvcls) = (float)j;
|
||||
}
|
||||
}
|
||||
(*lblInfo.m_classInfoLocal)(1, prvcls) = (float)nwords;
|
||||
|
||||
lblInfo.m_classInfoLocal->TransferFromDeviceToDevice(CPUDEVICE, curDevId, true, false, false);
|
||||
|
||||
lblInfo.m_clsinfoRead = true;
|
||||
}
|
||||
|
||||
// GetIdFromLabel - get an Id from a Label
|
||||
|
@ -278,6 +365,8 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
|||
|
||||
m_labelInfo[index].isproposal = labelConfig[index]("isproposal", "false");
|
||||
|
||||
m_labelInfo[index].m_clsinfoRead = false;
|
||||
|
||||
// determine label type desired
|
||||
std::string labelType(labelConfig[index]("labelType", "Category"));
|
||||
if (labelType == "Category")
|
||||
|
@ -290,9 +379,24 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
|||
// if we have labels, we need a label Mapping file, it will be a file with one label per line
|
||||
if (m_labelInfo[index].type != labelNone)
|
||||
{
|
||||
string mode = labelConfig[index]("mode", "plain");//plain, class
|
||||
|
||||
m_labelInfo[index].m_classInfoLocal = nullptr;
|
||||
m_labelInfo[index].m_id2classLocal = nullptr;
|
||||
|
||||
if (mode == "class")
|
||||
{
|
||||
m_labelInfo[index].readerMode = ReaderMode::Class;
|
||||
}
|
||||
|
||||
std::wstring wClassFile = labelConfig[index]("token", "");
|
||||
if (wClassFile != L""){
|
||||
ReadLabelInfo(wClassFile, m_labelInfo[index].word4idx, m_labelInfo[index].idx4word);
|
||||
ReadLabelInfo(wClassFile, m_labelInfo[index].word4idx,
|
||||
m_labelInfo[index].readerMode == ReaderMode::Class,
|
||||
m_labelInfo[index].word4cls,
|
||||
m_labelInfo[index].idx4word, m_labelInfo[index].idx4class, m_labelInfo[index].mNbrClasses);
|
||||
|
||||
GetClassInfo(m_labelInfo[index]);
|
||||
}
|
||||
if (m_labelInfo[index].busewordmap)
|
||||
ChangeMaping(mWordMapping, mUnkStr, m_labelInfo[index].word4idx);
|
||||
|
@ -322,7 +426,7 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
|||
|
||||
const LabelInfo& labelIn = m_labelInfo[labelInfoIn];
|
||||
const LabelInfo& labelOut = m_labelInfo[labelInfoOut];
|
||||
m_parser.ParseInit(m_file.c_str(), labelIn.dim, labelOut.dim, labelIn.beginSequence, labelIn.endSequence, labelOut.beginSequence, labelOut.endSequence);
|
||||
m_parser.ParseInit(m_file.c_str(), labelIn.dim, labelOut.dim, labelIn.beginSequence, labelIn.endSequence, labelOut.beginSequence, labelOut.endSequence, mUnkStr);
|
||||
|
||||
mBlgSize = readerConfig("nbruttsineachrecurrentiter", "1");
|
||||
|
||||
|
@ -344,6 +448,7 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
|
|||
mAllowMultPassData = readerConfig("dataMultiPass", "false");
|
||||
|
||||
mIgnoreSentenceBeginTag = readerConfig("ignoresentencebegintag", "false");
|
||||
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -610,12 +715,13 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
|
|||
|
||||
for (i = (int)mLastPosInSentence; j < (int)mMaxSentenceLength; i++, j++)
|
||||
{
|
||||
mtSentenceBegin.SetValue(0, j, (ElemType)0);
|
||||
for (int k = 0; k < mToProcess.size(); k++)
|
||||
{
|
||||
size_t seq = mToProcess[k];
|
||||
|
||||
if (i == mLastPosInSentence)
|
||||
if (
|
||||
i == mLastPosInSentence /// the first time instance has sentence begining
|
||||
)
|
||||
{
|
||||
mSentenceBeginAt[k] = i;
|
||||
if (mIgnoreSentenceBeginTag == false) /// ignore sentence begin, this is used for decoder network reader, which carries activities from the encoder networks
|
||||
|
@ -794,7 +900,7 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
|
|||
|
||||
features.SetValue(locObs);
|
||||
|
||||
lablsize = GetLabelOutput(matrices, actualmbsize);
|
||||
lablsize = GetLabelOutput(matrices, m_labelInfo[labelInfoOut], actualmbsize);
|
||||
|
||||
// go to the next sequence
|
||||
m_seqIndex++;
|
||||
|
@ -814,9 +920,8 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
|
|||
|
||||
template<class ElemType>
|
||||
size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
|
||||
Matrix<ElemType>*>& matrices, size_t actualmbsize)
|
||||
Matrix<ElemType>*>& matrices, LabelInfo& labelInfo, size_t actualmbsize)
|
||||
{
|
||||
const LabelInfo& labelInfo = m_labelInfo[labelInfoOut];
|
||||
Matrix<ElemType>* labels = matrices[m_labelsName[labelInfoOut]];
|
||||
if (labels == nullptr) return 0;
|
||||
|
||||
|
@ -835,11 +940,29 @@ size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
|
|||
size_t utt_t = (size_t) floor(j / mSentenceBeginAt.size());
|
||||
|
||||
if (utt_t > mSentenceEndAt[utt_id]) continue;
|
||||
labels->SetValue(wrd, j, 1);
|
||||
if (labelInfo.readerMode == ReaderMode::Plain)
|
||||
labels->SetValue(wrd, j, 1);
|
||||
else if (labelInfo.readerMode == ReaderMode::Class && labelInfo.mNbrClasses > 0)
|
||||
{
|
||||
labels->SetValue(0, j, (ElemType)wrd);
|
||||
|
||||
long clsidx = -1;
|
||||
clsidx = labelInfo.idx4class[wrd];
|
||||
|
||||
labels->SetValue(1, j, (ElemType)clsidx);
|
||||
/// save the [begining ending_indx) of the class
|
||||
ElemType lft = (*labelInfo.m_classInfoLocal)(0, clsidx);
|
||||
ElemType rgt = (*labelInfo.m_classInfoLocal)(1, clsidx);
|
||||
if (rgt <= lft)
|
||||
LogicError("LUSequenceReader : right is equal or smaller than the left, which is wrong.");
|
||||
labels->SetValue(2, j, lft); /// begining index of the class
|
||||
labels->SetValue(3, j, rgt); /// end index of the class
|
||||
}
|
||||
else
|
||||
LogicError("LUSequenceReader: reader mode is not set to Plain. Or in the case of setting it to Class, the class number is 0. ");
|
||||
nbrLabl++;
|
||||
}
|
||||
|
||||
labels->TransferFromDeviceToDevice(CPUDEVICE, device, true);
|
||||
return nbrLabl;
|
||||
}
|
||||
|
||||
|
@ -1158,6 +1281,8 @@ void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType
|
|||
/// run for each reader
|
||||
vector<size_t> col;
|
||||
size_t rows = 0, cols = 0;
|
||||
if (mReader.size() > 1)
|
||||
LogicError("MultiIOBatchLUSequenceReader::SetSentenceSegBatch only supports processing from one BatchLUSequenceReader");
|
||||
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
|
||||
{
|
||||
(p->second)->SetSentenceSegBatch(sentenceBegin, sentenceExistBeginOrNolabels);
|
||||
|
@ -1170,21 +1295,6 @@ void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType
|
|||
col.push_back(this_col);
|
||||
cols += this_col;
|
||||
}
|
||||
|
||||
sentenceBegin.Resize(rows, cols);
|
||||
sentenceExistBeginOrNolabels.Resize(rows, cols);
|
||||
size_t i = 0, t = 0;
|
||||
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
|
||||
{
|
||||
Matrix<ElemType> mtmp(sentenceBegin.GetDeviceId());
|
||||
Matrix<ElemType> mtmp2(sentenceExistBeginOrNolabels.GetDeviceId());
|
||||
|
||||
(p->second)->SetSentenceSegBatch(mtmp, mtmp2);
|
||||
sentenceBegin.ColumnSlice(i, col[t]).SetValue(mtmp);
|
||||
sentenceExistBeginOrNolabels.ColumnSlice(i, col[t]).SetValue(mtmp2);
|
||||
i += col[t];
|
||||
t++;
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
|
|
|
@ -40,6 +40,12 @@ enum LabelKind
|
|||
labelOther = 3, // some other type of label
|
||||
};
|
||||
|
||||
enum ReaderMode
|
||||
{
|
||||
Plain = 0, // no class info
|
||||
Class = 1, // category labels, creates mapping tables
|
||||
};
|
||||
|
||||
template<class ElemType>
|
||||
class LUSequenceReader : public IDataReader<ElemType>
|
||||
{
|
||||
|
@ -129,6 +135,20 @@ protected:
|
|||
|
||||
bool isproposal; /// whether this is for proposal generation
|
||||
|
||||
ReaderMode readerMode;
|
||||
/**
|
||||
word class info saved in file in format below
|
||||
! 29
|
||||
# 58
|
||||
$ 26
|
||||
where the first column is the word and the second column is the class id, base 0
|
||||
*/
|
||||
map<wstring, long> word4cls;
|
||||
map<long, long> idx4class;
|
||||
Matrix<ElemType>* m_id2classLocal; // CPU version
|
||||
Matrix<ElemType>* m_classInfoLocal; // CPU version
|
||||
int mNbrClasses;
|
||||
bool m_clsinfoRead;
|
||||
} m_labelInfo[labelInfoMax];
|
||||
|
||||
// caching support
|
||||
|
@ -152,8 +172,6 @@ protected:
|
|||
|
||||
public:
|
||||
void Init(const ConfigParameters& ){};
|
||||
void ReadLabelInfo(const wstring & vocfile, map<LabelType, LabelIdType> & word4idx,
|
||||
map<LabelIdType, LabelType>& idx4word);
|
||||
void ChangeMaping(const map<LabelType, LabelType>& maplist,
|
||||
const LabelType& unkstr,
|
||||
map<LabelType, LabelIdType> & word4idx);
|
||||
|
@ -227,7 +245,6 @@ public:
|
|||
using LUSequenceReader<ElemType>::ChangeMaping;
|
||||
using LUSequenceReader<ElemType>::GetIdFromLabel;
|
||||
using LUSequenceReader<ElemType>::InitCache;
|
||||
using LUSequenceReader<ElemType>::ReadLabelInfo;
|
||||
using LUSequenceReader<ElemType>::mRandomize;
|
||||
using LUSequenceReader<ElemType>::m_seed;
|
||||
using LUSequenceReader<ElemType>::mTotalSentenceSofar;
|
||||
|
@ -249,7 +266,7 @@ private:
|
|||
public:
|
||||
vector<bool> mProcessed;
|
||||
LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
|
||||
BatchLUSequenceReader() {
|
||||
BatchLUSequenceReader() : mtSentenceBegin(CPUDEVICE), mtExistsSentenceBeginOrNoLabels(CPUDEVICE){
|
||||
mLastProcssedSentenceId = 0;
|
||||
mBlgSize = 1;
|
||||
mLastPosInSentence = 0;
|
||||
|
@ -259,12 +276,7 @@ public:
|
|||
mIgnoreSentenceBeginTag = false;
|
||||
}
|
||||
|
||||
~BatchLUSequenceReader() {
|
||||
if (m_labelTemp.size() > 0)
|
||||
m_labelTemp.clear();
|
||||
if (m_featureTemp.size() > 0)
|
||||
m_featureTemp.clear();
|
||||
};
|
||||
~BatchLUSequenceReader();
|
||||
|
||||
void Init(const ConfigParameters& readerConfig);
|
||||
void Reset();
|
||||
|
@ -279,7 +291,8 @@ public:
|
|||
void SetSentenceBegin(size_t wrd, size_t pos, size_t actualMbSize) { SetSentenceBegin((int)wrd, (int)pos, (int)actualMbSize); }
|
||||
void SetSentenceEnd(size_t wrd, size_t pos, size_t actualMbSize) { SetSentenceEnd((int)wrd, (int)pos, (int)actualMbSize); }
|
||||
|
||||
size_t GetLabelOutput(std::map<std::wstring, Matrix<ElemType>*>& matrices, size_t actualmbsize);
|
||||
size_t GetLabelOutput(std::map<std::wstring,
|
||||
Matrix<ElemType>*>& matrices, LabelInfo& labelInfo, size_t actualmbsize);
|
||||
|
||||
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
|
||||
bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
|
||||
|
@ -291,6 +304,15 @@ public:
|
|||
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, Matrix<ElemType>& sentenceExistsBeginOrNoLabels);
|
||||
|
||||
public:
|
||||
void GetClassInfo(LabelInfo& lblInfo);
|
||||
void ReadLabelInfo(const wstring & vocfile,
|
||||
map<wstring, long> & word4idx,
|
||||
bool readClass,
|
||||
map<wstring, long>& word4cls,
|
||||
map<long, wstring>& idx4word,
|
||||
map<long, long>& idx4class,
|
||||
int & mNbrCls);
|
||||
|
||||
void LoadWordMapping(const ConfigParameters& readerConfig);
|
||||
bool CanReadFor(wstring nodeName); /// return true if this reader can output for a node with name nodeName
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче