update eMBR with CER
This commit is contained in:
Родитель
09a3f4c1b7
Коммит
f486e3bad5
|
@ -1597,8 +1597,9 @@ public:
|
|||
static void convert(const std::wstring& intocpath, const std::wstring& intocpath2, const std::wstring& outpath,
|
||||
const msra::asr::simplesenonehmm& hset, std::unordered_map<size_t, std::wstring>& id2wordmapping, std::set<size_t>& specialwordids);
|
||||
|
||||
static std::vector<std::wstring> splitword2character(const std::wstring& s);
|
||||
/*static std::vector<std::wstring> splitword2character(const std::wstring& s);
|
||||
static bool istagword(const std::wstring& s);
|
||||
static float computewerandcer(const std::vector<size_t>& wids, const std::vector<size_t>& path_ids, const std::unordered_map<size_t, std::wstring>* ptr_id2wordmap4node);*/
|
||||
};
|
||||
};
|
||||
};
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
#include <unordered_map>
|
||||
#include <list>
|
||||
#include <stdexcept>
|
||||
#include<regex>
|
||||
#include <regex>
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -860,7 +860,7 @@ std::vector<std::wstring> splitword2character(const std::wstring &s)
|
|||
|
||||
return char_array;*/
|
||||
|
||||
std::wregex words_regex(L"([\u4e00-\u9fa5]|[^\u4e00-\u9fa5\\s]+)");
|
||||
std::wregex words_regex(L"([\u4e00-\u9fa5]|[^\u4e00-\u9fa5\\s]+)");
|
||||
auto words_begin = std::wsregex_iterator(s.begin(), s.end(), words_regex);
|
||||
auto words_end = std::wsregex_iterator();
|
||||
std::vector<std::wstring> tgt;
|
||||
|
@ -873,7 +873,6 @@ std::vector<std::wstring> splitword2character(const std::wstring &s)
|
|||
}
|
||||
|
||||
return tgt;
|
||||
|
||||
}
|
||||
|
||||
bool istagword(const std::wstring &s)
|
||||
|
@ -887,6 +886,83 @@ bool istagword(const std::wstring &s)
|
|||
return false;
|
||||
}
|
||||
|
||||
float computewerandcer(std::vector<size_t> &wids, std::vector<size_t> &path_ids, const std::unordered_map<size_t, std::wstring> *ptr_id2wordmap4node)
|
||||
{
|
||||
float wer;
|
||||
if (ptr_id2wordmap4node->size() > 0)
|
||||
{
|
||||
std::vector<std::wstring> refwords;
|
||||
std::vector<std::wstring> regwords;
|
||||
std::vector<size_t> refid;
|
||||
std::vector<size_t> regid;
|
||||
std::wstring temp_string;
|
||||
std::vector<std::wstring> character_array;
|
||||
std::unordered_map<std::wstring, size_t> idmappingtable;
|
||||
|
||||
refwords.clear();
|
||||
regwords.clear();
|
||||
character_array.clear();
|
||||
refid.clear();
|
||||
regid.clear();
|
||||
idmappingtable.clear();
|
||||
std::unordered_map<size_t, std::wstring>::const_iterator maptable_itr;
|
||||
for (std::vector<size_t>::const_iterator it = wids.begin(); it != wids.end(); ++it)
|
||||
{
|
||||
maptable_itr = ptr_id2wordmap4node->find(*it);
|
||||
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
|
||||
character_array = splitword2character(temp_string);
|
||||
|
||||
foreach_index (_i, character_array)
|
||||
{
|
||||
refwords.push_back(character_array[_i]);
|
||||
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
|
||||
{
|
||||
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (std::vector<size_t>::const_iterator it = path_ids.begin(); it != path_ids.end(); ++it)
|
||||
{
|
||||
maptable_itr = ptr_id2wordmap4node->find(*it);
|
||||
|
||||
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
|
||||
character_array = splitword2character(temp_string);
|
||||
|
||||
foreach_index (_i, character_array)
|
||||
{
|
||||
regwords.push_back(character_array[_i]);
|
||||
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
|
||||
{
|
||||
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//map characters to id to be compatiable with egacy code
|
||||
//skip tag words
|
||||
foreach_index (_k, refwords)
|
||||
{
|
||||
if (!istagword(refwords[_k]))
|
||||
refid.push_back(idmappingtable.find(refwords[_k])->second);
|
||||
}
|
||||
|
||||
foreach_index (_k, regwords)
|
||||
{
|
||||
if (!istagword(regwords[_k]))
|
||||
regid.push_back(idmappingtable.find(regwords[_k])->second);
|
||||
}
|
||||
|
||||
wer = compute_wer(refid, regid);
|
||||
}
|
||||
else
|
||||
{
|
||||
wer = compute_wer(wids, path_ids);
|
||||
}
|
||||
|
||||
return wer;
|
||||
}
|
||||
|
||||
double lattice::nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<NBestToken> &tokenlattice, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords,
|
||||
const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids
|
||||
/*, std::unordered_map<int, std::wstring> wordipmap /*added by linquan*/) const
|
||||
|
@ -1011,82 +1087,8 @@ double lattice::nbestlatticeEMBR(const std::vector<float> &edgeacscores, paralle
|
|||
path_ids.push_back(nodes[edges[path[k]].E].wid);
|
||||
}
|
||||
|
||||
float wer;
|
||||
if (ptr_id2wordmap4node->size() > 0)
|
||||
{
|
||||
//Linquan added
|
||||
std::vector<std::wstring> refwords;
|
||||
std::vector<std::wstring> regwords;
|
||||
std::vector<size_t> refid;
|
||||
std::vector<size_t> regid;
|
||||
std::wstring temp_string;
|
||||
std::vector<std::wstring> character_array;
|
||||
std::unordered_map<std::wstring, size_t> idmappingtable;
|
||||
|
||||
refwords.clear();
|
||||
regwords.clear();
|
||||
character_array.clear();
|
||||
refid.clear();
|
||||
regid.clear();
|
||||
idmappingtable.clear();
|
||||
std::unordered_map<size_t, std::wstring>::const_iterator maptable_itr;
|
||||
for (std::vector<size_t>::const_iterator it = wids.begin(); it != wids.end(); ++it)
|
||||
{
|
||||
maptable_itr = ptr_id2wordmap4node->find(*it);
|
||||
/* if (maptable_itr != ptr_id2wordmap4node->end())
|
||||
refwords.push_back(maptable_itr->second);
|
||||
|
||||
else
|
||||
refwords.push_back(std::to_wstring(*it));*/
|
||||
|
||||
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
|
||||
character_array = splitword2character(temp_string);
|
||||
|
||||
foreach_index (_i, character_array)
|
||||
{
|
||||
refwords.push_back(character_array[_i]);
|
||||
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
|
||||
{
|
||||
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (std::vector<size_t>::const_iterator it = path_ids.begin(); it != path_ids.end(); ++it)
|
||||
{
|
||||
maptable_itr = ptr_id2wordmap4node->find(*it);
|
||||
|
||||
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
|
||||
character_array = splitword2character(temp_string);
|
||||
|
||||
foreach_index (_i, character_array)
|
||||
{
|
||||
regwords.push_back(character_array[_i]);
|
||||
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
|
||||
{
|
||||
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//map characters to id to be compatiable with egacy code
|
||||
//skip tag words
|
||||
foreach_index (_k, refwords)
|
||||
{
|
||||
if (!istagword(refwords[_k]))
|
||||
refid.push_back(idmappingtable.find(refwords[_k])->second);
|
||||
}
|
||||
|
||||
foreach_index (_k, regwords)
|
||||
{
|
||||
if (!istagword(regwords[_k]))
|
||||
regid.push_back(idmappingtable.find(regwords[_k])->second);
|
||||
}
|
||||
|
||||
wer = compute_wer(refid, regid);
|
||||
}
|
||||
else
|
||||
wer = compute_wer(wids, path_ids);
|
||||
//linquan
|
||||
float wer = computewerandcer(wids, path_ids, ptr_id2wordmap4node);
|
||||
|
||||
// will favor the path with better WER
|
||||
pathscore -= double(accWeightInNbest * wer);
|
||||
|
@ -1827,7 +1829,8 @@ double lattice::get_edge_weights(std::vector<size_t> &wids, std::vector<std::vec
|
|||
nodes[edges[vt_paths[i][j]].E].wid;
|
||||
}
|
||||
|
||||
vt_path_weights[i] = compute_wer(wids, path_ids);
|
||||
//linquan
|
||||
vt_path_weights[i] = computewerandcer(wids, path_ids, ptr_id2wordmap4node);
|
||||
|
||||
string pathidstr = "$";
|
||||
for (size_t j = 0; j < path_ids.size(); j++)
|
||||
|
|
Загрузка…
Ссылка в новой задаче