This commit is contained in:
Linquan Liu 2019-01-03 13:58:32 +08:00
Родитель 09a3f4c1b7
Коммит f486e3bad5
2 изменённых файлов: 85 добавлений и 81 удалений

Просмотреть файл

@ -1597,8 +1597,9 @@ public:
static void convert(const std::wstring& intocpath, const std::wstring& intocpath2, const std::wstring& outpath,
const msra::asr::simplesenonehmm& hset, std::unordered_map<size_t, std::wstring>& id2wordmapping, std::set<size_t>& specialwordids);
static std::vector<std::wstring> splitword2character(const std::wstring& s);
/*static std::vector<std::wstring> splitword2character(const std::wstring& s);
static bool istagword(const std::wstring& s);
static float computewerandcer(const std::vector<size_t>& wids, const std::vector<size_t>& path_ids, const std::unordered_map<size_t, std::wstring>* ptr_id2wordmap4node);*/
};
};
};

Просмотреть файл

@ -15,7 +15,7 @@
#include <unordered_map>
#include <list>
#include <stdexcept>
#include<regex>
#include <regex>
using namespace std;
@ -860,7 +860,7 @@ std::vector<std::wstring> splitword2character(const std::wstring &s)
return char_array;*/
std::wregex words_regex(L"([\u4e00-\u9fa5]|[^\u4e00-\u9fa5\\s]+)");
std::wregex words_regex(L"([\u4e00-\u9fa5]|[^\u4e00-\u9fa5\\s]+)");
auto words_begin = std::wsregex_iterator(s.begin(), s.end(), words_regex);
auto words_end = std::wsregex_iterator();
std::vector<std::wstring> tgt;
@ -873,7 +873,6 @@ std::vector<std::wstring> splitword2character(const std::wstring &s)
}
return tgt;
}
bool istagword(const std::wstring &s)
@ -887,6 +886,83 @@ bool istagword(const std::wstring &s)
return false;
}
float computewerandcer(std::vector<size_t> &wids, std::vector<size_t> &path_ids, const std::unordered_map<size_t, std::wstring> *ptr_id2wordmap4node)
{
float wer;
if (ptr_id2wordmap4node->size() > 0)
{
std::vector<std::wstring> refwords;
std::vector<std::wstring> regwords;
std::vector<size_t> refid;
std::vector<size_t> regid;
std::wstring temp_string;
std::vector<std::wstring> character_array;
std::unordered_map<std::wstring, size_t> idmappingtable;
refwords.clear();
regwords.clear();
character_array.clear();
refid.clear();
regid.clear();
idmappingtable.clear();
std::unordered_map<size_t, std::wstring>::const_iterator maptable_itr;
for (std::vector<size_t>::const_iterator it = wids.begin(); it != wids.end(); ++it)
{
maptable_itr = ptr_id2wordmap4node->find(*it);
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
character_array = splitword2character(temp_string);
foreach_index (_i, character_array)
{
refwords.push_back(character_array[_i]);
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
{
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
}
}
}
for (std::vector<size_t>::const_iterator it = path_ids.begin(); it != path_ids.end(); ++it)
{
maptable_itr = ptr_id2wordmap4node->find(*it);
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
character_array = splitword2character(temp_string);
foreach_index (_i, character_array)
{
regwords.push_back(character_array[_i]);
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
{
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
}
}
}
//map characters to id to be compatiable with egacy code
//skip tag words
foreach_index (_k, refwords)
{
if (!istagword(refwords[_k]))
refid.push_back(idmappingtable.find(refwords[_k])->second);
}
foreach_index (_k, regwords)
{
if (!istagword(regwords[_k]))
regid.push_back(idmappingtable.find(regwords[_k])->second);
}
wer = compute_wer(refid, regid);
}
else
{
wer = compute_wer(wids, path_ids);
}
return wer;
}
double lattice::nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate &parallelstate, std::vector<NBestToken> &tokenlattice, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords,
const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids
/*, std::unordered_map<int, std::wstring> wordipmap /*added by linquan*/) const
@ -1011,82 +1087,8 @@ double lattice::nbestlatticeEMBR(const std::vector<float> &edgeacscores, paralle
path_ids.push_back(nodes[edges[path[k]].E].wid);
}
float wer;
if (ptr_id2wordmap4node->size() > 0)
{
//Linquan added
std::vector<std::wstring> refwords;
std::vector<std::wstring> regwords;
std::vector<size_t> refid;
std::vector<size_t> regid;
std::wstring temp_string;
std::vector<std::wstring> character_array;
std::unordered_map<std::wstring, size_t> idmappingtable;
refwords.clear();
regwords.clear();
character_array.clear();
refid.clear();
regid.clear();
idmappingtable.clear();
std::unordered_map<size_t, std::wstring>::const_iterator maptable_itr;
for (std::vector<size_t>::const_iterator it = wids.begin(); it != wids.end(); ++it)
{
maptable_itr = ptr_id2wordmap4node->find(*it);
/* if (maptable_itr != ptr_id2wordmap4node->end())
refwords.push_back(maptable_itr->second);
else
refwords.push_back(std::to_wstring(*it));*/
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
character_array = splitword2character(temp_string);
foreach_index (_i, character_array)
{
refwords.push_back(character_array[_i]);
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
{
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
}
}
}
for (std::vector<size_t>::const_iterator it = path_ids.begin(); it != path_ids.end(); ++it)
{
maptable_itr = ptr_id2wordmap4node->find(*it);
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
character_array = splitword2character(temp_string);
foreach_index (_i, character_array)
{
regwords.push_back(character_array[_i]);
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
{
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
}
}
}
//map characters to id to be compatiable with egacy code
//skip tag words
foreach_index (_k, refwords)
{
if (!istagword(refwords[_k]))
refid.push_back(idmappingtable.find(refwords[_k])->second);
}
foreach_index (_k, regwords)
{
if (!istagword(regwords[_k]))
regid.push_back(idmappingtable.find(regwords[_k])->second);
}
wer = compute_wer(refid, regid);
}
else
wer = compute_wer(wids, path_ids);
//linquan
float wer = computewerandcer(wids, path_ids, ptr_id2wordmap4node);
// will favor the path with better WER
pathscore -= double(accWeightInNbest * wer);
@ -1827,7 +1829,8 @@ double lattice::get_edge_weights(std::vector<size_t> &wids, std::vector<std::vec
nodes[edges[vt_paths[i][j]].E].wid;
}
vt_path_weights[i] = compute_wer(wids, path_ids);
//linquan
vt_path_weights[i] = computewerandcer(wids, path_ids, ptr_id2wordmap4node);
string pathidstr = "$";
for (size_t j = 0; j < path_ids.size(); j++)