Enable unigram support on Linux
This commit is contained in:
Родитель
b1ccf91f85
Коммит
cf0d15a01b
|
@ -126,7 +126,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
soName = soName + ".so";
|
||||
void *handle = dlopen(soName.c_str(), RTLD_LAZY);
|
||||
if (handle == NULL)
|
||||
RuntimeError("Plugin not found: %s", soName.c_str());
|
||||
{
|
||||
RuntimeError("Plugin not found: %s (error: %s)", soName.c_str(), dlerror());
|
||||
}
|
||||
return dlsym(handle, proc.c_str());
|
||||
}
|
||||
|
||||
|
|
|
@ -14,9 +14,7 @@
|
|||
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
|
||||
#endif
|
||||
#include "simplesenonehmm.h" // for MMI scoring
|
||||
#ifdef _WIN32
|
||||
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
|
||||
#endif
|
||||
|
||||
#include "rollingwindowsource.h" // minibatch sources
|
||||
#include "utterancesource.h"
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
#include "htkfeatio.h" // for reading HTK features
|
||||
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
|
||||
#include "simplesenonehmm.h" // for MMI scoring
|
||||
#ifdef _WIN32
|
||||
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
|
||||
#endif
|
||||
|
||||
#include "rollingwindowsource.h" // minibatch sources
|
||||
#include "utterancesourcemulti.h"
|
||||
|
@ -48,6 +46,10 @@ typedef unsigned int UNINT32;
|
|||
int msra::numa::node_override = -1; // for numahelpers.h
|
||||
#endif
|
||||
|
||||
namespace msra { namespace lm {
|
||||
/*static*/ const mgram_map::index_t mgram_map::nindex = (mgram_map::index_t) -1; // invalid index
|
||||
}}
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// Create a Data Reader
|
||||
|
@ -337,9 +339,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
infilesmulti.push_back(filelist);
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
if (readerConfig.Exists("unigram"))
|
||||
unigrampath = readerConfig("unigram");
|
||||
unigrampath = (wstring)readerConfig("unigram");
|
||||
|
||||
// load a unigram if needed (this is used for MMI training)
|
||||
msra::lm::CSymbolSet unigramsymbols;
|
||||
|
@ -358,7 +359,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
if (!unigram)
|
||||
fprintf (stderr, "trainlayer: OOV-exclusion code enabled, but no unigram specified to derive the word set from, so you won't get OOV exclusion\n");
|
||||
#endif
|
||||
|
||||
// currently assumes all mlfs will have same root name (key)
|
||||
set<wstring> restrictmlftokeys; // restrict MLF reader to these files--will make stuff much faster without having to use shortened input files
|
||||
|
@ -382,11 +382,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
//std::vector<std::wstring> pagepath;
|
||||
foreach_index(i, mlfpathsmulti)
|
||||
{
|
||||
#ifdef WIN32
|
||||
const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
|
||||
#else
|
||||
const map<string, size_t>* wordmap = NULL;
|
||||
#endif
|
||||
msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>
|
||||
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string,size_t>*) NULL, htktimetoframe); // label MLF
|
||||
// get the temp file name for the page file
|
||||
|
|
|
@ -1288,7 +1288,9 @@ public:
|
|||
soName = soName + ".so";
|
||||
void *handle = dlopen(soName.c_str(), RTLD_LAZY);
|
||||
if (handle == NULL)
|
||||
RuntimeError("Plugin not found: %s", soName.c_str());
|
||||
{
|
||||
RuntimeError("Plugin not found: %s (error: %s)", soName.c_str(), dlerror());
|
||||
}
|
||||
return dlsym(handle, proc.c_str());
|
||||
}
|
||||
|
||||
|
|
|
@ -22,8 +22,9 @@ namespace msra { namespace lm {
|
|||
// core LM interface -- LM scores are accessed through this exclusively
|
||||
// ===========================================================================
|
||||
|
||||
interface ILM // generic interface -- mostly the score() function
|
||||
class ILM // generic interface -- mostly the score() function
|
||||
{
|
||||
public:
|
||||
virtual double score (const int * mgram, int m) const = 0;
|
||||
virtual bool oov (int w) const = 0; // needed for perplexity calculation
|
||||
// ... TODO (?): return true/false to indicate whether anything changed.
|
||||
|
@ -31,8 +32,9 @@ interface ILM // generic interface -- mostly the score() function
|
|||
virtual void adapt (const int * data, size_t m) = 0; // (NULL,M) to reset, (!NULL,0) to flush
|
||||
|
||||
// iterator for composing models --iterates in increasing order w.r.t. w
|
||||
interface IIter
|
||||
class IIter
|
||||
{
|
||||
public:
|
||||
virtual operator bool() const = 0; // has iterator not yet reached end?
|
||||
// ... TODO: ensure iterators do not return OOVs w.r.t. user symbol table
|
||||
// (It needs to be checked which LM type's iterator currently does.)
|
||||
|
@ -128,7 +130,11 @@ public:
|
|||
// create
|
||||
const char * p = _strdup (key);
|
||||
if (!p)
|
||||
#ifdef _WIN32
|
||||
throw std::bad_exception ("CSymbolSet:id string allocation failure");
|
||||
#else
|
||||
throw std::bad_exception ();
|
||||
#endif
|
||||
try
|
||||
{
|
||||
int id = (int) symbols.size();
|
||||
|
@ -274,7 +280,7 @@ class mgram_map
|
|||
{
|
||||
typedef unsigned int index_t; // (-> size_t when we really need it)
|
||||
//typedef size_t index_t; // (tested once, seems to work)
|
||||
static const index_t nindex = (index_t) -1; // invalid index
|
||||
static const index_t nindex; // invalid index
|
||||
// entry [m][i] is first index of children in level m+1, entry[m][i+1] the end.
|
||||
int M; // order, e.g. M=3 for trigram
|
||||
std::vector<std::vector<index_t>> firsts; // [M][i] ([0] = zerogram = root)
|
||||
|
@ -1124,7 +1130,7 @@ public:
|
|||
void read (const std::wstring & pathname, SYMMAP & userSymMap, bool filterVocabulary, int maxM)
|
||||
{
|
||||
int lineNo = 0;
|
||||
msra::basetypes::auto_file_ptr f = fopenOrDie (pathname, L"rbS");
|
||||
msra::basetypes::auto_file_ptr f(fopenOrDie (pathname, L"rbS"));
|
||||
fprintf (stderr, "read: reading %S", pathname.c_str());
|
||||
filename = pathname; // (keep this info for debugging)
|
||||
|
||||
|
@ -1769,7 +1775,7 @@ protected:
|
|||
mcounts.push_back (mmap.create (newkey, mmapCache), count); // store 'count' under 'key'
|
||||
}
|
||||
}
|
||||
fprintf (stderr, " %d %d-grams", mcounts.size (m), m);
|
||||
fprintf (stderr, " %d %d-grams", (int)mcounts.size (m), m);
|
||||
}
|
||||
|
||||
// remove used up tokens from the buffer
|
||||
|
@ -2027,7 +2033,7 @@ public:
|
|||
while (M > 0 && counts.size (M) == 0) resize (M-1);
|
||||
|
||||
for (int m = 1; m <= M; m++)
|
||||
fprintf (stderr, "estimate: read %d %d-grams\n", counts.size (m), m);
|
||||
fprintf (stderr, "estimate: read %d %d-grams\n", (int)counts.size (m), m);
|
||||
|
||||
// === Kneser-Ney smoothing
|
||||
// This is a strange algorithm.
|
||||
|
@ -2197,8 +2203,8 @@ public:
|
|||
for (int m = 1; m <= M; m++)
|
||||
{
|
||||
fprintf (stderr, "estimate: %d-grams after pruning: %d out of %d (%.1f%%)\n", m,
|
||||
numMGrams[m], counts.size (m),
|
||||
100.0 * numMGrams[m] / max (counts.size (m), 1));
|
||||
numMGrams[m], (int)counts.size (m),
|
||||
100.0 * numMGrams[m] / max (counts.size (m), size_t(1)));
|
||||
}
|
||||
|
||||
// ensure M reflects the actual order of read data after pruning
|
||||
|
@ -2282,6 +2288,9 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
double dcount;
|
||||
double dP;
|
||||
|
||||
// pruned case
|
||||
if (count == 0) // this entry was pruned before
|
||||
goto skippruned;
|
||||
|
@ -2314,7 +2323,7 @@ public:
|
|||
}
|
||||
|
||||
// estimate discounted probability
|
||||
double dcount = count; // "modified Kneser-Ney" discounting
|
||||
dcount = count; // "modified Kneser-Ney" discounting
|
||||
if (count >= 3) dcount -= d3[m];
|
||||
else if (count == 2) dcount -= d2[m];
|
||||
else if (count == 1) dcount -= d1[m];
|
||||
|
@ -2323,7 +2332,7 @@ public:
|
|||
|
||||
if (histCount == 0)
|
||||
RuntimeError ("estimate: unexpected 0 denominator");
|
||||
double dP = dcount / histCount;
|
||||
dP = dcount / histCount;
|
||||
// and this is the discounted probability value
|
||||
{
|
||||
// Actually, 'key' uses a "mapped" word ids, while create()
|
||||
|
@ -2412,7 +2421,7 @@ skippruned:; // m-gram was pruned
|
|||
updateOOVScore();
|
||||
|
||||
fprintf (stderr, "estimate: done");
|
||||
for (int m = 1; m <= M; m++) fprintf (stderr, ", %d %d-grams", logP.size (m), m);
|
||||
for (int m = 1; m <= M; m++) fprintf (stderr, ", %d %d-grams", (int)logP.size (m), m);
|
||||
fprintf (stderr, "\n");
|
||||
}
|
||||
};
|
||||
|
@ -2521,7 +2530,7 @@ skipMGram:
|
|||
wstring dir, file;
|
||||
splitpath (clonepath, dir, file); // we allow relative paths in the file
|
||||
|
||||
msra::basetypes::auto_file_ptr f = fopenOrDie (clonepath, L"rbS");
|
||||
msra::basetypes::auto_file_ptr f(fopenOrDie (clonepath, L"rbS"));
|
||||
std::string line = fgetline (f);
|
||||
if (line != "#clone")
|
||||
throw runtime_error ("read: invalid header line " + line);
|
||||
|
|
Загрузка…
Ссылка в новой задаче