LUSequenceReader changes: use unk word if a word is not i no observed list. Support read word class information from LU Sequence Reader (mode=class)

This commit is contained in:
kaisheny 2015-05-24 23:14:58 -07:00
Родитель 26c87bfcfc
Коммит 0ba2bc05c0
4 изменённых файлов: 193 добавлений и 45 удалений

Просмотреть файл

@ -95,13 +95,24 @@ long LUBatchLUSequenceParser<NumType, LabelType>::Parse(size_t recordsRequested,
for (size_t i = 0; i < vstr.size() - 1; i++)
{
if (inputlabel2id.find(vstr[i]) == inputlabel2id.end())
LogicError("cannot find item %ls in input label", vstr[i].c_str());
vtmp.push_back(inputlabel2id.find(vstr[i])->second);
{
if (inputlabel2id.find(mUnkStr) == inputlabel2id.end())
{
LogicError("cannot find item %ls and unk str %ls in input label", vstr[i].c_str(), mUnkStr.c_str());
}
vtmp.push_back(inputlabel2id.find(mUnkStr)->second);
}
else
vtmp.push_back(inputlabel2id.find(vstr[i])->second);
}
if (outputlabel2id.find(vstr[vstr.size() - 1]) == outputlabel2id.end())
LogicError("cannot find item %ls in output label", vstr[vstr.size() - 1].c_str());
labels->push_back(outputlabel2id.find(vstr[vstr.size() - 1])->second);
{
if (outputlabel2id.find(mUnkStr) == outputlabel2id.end())
LogicError("cannot find item %ls and unk str %ls in output label", vstr[vstr.size() - 1].c_str(), mUnkStr.c_str());
labels->push_back(outputlabel2id.find(mUnkStr)->second);
}
else
labels->push_back(outputlabel2id.find(vstr[vstr.size() - 1])->second);
input->push_back(vtmp);
if ((vstr[vstr.size() - 1] == m_endSequenceOut ||
/// below is for backward support

Просмотреть файл

@ -133,6 +133,9 @@ typedef struct{
template <typename NumType, typename LabelType = wstring>
class LUBatchLUSequenceParser : public LUSequenceParser<NumType, LabelType>
{
public:
wstring mUnkStr;
public:
wifstream mFile;
std::wstring mFileName;
@ -160,7 +163,7 @@ public:
mFile.close();
}
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut)
void ParseInit(LPCWSTR fileName, size_t dimLabelsIn, size_t dimLabelsOut, wstring beginSequenceIn, wstring endSequenceIn, wstring beginSequenceOut, wstring endSequenceOut, wstring unkstr = "<UNK>")
{
assert(fileName != NULL);
mFileName = fileName;
@ -176,6 +179,8 @@ public:
m_beginTag = m_beginSequenceIn;
m_endTag = m_endSequenceIn;
mUnkStr = unkstr;
mFile.close();
mFile.open(fileName, wifstream::in);

Просмотреть файл

@ -34,9 +34,27 @@ long LUSequenceReader<ElemType>::GetIdFromLabel(const LabelType& labelValue, Lab
}
template<class ElemType>
void LUSequenceReader<ElemType>::ReadLabelInfo(const wstring & vocfile,
BatchLUSequenceReader<ElemType>::~BatchLUSequenceReader()
{
if (m_labelTemp.size() > 0)
m_labelTemp.clear();
if (m_featureTemp.size() > 0)
m_featureTemp.clear();
for (int index = labelInfoMin; index < labelInfoMax; ++index)
{
delete[] m_labelInfo[index].m_id2classLocal;
delete[] m_labelInfo[index].m_classInfoLocal;
};
}
template<class ElemType>
void BatchLUSequenceReader<ElemType>::ReadLabelInfo(const wstring & vocfile,
map<wstring, long> & word4idx,
map<long, wstring>& idx4word)
bool readClass,
map<wstring, long>& word4cls,
map<long, wstring>& idx4word,
map<long, long>& idx4class,
int & mNbrCls)
{
char strFileName[MAX_STRING];
wstring strtmp;
@ -50,17 +68,86 @@ void LUSequenceReader<ElemType>::ReadLabelInfo(const wstring & vocfile,
if (!vin.good())
LogicError("LUSequenceReader cannot open %ls \n", strFileName);
wstring wstr = L" ";
b = 0;
nwords = 0;
int prevcls = -1;
mNbrCls = 0;
while (vin.good())
{
getline(vin, strtmp);
strtmp = wtrim(strtmp);
if (strtmp.length() == 0)
break;
word4idx[strtmp] = b;
idx4word[b++] = strtmp;
if (readClass)
{
vector<wstring> wordandcls = wsep_string(strtmp, wstr);
long cls = (long)_wtoi(wordandcls[1].c_str());
word4cls[wordandcls[0]] = cls;
idx4class[b] = cls;
if (idx4class[b] != prevcls)
{
if (idx4class[b] < prevcls)
LogicError("LUSequenceReader: the word list needs to be grouped into classes and the classes indices need to be ascending.");
prevcls = idx4class[b];
}
word4idx[wordandcls[0]] = b;
idx4word[b++] = wordandcls[0];
if (mNbrCls < cls)
mNbrCls = cls;
}
else {
word4idx[strtmp] = b;
idx4word[b++] = strtmp;
}
nwords++;
}
vin.close();
if (readClass)
mNbrCls++;
}
template<class ElemType>
void BatchLUSequenceReader<ElemType>::GetClassInfo(LabelInfo& lblInfo)
{
if (lblInfo.m_clsinfoRead || lblInfo.mNbrClasses == 0) return;
// populate local CPU matrix
if (lblInfo.m_id2classLocal == nullptr)
lblInfo.m_id2classLocal = new Matrix<ElemType>(CPUDEVICE);
if (lblInfo.m_classInfoLocal == nullptr)
lblInfo.m_classInfoLocal = new Matrix<ElemType>(CPUDEVICE);
lblInfo.m_classInfoLocal->SwitchToMatrixType(MatrixType::DENSE, matrixFormatDense, false);
lblInfo.m_classInfoLocal->Resize(2, lblInfo.mNbrClasses);
//move to CPU since element-wise operation is expensive and can go wrong in GPU
int curDevId = lblInfo.m_classInfoLocal->GetDeviceId();
lblInfo.m_classInfoLocal->TransferFromDeviceToDevice(curDevId, CPUDEVICE, true, false, false);
int clsidx;
int prvcls = -1;
for (size_t j = 0; j < nwords; j++)
{
clsidx = lblInfo.idx4class[(long)j];
if (prvcls != clsidx)
{
if (prvcls >= 0)
(*lblInfo.m_classInfoLocal)(1, prvcls) = (float)j;
prvcls = clsidx;
(*lblInfo.m_classInfoLocal)(0, prvcls) = (float)j;
}
}
(*lblInfo.m_classInfoLocal)(1, prvcls) = (float)nwords;
lblInfo.m_classInfoLocal->TransferFromDeviceToDevice(CPUDEVICE, curDevId, true, false, false);
lblInfo.m_clsinfoRead = true;
}
// GetIdFromLabel - get an Id from a Label
@ -278,6 +365,8 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
m_labelInfo[index].isproposal = labelConfig[index]("isproposal", "false");
m_labelInfo[index].m_clsinfoRead = false;
// determine label type desired
std::string labelType(labelConfig[index]("labelType", "Category"));
if (labelType == "Category")
@ -290,9 +379,24 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
// if we have labels, we need a label Mapping file, it will be a file with one label per line
if (m_labelInfo[index].type != labelNone)
{
string mode = labelConfig[index]("mode", "plain");//plain, class
m_labelInfo[index].m_classInfoLocal = nullptr;
m_labelInfo[index].m_id2classLocal = nullptr;
if (mode == "class")
{
m_labelInfo[index].readerMode = ReaderMode::Class;
}
std::wstring wClassFile = labelConfig[index]("token", "");
if (wClassFile != L""){
ReadLabelInfo(wClassFile, m_labelInfo[index].word4idx, m_labelInfo[index].idx4word);
ReadLabelInfo(wClassFile, m_labelInfo[index].word4idx,
m_labelInfo[index].readerMode == ReaderMode::Class,
m_labelInfo[index].word4cls,
m_labelInfo[index].idx4word, m_labelInfo[index].idx4class, m_labelInfo[index].mNbrClasses);
GetClassInfo(m_labelInfo[index]);
}
if (m_labelInfo[index].busewordmap)
ChangeMaping(mWordMapping, mUnkStr, m_labelInfo[index].word4idx);
@ -322,7 +426,7 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
const LabelInfo& labelIn = m_labelInfo[labelInfoIn];
const LabelInfo& labelOut = m_labelInfo[labelInfoOut];
m_parser.ParseInit(m_file.c_str(), labelIn.dim, labelOut.dim, labelIn.beginSequence, labelIn.endSequence, labelOut.beginSequence, labelOut.endSequence);
m_parser.ParseInit(m_file.c_str(), labelIn.dim, labelOut.dim, labelIn.beginSequence, labelIn.endSequence, labelOut.beginSequence, labelOut.endSequence, mUnkStr);
mBlgSize = readerConfig("nbruttsineachrecurrentiter", "1");
@ -344,6 +448,7 @@ void BatchLUSequenceReader<ElemType>::Init(const ConfigParameters& readerConfig)
mAllowMultPassData = readerConfig("dataMultiPass", "false");
mIgnoreSentenceBeginTag = readerConfig("ignoresentencebegintag", "false");
}
template<class ElemType>
@ -610,12 +715,13 @@ bool BatchLUSequenceReader<ElemType>::EnsureDataAvailable(size_t /*mbStartSample
for (i = (int)mLastPosInSentence; j < (int)mMaxSentenceLength; i++, j++)
{
mtSentenceBegin.SetValue(0, j, (ElemType)0);
for (int k = 0; k < mToProcess.size(); k++)
{
size_t seq = mToProcess[k];
if (i == mLastPosInSentence)
if (
i == mLastPosInSentence /// the first time instance has sentence begining
)
{
mSentenceBeginAt[k] = i;
if (mIgnoreSentenceBeginTag == false) /// ignore sentence begin, this is used for decoder network reader, which carries activities from the encoder networks
@ -794,7 +900,7 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
features.SetValue(locObs);
lablsize = GetLabelOutput(matrices, actualmbsize);
lablsize = GetLabelOutput(matrices, m_labelInfo[labelInfoOut], actualmbsize);
// go to the next sequence
m_seqIndex++;
@ -814,9 +920,8 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix
template<class ElemType>
size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
Matrix<ElemType>*>& matrices, size_t actualmbsize)
Matrix<ElemType>*>& matrices, LabelInfo& labelInfo, size_t actualmbsize)
{
const LabelInfo& labelInfo = m_labelInfo[labelInfoOut];
Matrix<ElemType>* labels = matrices[m_labelsName[labelInfoOut]];
if (labels == nullptr) return 0;
@ -835,11 +940,29 @@ size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
size_t utt_t = (size_t) floor(j / mSentenceBeginAt.size());
if (utt_t > mSentenceEndAt[utt_id]) continue;
labels->SetValue(wrd, j, 1);
if (labelInfo.readerMode == ReaderMode::Plain)
labels->SetValue(wrd, j, 1);
else if (labelInfo.readerMode == ReaderMode::Class && labelInfo.mNbrClasses > 0)
{
labels->SetValue(0, j, (ElemType)wrd);
long clsidx = -1;
clsidx = labelInfo.idx4class[wrd];
labels->SetValue(1, j, (ElemType)clsidx);
/// save the [begining ending_indx) of the class
ElemType lft = (*labelInfo.m_classInfoLocal)(0, clsidx);
ElemType rgt = (*labelInfo.m_classInfoLocal)(1, clsidx);
if (rgt <= lft)
LogicError("LUSequenceReader : right is equal or smaller than the left, which is wrong.");
labels->SetValue(2, j, lft); /// begining index of the class
labels->SetValue(3, j, rgt); /// end index of the class
}
else
LogicError("LUSequenceReader: reader mode is not set to Plain. Or in the case of setting it to Class, the class number is 0. ");
nbrLabl++;
}
labels->TransferFromDeviceToDevice(CPUDEVICE, device, true);
return nbrLabl;
}
@ -1158,6 +1281,8 @@ void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType
/// run for each reader
vector<size_t> col;
size_t rows = 0, cols = 0;
if (mReader.size() > 1)
LogicError("MultiIOBatchLUSequenceReader::SetSentenceSegBatch only supports processing from one BatchLUSequenceReader");
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
{
(p->second)->SetSentenceSegBatch(sentenceBegin, sentenceExistBeginOrNolabels);
@ -1170,21 +1295,6 @@ void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType
col.push_back(this_col);
cols += this_col;
}
sentenceBegin.Resize(rows, cols);
sentenceExistBeginOrNolabels.Resize(rows, cols);
size_t i = 0, t = 0;
for (map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
{
Matrix<ElemType> mtmp(sentenceBegin.GetDeviceId());
Matrix<ElemType> mtmp2(sentenceExistBeginOrNolabels.GetDeviceId());
(p->second)->SetSentenceSegBatch(mtmp, mtmp2);
sentenceBegin.ColumnSlice(i, col[t]).SetValue(mtmp);
sentenceExistBeginOrNolabels.ColumnSlice(i, col[t]).SetValue(mtmp2);
i += col[t];
t++;
}
}
template<class ElemType>

Просмотреть файл

@ -40,6 +40,12 @@ enum LabelKind
labelOther = 3, // some other type of label
};
enum ReaderMode
{
Plain = 0, // no class info
Class = 1, // category labels, creates mapping tables
};
template<class ElemType>
class LUSequenceReader : public IDataReader<ElemType>
{
@ -129,6 +135,20 @@ protected:
bool isproposal; /// whether this is for proposal generation
ReaderMode readerMode;
/**
word class info saved in file in format below
! 29
# 58
$ 26
where the first column is the word and the second column is the class id, base 0
*/
map<wstring, long> word4cls;
map<long, long> idx4class;
Matrix<ElemType>* m_id2classLocal; // CPU version
Matrix<ElemType>* m_classInfoLocal; // CPU version
int mNbrClasses;
bool m_clsinfoRead;
} m_labelInfo[labelInfoMax];
// caching support
@ -152,8 +172,6 @@ protected:
public:
void Init(const ConfigParameters& ){};
void ReadLabelInfo(const wstring & vocfile, map<LabelType, LabelIdType> & word4idx,
map<LabelIdType, LabelType>& idx4word);
void ChangeMaping(const map<LabelType, LabelType>& maplist,
const LabelType& unkstr,
map<LabelType, LabelIdType> & word4idx);
@ -227,7 +245,6 @@ public:
using LUSequenceReader<ElemType>::ChangeMaping;
using LUSequenceReader<ElemType>::GetIdFromLabel;
using LUSequenceReader<ElemType>::InitCache;
using LUSequenceReader<ElemType>::ReadLabelInfo;
using LUSequenceReader<ElemType>::mRandomize;
using LUSequenceReader<ElemType>::m_seed;
using LUSequenceReader<ElemType>::mTotalSentenceSofar;
@ -249,7 +266,7 @@ private:
public:
vector<bool> mProcessed;
LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
BatchLUSequenceReader() {
BatchLUSequenceReader() : mtSentenceBegin(CPUDEVICE), mtExistsSentenceBeginOrNoLabels(CPUDEVICE){
mLastProcssedSentenceId = 0;
mBlgSize = 1;
mLastPosInSentence = 0;
@ -259,12 +276,7 @@ public:
mIgnoreSentenceBeginTag = false;
}
~BatchLUSequenceReader() {
if (m_labelTemp.size() > 0)
m_labelTemp.clear();
if (m_featureTemp.size() > 0)
m_featureTemp.clear();
};
~BatchLUSequenceReader();
void Init(const ConfigParameters& readerConfig);
void Reset();
@ -279,7 +291,8 @@ public:
void SetSentenceBegin(size_t wrd, size_t pos, size_t actualMbSize) { SetSentenceBegin((int)wrd, (int)pos, (int)actualMbSize); }
void SetSentenceEnd(size_t wrd, size_t pos, size_t actualMbSize) { SetSentenceEnd((int)wrd, (int)pos, (int)actualMbSize); }
size_t GetLabelOutput(std::map<std::wstring, Matrix<ElemType>*>& matrices, size_t actualmbsize);
size_t GetLabelOutput(std::map<std::wstring,
Matrix<ElemType>*>& matrices, LabelInfo& labelInfo, size_t actualmbsize);
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
@ -291,6 +304,15 @@ public:
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, Matrix<ElemType>& sentenceExistsBeginOrNoLabels);
public:
void GetClassInfo(LabelInfo& lblInfo);
void ReadLabelInfo(const wstring & vocfile,
map<wstring, long> & word4idx,
bool readClass,
map<wstring, long>& word4cls,
map<long, wstring>& idx4word,
map<long, long>& idx4class,
int & mNbrCls);
void LoadWordMapping(const ConfigParameters& readerConfig);
bool CanReadFor(wstring nodeName); /// return true if this reader can output for a node with name nodeName