This commit is contained in:
Guoli Ye 2017-12-07 17:47:35 -08:00
Родитель 830b6f94b4
Коммит 4018e1e724
65 изменённых файлов: 30250 добавлений и 475 удалений

Просмотреть файл

@ -2216,6 +2216,7 @@ Global
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug_CpuOnly|x64.ActiveCfg = Debug|Any CPU
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug_UWP|x64.ActiveCfg = Debug|Any CPU
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug|x64.ActiveCfg = Debug|Any CPU
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug|x64.Build.0 = Debug|Any CPU
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Release_CpuOnly|x64.ActiveCfg = Release|Any CPU
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Release_NoOpt|x64.ActiveCfg = Release|Any CPU
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Release_UWP|x64.ActiveCfg = Release|Any CPU

Просмотреть файл

@ -784,6 +784,9 @@ HTKMLFREADER_SRC =\
$(SOURCEDIR)/Readers/HTKMLFReader/DataWriterLocal.cpp \
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFReader.cpp \
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFWriter.cpp \
# $(SOURCEDIR)/Common/File.cpp \
# $(SOURCEDIR)/Common/fileutil.cpp \
# $(SOURCEDIR)/Common/ExceptionWithCallStack.cpp \
HTKMLFREADER_OBJ := $(patsubst %.cpp, $(OBJDIR)/%.o, $(HTKMLFREADER_SRC))

Просмотреть файл

@ -246,7 +246,7 @@ TIMIT_TrainSimple = new TrainAction [ // new: added TrainAction; t
//L3 = SBFF(L2,hiddenDim,hiddenDim)
//CE = SMBFF(L3,labelDim,hiddenDim,myLabels,tag=Criteria)
//Err = ClassificationError(myLabels,CE.BFF.FF.P,tag=Eval)
//logPrior = LogPrior(myLabels)
//logPrior = LogPrior(myLabels)
//ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
// new:
@ -282,7 +282,7 @@ TIMIT_TrainSimple = new TrainAction [ // new: added TrainAction; t
Err = ClassificationError(myLabels, outZ)
// define output node for decoding
logPrior = LogPrior(myLabels)
logPrior = LogPrior(myLabels)
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
]
]
@ -395,6 +395,6 @@ network = new NDL [
Err = ClassificationError(myLabels, outZ)
// define output node for decoding
logPrior = LogPrior(myLabels)
logPrior = LogPrior(myLabels)
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
]

Просмотреть файл

@ -606,7 +606,7 @@ static void PrintBanner(int argc, wchar_t* argv[], const string& timestamp)
fprintf(stderr, "%s %.6s, ", _BUILDBRANCH_, _BUILDSHA1_);
#endif
fprintf(stderr, "%s %s", __DATE__, __TIME__); // build time
fprintf(stderr, ") at %s\n\n", timestamp.c_str());
fprintf(stderr, ") on %s at %s\n\n", GetHostName().c_str(), timestamp.c_str());
for (int i = 0; i < argc; i++)
fprintf(stderr, "%*s%ls", i > 0 ? 2 : 0, "", argv[i]); // use 2 spaces for better visual separability
fprintf(stderr, "\n");
@ -617,8 +617,7 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[])
{
std::string timestamp = TimeDateStamp();
PrintBanner(argc, argv, timestamp);
ConfigParameters config;
ConfigParameters config;
std::string rawConfigString = ConfigParameters::ParseCommandLine(argc, argv, config); // get the command param set they want
int traceLevel = config(L"traceLevel", 0);

Просмотреть файл

@ -488,11 +488,11 @@ ndlMacroUseCNNAuto=[
]
ndlRnnNetwork=[
#define basic i/o
featDim=1845
labelDim=183
hiddenDim=2048
features=Input(featDim, tag=feature)
#define basic i/o
featDim=1845
labelDim=183
hiddenDim=2048
features=Input(featDim, tag=feature)
labels=Input(labelDim, tag=label)
MeanVarNorm(x)=[
@ -502,9 +502,9 @@ ndlRnnNetwork=[
]
# define network
featNorm = MeanVarNorm(features)
featNorm = MeanVarNorm(features)
W0 = Parameter(hiddenDim, featDim)
L1 = Times(W0,featNorm)
L1 = Times(W0,featNorm)
W = Parameter(hiddenDim, hiddenDim)
@ -515,8 +515,8 @@ ndlRnnNetwork=[
Output = Times(W2, Dout)
criterion = CrossEntropyWithSoftmax(labels, Output, tag=Criteria)
#CE = SMBFF(Dout,labelDim,hiddenDim,labels,tag=Criteria)
#Err = ErrorPrediction(labels,CE.BFF.FF.P,tag=Eval)
#CE = SMBFF(Dout,labelDim,hiddenDim,labels,tag=Criteria)
#Err = ErrorPrediction(labels,CE.BFF.FF.P,tag=Eval)
LogPrior(labels)
{
@ -525,8 +525,8 @@ ndlRnnNetwork=[
}
# define output (scaled loglikelihood)
logPrior = LogPrior(labels)
#ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
logPrior = LogPrior(labels)
#ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
# rootNodes defined here temporarily so we pass
OutputNodes=(criterion)
EvalNodes=(criterion)

Просмотреть файл

@ -35,6 +35,7 @@
<Keyword>Win32Proj</Keyword>
<RootNamespace>CNTKv2LibraryDll</RootNamespace>
<ProjectName>CNTKv2LibraryDll</ProjectName>
<WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(SolutionDir)\CNTK.Cpp.props" />
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />

Просмотреть файл

@ -274,11 +274,17 @@ bool DataReader::GetMinibatch(StreamMinibatchInputs& matrices)
// uids - lables stored in size_t vector instead of ElemType matrix
// boundary - phone boundaries
// returns - true if there are more minibatches, false if no more minibatches remain
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
/* guoye: start */
// bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
/* guoye: end */
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap);
/* guoye: start */
// bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap);
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, wids, nws, boundaries, extrauttmap);
/* guoye: end */
return bRet;
}
@ -288,8 +294,14 @@ bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::lattice
bool DataReader::GetHmmData(msra::asr::simplesenonehmm* hmm)
{
bool bRet = true;
// fprintf(stderr, "DataReader::GetHmmData: debug 1, m_ioNames.size() = %d \n", int(m_ioNames.size()));
for (size_t i = 0; i < m_ioNames.size(); i++)
{
//fprintf(stderr, "DataReader::GetHmmData: debug 2, i = %d , m_ioNames[i] = %ls \n", int(i), m_ioNames[i].c_str());
bRet &= m_dataReaders[m_ioNames[i]]->GetHmmData(hmm);
// fprintf(stderr, "DataReader::GetHmmData: debug 3, i = %d \n", int(i));
}
return bRet;
}

Просмотреть файл

@ -264,7 +264,10 @@ public:
}
virtual bool GetMinibatch(StreamMinibatchInputs& matrices) = 0;
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
/* guoye: start */
// virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*wids*/, vector<short>& /*nws*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
/* guoye: end */
{
NOT_IMPLEMENTED;
};
@ -444,7 +447,11 @@ public:
// [out] each matrix resized if necessary containing data.
// returns - true if there are more minibatches, false if no more minibatches remain
virtual bool GetMinibatch(StreamMinibatchInputs& matrices);
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
/* guoye: start */
// virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
/* guoye: end */
virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm);
size_t GetNumParallelSequencesForFixingBPTTMode();

Просмотреть файл

@ -34,6 +34,22 @@
#include <fcntl.h>
#define FCLOSE_SUCCESS 0
/* guoye: start */
/*
#include "basetypes.h" //for attemp()
#include "ProgressTracing.h"
#include <unistd.h>
#include <glob.h>
#include <dirent.h>
#include <sys/sendfile.h>
#include <stdio.h>
#include <ctype.h>
#include <limits.h>
#include <memory>
#include <cwctype>
*/
// using namespace Microsoft::MSR::CNTK;
/* guoye: end */
// ----------------------------------------------------------------------------
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
@ -77,6 +93,7 @@ void freadOrDie(_T& data, size_t num, FILE* f) // template for std::vector<>
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
}
#ifdef _WIN32
template <class _T>
void freadOrDie(_T& data, int num, const HANDLE f) // template for std::vector<>
@ -229,11 +246,129 @@ void fputstring(FILE* f, const wchar_t*);
void fputstring(FILE* f, const std::wstring&);
template <class CHAR>
CHAR* fgetline(FILE* f, CHAR* buf, int size);
CHAR* fgetline(FILE* f, CHAR* buf, int size)
{
// TODO: we should redefine this to write UTF-16 (which matters on GCC which defines wchar_t as 32 bit)
/* guoye: start */
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 0\n");
/* guoye: end */
CHAR* p = fgets(buf, size, f);
/* guoye: start */
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 1\n");
/* guoye: end */
if (p == NULL) // EOF reached: next time feof() = true
{
if (ferror(f))
RuntimeError("error reading line: %s", strerror(errno));
buf[0] = 0;
return buf;
}
size_t n = strnlen(p, size);
// check for buffer overflow
if (n >= (size_t)size - 1)
{
/* guoye: start */
// basic_string<CHAR> example(p, n < 100 ? n : 100);
std::basic_string<CHAR> example(p, n < 100 ? n : 100);
/* guoye: end */
uint64_t filepos = fgetpos(f); // (for error message only)
RuntimeError("input line too long at file offset %d (max. %d characters allowed) [%s ...]", (int)filepos, (int)size - 1, msra::strfun::utf8(example).c_str());
}
// remove newline at end
if (n > 0 && p[n - 1] == '\n') // UNIX and Windows style
{
n--;
p[n] = 0;
if (n > 0 && p[n - 1] == '\r') // Windows style
{
n--;
p[n] = 0;
}
}
else if (n > 0 && p[n - 1] == '\r') // Mac style
{
n--;
p[n] = 0;
}
return buf;
}
// this is add to fix the code bug, without this, the code does not support wchar
template <class CHAR>
CHAR* fgetlinew(FILE* f, CHAR* buf, int size)
{
// TODO: we should redefine this to write UTF-16 (which matters on GCC which defines wchar_t as 32 bit)
/* guoye: start */
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 0\n");
/* guoye: end */
CHAR* p = fgets(buf, size, f);
/* guoye: start */
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 1\n");
/* guoye: end */
if (p == NULL) // EOF reached: next time feof() = true
{
if (ferror(f))
RuntimeError("error reading line: %s", strerror(errno));
buf[0] = L'\0';
return buf;
}
size_t n = wcsnlen(p, size);
// check for buffer overflow
if (n >= (size_t)size - 1)
{
/* guoye: start */
// basic_string<CHAR> example(p, n < 100 ? n : 100);
std::basic_string<CHAR> example(p, n < 100 ? n : 100);
/* guoye: end */
uint64_t filepos = fgetpos(f); // (for error message only)
RuntimeError("input line too long at file offset %d (max. %d characters allowed) [%s ...]", (int)filepos, (int)size - 1, msra::strfun::utf8(example).c_str());
}
// remove newline at end
if (n > 0 && p[n - 1] == L'\n') // UNIX and Windows style
{
n--;
p[n] = L'\0';
if (n > 0 && p[n - 1] == L'\r') // Windows style
{
n--;
p[n] = L'\0';
}
}
else if (n > 0 && p[n - 1] == L'\r') // Mac style
{
n--;
p[n] = L'\0';
}
return buf;
}
template <class CHAR, size_t n>
CHAR* fgetlinew(FILE* f, CHAR(&buf)[n])
{
/* guoye: start */
// fprintf(stderr, "\n fileutil.h: fgetline(FILE* f, CHAR(&buf)[n]): debug 0\n");
return fgetlinew(f, buf, n);
/* guoye: end */
}
/* guoye: end */
template <class CHAR, size_t n>
CHAR* fgetline(FILE* f, CHAR(&buf)[n])
{
/* guoye: start */
// fprintf(stderr, "\n fileutil.h: fgetline(FILE* f, CHAR(&buf)[n]): debug 0\n");
return fgetline(f, buf, n);
/* guoye: end */
}
std::string fgetline(FILE* f);
std::wstring fgetlinew(FILE* f);
@ -902,9 +1037,40 @@ static inline String& trim(String& s)
{
return ltrim(rtrim(s));
}
/* guoye: start */
template<class String>
std::vector<String> SplitString(const String& str, const String& sep);
// move from fileutil.h, the definition and declartion should be at the same file.
// vector<String> SplitString(const String& str, const String& sep)
std::vector<String> SplitString(const String& str, const String& sep)
/* guoye: end */
{
/* guoye: start */
// vector<String> vstr;
std::vector<String> vstr;
/* guoye: end */
String csub;
size_t ifound = 0;
size_t ifoundlast = ifound;
ifound = str.find_first_of(sep, ifound);
while (ifound != String::npos)
{
csub = str.substr(ifoundlast, ifound - ifoundlast);
if (!csub.empty())
vstr.push_back(csub);
ifoundlast = ifound + 1;
ifound = str.find_first_of(sep, ifoundlast);
}
ifound = str.length();
csub = str.substr(ifoundlast, ifound - ifoundlast);
if (!csub.empty())
vstr.push_back(csub);
return vstr;
}
/* guoye: end */
template<class String, class Char>
std::vector<String> SplitString(const String& str, const Char* sep) { return SplitString(str, String(sep)); }
@ -912,4 +1078,8 @@ std::wstring s2ws(const std::string& str);
std::string ws2s(const std::wstring& wstr);
/* guoye: start */
// #include "../fileutil.cpp"
/* guoye: end */
#endif // _FILEUTIL_

Просмотреть файл

@ -0,0 +1,938 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// fileutil.h - file I/O with error checking
//
#pragma once
#ifndef _FILEUTIL_
#define _FILEUTIL_
#include "Basics.h"
#ifdef __WINDOWS__
#ifndef NOMINMAX
#define NOMINMAX
#endif // NOMINMAX
#include "Windows.h" // for mmreg.h and FILETIME
#include <mmreg.h>
#endif
#ifdef __unix__
#include <sys/types.h>
#include <sys/stat.h>
#endif
#include <algorithm> // for std::find
#include <vector>
#include <string>
#include <map>
#include <functional>
#include <cctype>
#include <errno.h>
#include <stdint.h>
#include <assert.h>
#include <string.h> // for strerror()
#include <stdexcept> // for exception
#include <fcntl.h>
#define FCLOSE_SUCCESS 0
/* guoye: start */
/*
#include "basetypes.h" //for attemp()
#include "ProgressTracing.h"
#include <unistd.h>
#include <glob.h>
#include <dirent.h>
#include <sys/sendfile.h>
#include <stdio.h>
#include <ctype.h>
#include <limits.h>
#include <memory>
#include <cwctype>
*/
// using namespace Microsoft::MSR::CNTK;
/* guoye: end */
// ----------------------------------------------------------------------------
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
// A pathname of "-" returns stdout or stdin, depending on mode, and it will
// change the binary mode if 'b' or 't' are given. If you use this, make sure
// not to fclose() such a handle.
// ----------------------------------------------------------------------------
FILE* fopenOrDie(const std::string& pathname, const char* mode);
FILE* fopenOrDie(const std::wstring& pathname, const wchar_t* mode);
#ifndef __unix__
// ----------------------------------------------------------------------------
// fsetmode(): set mode to binary or text
// ----------------------------------------------------------------------------
void fsetmode(FILE* f, char type);
#endif
// ----------------------------------------------------------------------------
// freadOrDie(): like fread() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void freadOrDie(void* ptr, size_t size, size_t count, FILE* f);
#ifdef _WIN32
void freadOrDie(void* ptr, size_t size, size_t count, const HANDLE f);
#endif
template <class _T>
void freadOrDie(_T& data, int num, FILE* f) // template for std::vector<>
{
data.resize(num);
if (data.size() > 0)
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
}
template <class _T>
void freadOrDie(_T& data, size_t num, FILE* f) // template for std::vector<>
{
data.resize(num);
if (data.size() > 0)
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
}
#ifdef _WIN32
template <class _T>
void freadOrDie(_T& data, int num, const HANDLE f) // template for std::vector<>
{
data.resize(num);
if (data.size() > 0)
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
}
template <class _T>
void freadOrDie(_T& data, size_t num, const HANDLE f) // template for std::vector<>
{
data.resize(num);
if (data.size() > 0)
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
}
#endif
// ----------------------------------------------------------------------------
// fwriteOrDie(): like fwrite() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fwriteOrDie(const void* ptr, size_t size, size_t count, FILE* f);
#ifdef _WIN32
void fwriteOrDie(const void* ptr, size_t size, size_t count, const HANDLE f);
#endif
template <class _T>
void fwriteOrDie(const _T& data, FILE* f) // template for std::vector<>
{
if (data.size() > 0)
fwriteOrDie(&data[0], sizeof(data[0]), data.size(), f);
}
#ifdef _WIN32
template <class _T>
void fwriteOrDie(const _T& data, const HANDLE f) // template for std::vector<>
{
if (data.size() > 0)
fwriteOrDie(&data[0], sizeof(data[0]), data.size(), f);
}
#endif
// ----------------------------------------------------------------------------
// fprintfOrDie(): like fprintf() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fprintfOrDie(FILE* f, const char* format, ...);
// ----------------------------------------------------------------------------
// fcloseOrDie(): like fclose() but terminate with err msg in case of error
// not yet implemented, but we should
// ----------------------------------------------------------------------------
#define fcloseOrDie fclose
// ----------------------------------------------------------------------------
// fflushOrDie(): like fflush() but terminate with err msg in case of error
// ----------------------------------------------------------------------------
void fflushOrDie(FILE* f);
// ----------------------------------------------------------------------------
// filesize(): determine size of the file in bytes
// ----------------------------------------------------------------------------
size_t filesize(const wchar_t* pathname);
size_t filesize(FILE* f);
int64_t filesize64(const wchar_t* pathname);
// ----------------------------------------------------------------------------
// fseekOrDie(),ftellOrDie(), fget/setpos(): seek functions with error handling
// ----------------------------------------------------------------------------
// 32-bit offsets only
long fseekOrDie(FILE* f, long offset, int mode = SEEK_SET);
#define ftellOrDie ftell
// ----------------------------------------------------------------------------
// fget/setpos(): seek functions with error handling
// ----------------------------------------------------------------------------
uint64_t fgetpos(FILE* f);
void fsetpos(FILE* f, uint64_t pos);
// ----------------------------------------------------------------------------
// unlinkOrDie(): unlink() with error handling
// ----------------------------------------------------------------------------
void unlinkOrDie(const std::string& pathname);
void unlinkOrDie(const std::wstring& pathname);
// ----------------------------------------------------------------------------
// renameOrDie(): rename() with error handling
// ----------------------------------------------------------------------------
void renameOrDie(const std::string& from, const std::string& to);
void renameOrDie(const std::wstring& from, const std::wstring& to);
// ----------------------------------------------------------------------------
// copyOrDie(): copy file with error handling.
// ----------------------------------------------------------------------------
void copyOrDie(const std::string& from, const std::string& to);
void copyOrDie(const std::wstring& from, const std::wstring& to);
// ----------------------------------------------------------------------------
// fexists(): test if a file exists
// ----------------------------------------------------------------------------
bool fexists(const char* pathname);
bool fexists(const wchar_t* pathname);
inline bool fexists(const std::string& pathname)
{
return fexists(pathname.c_str());
}
inline bool fexists(const std::wstring& pathname)
{
return fexists(pathname.c_str());
}
// ----------------------------------------------------------------------------
// funicode(): test if a file uses unicode
// ----------------------------------------------------------------------------
bool funicode(FILE* f);
// ----------------------------------------------------------------------------
// fskipspace(): skip space characters
// ----------------------------------------------------------------------------
bool fskipspace(FILE* F);
bool fskipwspace(FILE* F);
// ----------------------------------------------------------------------------
// fgetline(): like fgets() but terminate with err msg in case of error;
// removes the newline character at the end (like gets()), returned buffer is
// always 0-terminated; has second version that returns an STL std::string instead
// fgetstring(): read a 0-terminated std::string (terminate if error)
// fgetword(): read a space-terminated token (terminate if error)
// fskipNewLine(): skip all white space until end of line incl. the newline
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
// fputstring(): write a 0-terminated std::string (terminate if error)
// ----------------------------------------------------------------------------
void fputstring(FILE* f, const char*);
void fputstring(const HANDLE f, const char* str);
void fputstring(FILE* f, const std::string&);
void fputstring(FILE* f, const wchar_t*);
void fputstring(FILE* f, const std::wstring&);
template <class CHAR>
CHAR* fgetline(FILE* f, CHAR* buf, int size);
template <class CHAR, size_t n>
CHAR* fgetline(FILE* f, CHAR(&buf)[n])
{
/* guoye: start */
// fprintf(stderr, "\n fileutil.h: fgetline(FILE* f, CHAR(&buf)[n]): debug 0\n");
return fgetline(f, buf, n);
/* guoye: end */
}
std::string fgetline(FILE* f);
std::wstring fgetlinew(FILE* f);
void fgetline(FILE* f, std::string& s, std::vector<char>& buf);
void fgetline(FILE* f, std::wstring& s, std::vector<char>& buf);
void fgetline(FILE* f, std::vector<char>& buf);
void fgetline(FILE* f, std::vector<wchar_t>& buf);
const char* fgetstring(FILE* f, char* buf, int size);
template <size_t n>
const char* fgetstring(FILE* f, char(&buf)[n])
{
return fgetstring(f, buf, n);
}
const char* fgetstring(const HANDLE f, char* buf, int size);
template <size_t n>
const char* fgetstring(const HANDLE f, char(&buf)[n])
{
return fgetstring(f, buf, n);
}
const wchar_t* fgetstring(FILE* f, wchar_t* buf, int size);
std::wstring fgetwstring(FILE* f);
std::string fgetstring(FILE* f);
const char* fgettoken(FILE* f, char* buf, int size);
template <size_t n>
const char* fgettoken(FILE* f, char(&buf)[n])
{
return fgettoken(f, buf, n);
}
std::string fgettoken(FILE* f);
const wchar_t* fgettoken(FILE* f, wchar_t* buf, int size);
std::wstring fgetwtoken(FILE* f);
int fskipNewline(FILE* f, bool skip = true);
// ----------------------------------------------------------------------------
// fputstring(): write a 0-terminated std::string (terminate if error)
// ----------------------------------------------------------------------------
void fputstring(FILE* f, const char*);
#ifdef _WIN32
void fputstring(const HANDLE f, const char* str);
#endif
void fputstring(FILE* f, const std::string&);
void fputstring(FILE* f, const wchar_t*);
void fputstring(FILE* f, const std::wstring&);
// ----------------------------------------------------------------------------
// fgetTag(): read a 4-byte tag & return as a std::string
// ----------------------------------------------------------------------------
std::string fgetTag(FILE* f);
// ----------------------------------------------------------------------------
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcheckTag(FILE* f, const char* expectedTag);
#ifdef _WIN32
void fcheckTag(const HANDLE f, const char* expectedTag);
#endif
void fcheckTag_ascii(FILE* f, const std::string& expectedTag);
// ----------------------------------------------------------------------------
// fcompareTag(): compare two tags; terminate if wrong tag
// ----------------------------------------------------------------------------
void fcompareTag(const std::string& readTag, const std::string& expectedTag);
// ----------------------------------------------------------------------------
// fputTag(): write a 4-byte tag
// ----------------------------------------------------------------------------
void fputTag(FILE* f, const char* tag);
#ifdef _WIN32
void fputTag(const HANDLE f, const char* tag);
#endif
// ----------------------------------------------------------------------------
// fskipstring(): skip a 0-terminated std::string, such as a pad std::string
// ----------------------------------------------------------------------------
void fskipstring(FILE* f);
// ----------------------------------------------------------------------------
// fpad(): write a 0-terminated std::string to pad file to a n-byte boundary
// ----------------------------------------------------------------------------
void fpad(FILE* f, int n);
// ----------------------------------------------------------------------------
// fgetbyte(): read a byte value
// ----------------------------------------------------------------------------
char fgetbyte(FILE* f);
// ----------------------------------------------------------------------------
// fgetshort(): read a short value
// ----------------------------------------------------------------------------
short fgetshort(FILE* f);
short fgetshort_bigendian(FILE* f);
// ----------------------------------------------------------------------------
// fgetint24(): read a 3-byte (24-bit) int value
// ----------------------------------------------------------------------------
int fgetint24(FILE* f);
// ----------------------------------------------------------------------------
// fgetint(): read an int value
// ----------------------------------------------------------------------------
int fgetint(FILE* f);
#ifdef _WIN32
int fgetint(const HANDLE f);
#endif
int fgetint_bigendian(FILE* f);
int fgetint_ascii(FILE* f);
// ----------------------------------------------------------------------------
// fgetlong(): read an long value
// ----------------------------------------------------------------------------
long fgetlong(FILE* f);
// ----------------------------------------------------------------------------
// fgetfloat(): read a float value
// ----------------------------------------------------------------------------
float fgetfloat(FILE* f);
float fgetfloat_bigendian(FILE* f);
float fgetfloat_ascii(FILE* f);
// ----------------------------------------------------------------------------
// fgetdouble(): read a double value
// ----------------------------------------------------------------------------
double fgetdouble(FILE* f);
#ifdef _WIN32
// ----------------------------------------------------------------------------
// fgetwav(): read an entire .wav file
// ----------------------------------------------------------------------------
void fgetwav(FILE* f, std::vector<short>& wav, int& sampleRate);
void fgetwav(const std::wstring& fn, std::vector<short>& wav, int& sampleRate);
// ----------------------------------------------------------------------------
// fputwav(): save data into a .wav file
// ----------------------------------------------------------------------------
void fputwav(FILE* f, const std::vector<short>& wav, int sampleRate, int nChannels = 1);
void fputwav(const std::wstring& fn, const std::vector<short>& wav, int sampleRate, int nChannels = 1);
#endif
// ----------------------------------------------------------------------------
// fputbyte(): write a byte value
// ----------------------------------------------------------------------------
void fputbyte(FILE* f, char val);
// ----------------------------------------------------------------------------
// fputshort(): write a short value
// ----------------------------------------------------------------------------
void fputshort(FILE* f, short val);
// ----------------------------------------------------------------------------
// fputint24(): write a 3-byte (24-bit) int value
// ----------------------------------------------------------------------------
void fputint24(FILE* f, int v);
// ----------------------------------------------------------------------------
// fputint(): write an int value
// ----------------------------------------------------------------------------
void fputint(FILE* f, int val);
// ----------------------------------------------------------------------------
// fputlong(): write an long value
// ----------------------------------------------------------------------------
void fputlong(FILE* f, long val);
#ifdef _WIN32
void fputint(const HANDLE f, int v);
#endif
// ----------------------------------------------------------------------------
// fputfloat(): write a float value
// ----------------------------------------------------------------------------
void fputfloat(FILE* f, float val);
// ----------------------------------------------------------------------------
// fputdouble(): write a double value
// ----------------------------------------------------------------------------
void fputdouble(FILE* f, double val);
// template versions of put/get functions for binary files
template <typename T>
void fput(FILE* f, T v)
{
fwriteOrDie(&v, sizeof(v), 1, f);
}
// template versions of put/get functions for binary files
template <typename T>
void fget(FILE* f, T& v)
{
freadOrDie((void*) &v, sizeof(v), 1, f);
}
// GetFormatString - get the format std::string for a particular type
template <typename T>
const wchar_t* GetFormatString(T /*t*/)
{
// if this _ASSERT goes off it means that you are using a type that doesn't have
// a read and/or write routine.
// If the type is a user defined class, you need to create some global functions that handles file in/out.
// for example:
// File& operator>>(File& stream, MyClass& test);
// File& operator<<(File& stream, MyClass& test);
//
// in your class you will probably want to add these functions as friends so you can access any private members
// friend File& operator>>(File& stream, MyClass& test);
// friend File& operator<<(File& stream, MyClass& test);
//
// if you are using wchar_t* or char* types, these use other methods because they require buffers to be passed
// either use std::string and std::wstring, or use the WriteString() and ReadString() methods
assert(false); // need a specialization
return NULL;
}
// GetFormatString - specalizations to get the format std::string for a particular type
template <>
const wchar_t* GetFormatString(char);
template <>
const wchar_t* GetFormatString(wchar_t);
template <>
const wchar_t* GetFormatString(short);
template <>
const wchar_t* GetFormatString(int);
template <>
const wchar_t* GetFormatString(long);
template <>
const wchar_t* GetFormatString(unsigned short);
template <>
const wchar_t* GetFormatString(unsigned int);
template <>
const wchar_t* GetFormatString(unsigned long);
template <>
const wchar_t* GetFormatString(float);
template <>
const wchar_t* GetFormatString(double);
template <>
const wchar_t* GetFormatString(unsigned long long);
template <>
const wchar_t* GetFormatString(long long);
template <>
const wchar_t* GetFormatString(const char*);
template <>
const wchar_t* GetFormatString(const wchar_t*);
// GetScanFormatString - get the format std::string for a particular type
template <typename T>
const wchar_t* GetScanFormatString(T)
{
assert(false); // need a specialization
return NULL;
}
// GetScanFormatString - specalizations to get the format std::string for a particular type
template <>
const wchar_t* GetScanFormatString(char);
template <>
const wchar_t* GetScanFormatString(wchar_t);
template <>
const wchar_t* GetScanFormatString(short);
template <>
const wchar_t* GetScanFormatString(int);
template <>
const wchar_t* GetScanFormatString(long);
template <>
const wchar_t* GetScanFormatString(unsigned short);
template <>
const wchar_t* GetScanFormatString(unsigned int);
template <>
const wchar_t* GetScanFormatString(unsigned long);
template <>
const wchar_t* GetScanFormatString(float);
template <>
const wchar_t* GetScanFormatString(double);
template <>
const wchar_t* GetScanFormatString(unsigned long long);
template <>
const wchar_t* GetScanFormatString(long long);
// ----------------------------------------------------------------------------
// fgetText(): get a value from a text file
// ----------------------------------------------------------------------------
template <typename T>
void fgetText(FILE* f, T& v)
{
int rc = ftrygetText(f, v);
if (rc == 0)
Microsoft::MSR::CNTK::RuntimeError("error reading value from file (invalid format)");
else if (rc == EOF)
Microsoft::MSR::CNTK::RuntimeError("error reading from file: %s", strerror(errno));
assert(rc == 1);
}
// version to try and get a std::string, and not throw exceptions if contents don't match
template <typename T>
int ftrygetText(FILE* f, T& v)
{
const wchar_t* formatString = GetScanFormatString<T>(v);
int rc = fwscanf(f, formatString, &v);
assert(rc == 1 || rc == 0);
return rc;
}
template <>
int ftrygetText<bool>(FILE* f, bool& v);
// ----------------------------------------------------------------------------
// fgetText() specializations for fwscanf_s differences: get a value from a text file
// ----------------------------------------------------------------------------
void fgetText(FILE* f, char& v);
void fgetText(FILE* f, wchar_t& v);
// ----------------------------------------------------------------------------
// fputText(): write a value out as text
// ----------------------------------------------------------------------------
template <typename T>
void fputText(FILE* f, T v)
{
const wchar_t* formatString = GetFormatString(v);
int rc = fwprintf(f, formatString, v);
if (rc == 0)
Microsoft::MSR::CNTK::RuntimeError("error writing value to file, no values written");
else if (rc < 0)
Microsoft::MSR::CNTK::RuntimeError("error writing to file: %s", strerror(errno));
}
// ----------------------------------------------------------------------------
// fputText(): write a bool out as character
// ----------------------------------------------------------------------------
template <>
void fputText<bool>(FILE* f, bool v);
// ----------------------------------------------------------------------------
// fputfile(): write a binary block or a std::string as a file
// ----------------------------------------------------------------------------
void fputfile(const std::wstring& pathname, const std::vector<char>& buffer);
void fputfile(const std::wstring& pathname, const std::wstring&);
void fputfile(const std::wstring& pathname, const std::string&);
// ----------------------------------------------------------------------------
// fgetfile(): load a file as a binary block
// ----------------------------------------------------------------------------
void fgetfile(const std::wstring& pathname, std::vector<char>& buffer);
void fgetfile(FILE* f, std::vector<char>& buffer);
namespace msra { namespace files {
void fgetfilelines(const std::wstring& pathname, std::vector<char>& readbuffer, std::vector<std::string>& lines, int numberOfTries = 1);
static inline std::vector<std::string> fgetfilelines(const std::wstring& pathname)
{
std::vector<char> buffer;
std::vector<std::string> lines;
fgetfilelines(pathname, buffer, lines);
return lines;
}
std::vector<char*> fgetfilelines(const std::wstring& pathname, std::vector<char>& readbuffer, int numberOfTries = 1);
}}
#ifdef _WIN32
// ----------------------------------------------------------------------------
// getfiletime(), setfiletime(): access modification time
// ----------------------------------------------------------------------------
bool getfiletime(const std::wstring& path, FILETIME& time);
void setfiletime(const std::wstring& path, const FILETIME& time);
#endif
// ----------------------------------------------------------------------------
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
// ----------------------------------------------------------------------------
void expand_wildcards(const std::wstring& path, std::vector<std::wstring>& paths);
// ----------------------------------------------------------------------------
// make_intermediate_dirs() -- make all intermediate dirs on a path
// ----------------------------------------------------------------------------
namespace msra { namespace files {
void make_intermediate_dirs(const std::wstring& filepath);
std::vector<std::wstring> get_all_files_from_directory(const std::wstring& directory);
}}
// ----------------------------------------------------------------------------
// fuptodate() -- test whether an output file is at least as new as an input file
// ----------------------------------------------------------------------------
namespace msra { namespace files {
bool fuptodate(const std::wstring& target, const std::wstring& input, bool inputrequired = true);
};
};
#ifdef _WIN32
// ----------------------------------------------------------------------------
// simple support for WAV file I/O
// ----------------------------------------------------------------------------
typedef struct wavehder
{
char riffchar[4];
unsigned int RiffLength;
char wavechar[8];
unsigned int FmtLength;
signed short wFormatTag;
signed short nChannels;
unsigned int nSamplesPerSec;
unsigned int nAvgBytesPerSec;
signed short nBlockAlign;
signed short wBitsPerSample;
char datachar[4];
unsigned int DataLength;
private:
void prepareRest(int SampleCount);
public:
void prepare(unsigned int Fs, int Bits, int Channels, int SampleCount);
void prepare(const WAVEFORMATEX& wfx, int SampleCount);
unsigned int read(FILE* f, signed short& wRealFormatTag, int& bytesPerSample);
void write(FILE* f);
static void update(FILE* f);
} WAVEHEADER;
// ----------------------------------------------------------------------------
// fgetwfx(), fputwfx(): I/O of wave file headers only
// ----------------------------------------------------------------------------
unsigned int fgetwfx(FILE* f, WAVEFORMATEX& wfx);
void fputwfx(FILE* f, const WAVEFORMATEX& wfx, unsigned int numSamples);
// ----------------------------------------------------------------------------
// fgetraw(): read data of .wav file, and separate data of multiple channels.
// For example, data[i][j]: i is channel index, 0 means the first
// channel. j is sample index.
// ----------------------------------------------------------------------------
void fgetraw(FILE* f, std::vector<std::vector<short>>& data, const WAVEHEADER& wavhd);
#endif
// ----------------------------------------------------------------------------
// auto_file_ptr -- FILE* with auto-close; use auto_file_ptr instead of FILE*.
// Warning: do not pass an auto_file_ptr to a function that calls fclose(),
// except for fclose() itself.
// ----------------------------------------------------------------------------
class auto_file_ptr
{
FILE* f;
FILE* operator=(auto_file_ptr&); // can't ref-count: no assignment
auto_file_ptr(auto_file_ptr&);
void close()
{
if (f && f != stdin && f != stdout && f != stderr)
{
int rc = ::fclose(f);
if ((rc != FCLOSE_SUCCESS) && !std::uncaught_exception())
RuntimeError("auto_file_ptr: failed to close file: %s", strerror(errno));
f = NULL;
}
}
#pragma warning(push)
#pragma warning(disable : 4996)
void openfailed(const std::string& path)
{
Microsoft::MSR::CNTK::RuntimeError("auto_file_ptr: error opening file '%s': %s", path.c_str(), strerror(errno));
}
#pragma warning(pop)
protected:
friend int fclose(auto_file_ptr&); // explicit close (note: may fail)
int fclose()
{
int rc = ::fclose(f);
if (rc == 0)
f = NULL;
return rc;
}
public:
auto_file_ptr()
: f(NULL)
{
}
~auto_file_ptr()
{
close();
}
#pragma warning(push)
#pragma warning(disable : 4996)
auto_file_ptr(const char* path, const char* mode)
{
f = fopen(path, mode);
if (f == NULL)
openfailed(path);
}
auto_file_ptr(const wchar_t* wpath, const char* mode)
{
f = _wfopen(wpath, msra::strfun::utf16(mode).c_str());
if (f == NULL)
openfailed(msra::strfun::utf8(wpath));
}
#pragma warning(pop)
FILE* operator=(FILE* other)
{
close();
f = other;
return f;
}
auto_file_ptr(FILE* other)
: f(other)
{
}
operator FILE*() const
{
return f;
}
FILE* operator->() const
{
return f;
}
void swap(auto_file_ptr& other) throw()
{
std::swap(f, other.f);
}
};
inline int fclose(auto_file_ptr& af)
{
return af.fclose();
}
namespace msra { namespace files {
// ----------------------------------------------------------------------------
// textreader -- simple reader for text files --we need this all the time!
// Currently reads 8-bit files, but can return as wstring, in which case
// they are interpreted as UTF-8 (without BOM).
// Note: Not suitable for pipes or typed input due to readahead (fixable if needed).
// ----------------------------------------------------------------------------
class textreader
{
auto_file_ptr f;
std::vector<char> buf; // read buffer (will only grow, never shrink)
int ch; // next character (we need to read ahead by one...)
char getch()
{
char prevch = (char) ch;
ch = fgetc(f);
return prevch;
}
public:
textreader(const std::wstring& path)
: f(path.c_str(), "rb")
{
buf.reserve(10000);
ch = fgetc(f);
}
operator bool() const
{
return ch != EOF;
} // true if still a line to read
std::string getline() // get and consume the next line
{
if (ch == EOF)
LogicError("textreader: attempted to read beyond EOF");
assert(buf.empty());
// get all line's characters --we recognize UNIX (LF), DOS (CRLF), and Mac (CR) convention
while (ch != EOF && ch != '\n' && ch != '\r')
buf.push_back(getch());
if (ch != EOF && getch() == '\r' && ch == '\n')
getch(); // consume EOLN char
std::string line(buf.begin(), buf.end());
buf.clear();
return line;
}
std::wstring wgetline()
{
return msra::strfun::utf16(getline());
}
};
}
}
// ----------------------------------------------------------------------------
// temp functions -- clean these up
// ----------------------------------------------------------------------------
// split a pathname into directory and filename
static inline void splitpath(const std::wstring& path, std::wstring& dir, std::wstring& file)
{
size_t pos = path.find_last_of(L"\\:/"); // DOS drives, UNIX, Windows
if (pos == path.npos) // no directory found
{
dir.clear();
file = path;
}
else
{
dir = path.substr(0, pos);
file = path.substr(pos + 1);
}
}
// test if a pathname is a relative path
// A relative path is one that can be appended to a directory.
// Drive-relative paths, such as D:file, are considered non-relative.
static inline bool relpath(const wchar_t* path)
{ // this is a wild collection of pathname conventions in Windows
if (path[0] == '/' || path[0] == '\\') // e.g. \WINDOWS
return false;
if (path[0] && path[1] == ':') // drive syntax
return false;
// ... TODO: handle long NT paths
return true; // all others
}
template <class Char>
static inline bool relpath(const std::basic_string<Char>& s)
{
return relpath(s.c_str());
}
// trim from start
template<class String>
static inline String& ltrim(String& s)
{
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](typename String::value_type c){ return !iscspace(c); }));
return s;
}
// trim from end
template<class String>
static inline String& rtrim(String& s)
{
s.erase(std::find_if(s.rbegin(), s.rend(), [](typename String::value_type c){ return !iscspace(c); }).base(), s.end());
return s;
}
// trim from both ends
template<class String>
static inline String& trim(String& s)
{
return ltrim(rtrim(s));
}
template<class String>
std::vector<String> SplitString(const String& str, const String& sep);
template<class String, class Char>
std::vector<String> SplitString(const String& str, const Char* sep) { return SplitString(str, String(sep)); }
std::wstring s2ws(const std::string& str);
std::string ws2s(const std::wstring& wstr);
/* guoye: start */
#include "../fileutil.cpp"
/* guoye: end */
#endif // _FILEUTIL_

Просмотреть файл

@ -23,7 +23,9 @@
#include <algorithm> // for find()
#include "simplesenonehmm.h"
#include "Matrix.h"
/* guoye: start */
#include <set>
/* guoye: end */
namespace msra { namespace math {
class ssematrixbase;
@ -67,7 +69,33 @@ enum mbrclassdefinition // used to identify definition of class in minimum bayes
// ===========================================================================
class lattice
{
public:
mutable int verbosity;
/* guoye: start */
// definie structure for nbest EMBR
struct TokenInfo
{
double score; // the score of the token
size_t prev_edge_index; // edge ending with this token, edge start points to the previous node
size_t prev_token_index; // the token index in the previous node
};
struct PrevTokenInfo
{
size_t prev_edge_index;
size_t prev_token_index;
double path_score; // use pure to indicatethe path score does not consider the WER of the path
};
struct NBestToken
{
// for sorting purpose
// make sure the map is stored with keys in descending order
std::map<double, std::vector<PrevTokenInfo>, std::greater <double>> mp_score_token_infos; // for sorting the tokens in map
std::vector<TokenInfo> vt_nbest_tokens; // stores the nbest tokens in the node
};
/* guoye: end */
struct header_v1_v2
{
size_t numnodes : 32;
@ -90,12 +118,22 @@ private:
static const unsigned int NOEDGE = 0xffffff; // 24 bits
// static_assert (sizeof (nodeinfo) == 8, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
// ensure type size as these are expected to be of this size in the files we read
static_assert(sizeof(nodeinfo) == 2, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
/* guoye: start */
static_assert(sizeof(nodeinfo) == 16, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
/* guoye: end */
static_assert(sizeof(edgeinfowithscores) == 16, "unexpected size of edgeinfowithscores");
static_assert(sizeof(aligninfo) == 4, "unexpected size of aligninfo");
std::vector<nodeinfo> nodes;
/* guoye: start */
mutable std::vector<std::vector<uint64_t>> vt_node_out_edge_indices; // vt_node_out_edge_indices[i]: it stores the outgoing edge indices starting from node i
std::vector<bool> is_special_words; // true if it is special words that do not count to WER computation, false if it is not
/* guoye: end */
std::vector<edgeinfowithscores> edges;
std::vector<aligninfo> align;
// V2 lattices --for a while, we will store both in RAM, until all code is updated
static int fsgn(float f)
{
@ -217,6 +255,12 @@ private:
public: // TODO: make private again once
// construct from edges/align
// This is also used for merging, where the edges[] array is not correctly sorted. So don't assume this here.
/* guoye: start */
void erase_node_out_edges(size_t nodeidx, size_t edgeidx_start, size_t edgeidx_end) const
{
vt_node_out_edge_indices[nodeidx].erase(vt_node_out_edge_indices[nodeidx].begin() + edgeidx_start, vt_node_out_edge_indices[nodeidx].begin() + edgeidx_end);
}
/* guoye: end */
void builduniquealignments(size_t spunit = SIZE_MAX /*fix this later*/)
{
// infer /sp/ unit if not given
@ -701,6 +745,7 @@ private:
const float lmf, const float wp, const float amf, const_array_ref<size_t>& uids,
const edgealignments& thisedgealignments, std::vector<double>& Eframescorrect) const;
void sMBRerrorsignal(parallelstate& parallelstate,
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg,
const std::vector<double>& logpps, const float amf, double minlogpp,
@ -736,7 +781,8 @@ private:
const std::vector<double>& logpps, const float amf,
const std::vector<double>& logEframescorrect, const double logEframescorrecttotal,
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg) const;
void parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
const std::vector<double>& edgeweights, msra::math::ssematrixbase& errorsignal) const;
void parallelmmierrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
const std::vector<double>& logpps, msra::math::ssematrixbase& errorsignal) const;
@ -747,6 +793,20 @@ private:
const_array_ref<size_t>& uids, std::vector<double>& logEframescorrect,
std::vector<double>& Eframescorrectbuf, double& logEframescorrecttotal) const;
/* guoye: start */
double parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector<float>& edgeacscores,
const float lmf, const float wp,
const float amf, std::vector<double>& edgelogbetas,
std::vector<double>& logbetas) const;
void EMBRsamplepaths(const std::vector<double> &edgelogbetas,
const std::vector<double> &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords, std::vector< std::vector<size_t> > & vt_paths) const;
void EMBRnbestpaths(std::vector<NBestToken>& tokenlattice, std::vector<std::vector<size_t>> & vt_paths, std::vector<double>& path_posterior_probs) const;
double get_edge_weights(std::vector<size_t>& wids, std::vector<std::vector<size_t>>& vt_paths, std::vector<double>& vt_edge_weights, std::vector<double>& vt_path_posterior_probs, std::string getPathMethodEMBR, double& onebestwer) const;
/* guoye: end */
static double scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript,
const std::vector<float>& transcriptunigrams, const msra::math::ssematrixbase& logLLs,
const msra::asr::simplesenonehmm& hset, const float lmf, const float wp, const float amf);
@ -762,6 +822,16 @@ private:
std::vector<double>& logEframescorrect, std::vector<double>& Eframescorrectbuf,
double& logEframescorrecttotal) const;
/* guoye: start */
double backwardlatticeEMBR(const std::vector<float>& edgeacscores, parallelstate& parallelstate, std::vector<double> &edgelogbetas,
std::vector<double>& logbetas,
const float lmf, const float wp, const float amf) const;
void constructnodenbestoken(std::vector<NBestToken> &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const;
double nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate &parallelstate, std::vector<NBestToken> &vt_nbesttokens, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords,
const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids) const;
/* guoye: end */
public:
// construct from a HTK lattice file
void fromhtklattice(const std::wstring& path, const std::unordered_map<std::string, size_t>& unitmap);
@ -770,6 +840,156 @@ public:
void frommlf(const std::wstring& key, const std::unordered_map<std::string, size_t>& unitmap, const msra::asr::htkmlfreader<msra::asr::htkmlfentry, lattice::htkmlfwordsequence>& labels,
const msra::lm::CMGramLM& lm, const msra::lm::CSymbolSet& unigramsymbols);
/* guoye: start */
// template <class IDMAP>
// void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids);
// move from latticearchive.cpp to .h, it requires template definition and delcaration are both in .h file
template <class IDMAP>
void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids)
/* guoye: end */
{
size_t version = freadtag(f, "LAT ");
if (version == 1)
{
freadOrDie(&info, sizeof(info), 1, f);
freadvector(f, "NODE", nodes, info.numnodes);
if (nodes.back().t != info.numframes)
/* guoye: start */
{
// RuntimeError("fread: mismatch between info.numframes and last node's time");
// sometimes, the data is corrputed, let's try to live with it
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
}
/* guoye: end */
freadvector(f, "EDGE", edges, info.numedges);
freadvector(f, "ALIG", align);
fcheckTag(f, "END ");
// map align ids to user's symmap --the lattice gets updated in place here
foreach_index(k, align)
align[k].updateunit(idmap); // updates itself
}
else if (version == 2)
{
freadOrDie(&info, sizeof(info), 1, f);
freadvector(f, "NODS", nodes, info.numnodes);
if (nodes.back().t != info.numframes)
{
/* guoye: start */
{
// RuntimeError("fread: mismatch between info.numframes and last node's time");
// sometimes, the data is corrputed, let's try to live with it
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
}
/* guoye: end */
}
freadvector(f, "EDGS", edges2, info.numedges); // uniqued edges
freadvector(f, "ALNS", uniquededgedatatokens); // uniqued alignments
fcheckTag(f, "END ");
/* guoye: start */
vt_node_out_edge_indices.resize(info.numnodes);
for (size_t j = 0; j < info.numedges; j++)
{
// an edge with !NULL pointing to not <s>
// this code make sure if you always start from <s> in the sampled path.
// mask here: we delay the processing in EMBRsamplepaths controlled by flag: enforceValidPathEMBR
// if (edges2[j].S == 0 && nodes[edges2[j].E].wid != 1) continue;
vt_node_out_edge_indices[edges2[j].S].push_back(j);
}
is_special_words.resize(info.numnodes);
for (size_t i = 0; i < info.numnodes; i++)
{
/*
if (nodes[i].wid == 0xfffff)
{
nodes[i].wid;
}
*/
if (specialwordids.find(int(nodes[i].wid)) != specialwordids.end()) is_special_words[i] = true;
else is_special_words[i] = false;
}
/* guoye: end */
// check if we need to map
#if 1 // post-bugfix for incorrect inference of spunit
if (info.impliedspunitid != SIZE_MAX && info.impliedspunitid >= idmap.size()) // we have buggy lattices like that--what do they mean??
{
fprintf(stderr, "fread: detected buggy spunit id %d which is out of range (%d entries in map)\n", (int)info.impliedspunitid, (int)idmap.size());
RuntimeError("fread: out of bounds spunitid");
}
#endif
// This is critical--we have a buggy lattice set that requires no mapping where mapping would fail
bool needsmapping = false;
foreach_index(k, idmap)
{
if (idmap[k] != (size_t)k
#if 1
&& (k != (int)idmap.size() - 1 || idmap[k] != spunit) // that HACK that we add one more /sp/ entry at the end...
#endif
)
{
needsmapping = true;
break;
}
}
// map align ids to user's symmap --the lattice gets updated in place here
if (needsmapping)
{
if (info.impliedspunitid != SIZE_MAX)
info.impliedspunitid = idmap[info.impliedspunitid];
// deal with broken (zero-token) edges
std::vector<bool> isendworkaround;
if (info.impliedspunitid != spunit)
{
fprintf(stderr, "fread: lattice with broken spunit, using workaround to handle potentially broken zero-token edges\n");
inferends(isendworkaround);
}
size_t uniquealignments = 1;
const size_t skipscoretokens = info.hasacscores ? 2 : 1;
for (size_t k = skipscoretokens; k < uniquededgedatatokens.size(); k++)
{
if (!isendworkaround.empty() && isendworkaround[k]) // secondary criterion to detect ends in broken lattices
{
k--; // don't advance, since nothing to advance over
}
else
{
// this is a regular token: update it in-place
auto& ai = uniquededgedatatokens[k];
if (ai.unit >= idmap.size())
RuntimeError("fread: broken-file heuristics failed");
ai.updateunit(idmap); // updates itself
if (!ai.last)
continue;
}
// if last then skip over the lm and ac scores
k += skipscoretokens;
uniquealignments++;
}
fprintf(stderr, "fread: mapped %d unique alignments\n", (int)uniquealignments);
}
if (info.impliedspunitid != spunit)
{
// fprintf (stderr, "fread: inconsistent spunit id in file %d vs. expected %d; due to erroneous heuristic\n", info.impliedspunitid, spunit); // [v-hansu] comment out becaues it takes up most of the log
// it's actually OK, we can live with this, since we only decompress and then move on without any assumptions
// RuntimeError("fread: mismatching /sp/ units");
}
// reconstruct old lattice format from this --TODO: remove once we change to new data representation
rebuildedges(info.impliedspunitid != spunit /*to be able to read somewhat broken V2 lattice archives*/);
}
else
RuntimeError("fread: unsupported lattice format version");
}
/* guoye: end */
// check consistency
// - only one end node
// - only forward edges
@ -995,6 +1215,7 @@ public:
}
}
/* guoye: start */
// read from a stream
// This can be used on an existing structure and will replace its content. May be useful to avoid memory allocations (resize() will not shrink memory).
// For efficiency, we will not check the inner consistency of the file here, but rather when we further process it.
@ -1003,7 +1224,7 @@ public:
// This will also map the aligninfo entries to the new symbol table, through idmap.
// V1 lattices will be converted. 'spsenoneid' is used in that process.
template <class IDMAP>
void fread(FILE* f, const IDMAP& idmap, size_t spunit)
void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids)
{
size_t version = freadtag(f, "LAT ");
if (version == 1)
@ -1011,7 +1232,11 @@ public:
freadOrDie(&info, sizeof(info), 1, f);
freadvector(f, "NODE", nodes, info.numnodes);
if (nodes.back().t != info.numframes)
RuntimeError("fread: mismatch between info.numframes and last node's time");
{
// RuntimeError("fread: mismatch between info.numframes and last node's time");
// sometimes, the data is corrputed, let's try to live with it
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
}
freadvector(f, "EDGE", edges, info.numedges);
freadvector(f, "ALIG", align);
fcheckTag(f, "END ");
@ -1024,11 +1249,15 @@ public:
freadOrDie(&info, sizeof(info), 1, f);
freadvector(f, "NODS", nodes, info.numnodes);
if (nodes.back().t != info.numframes)
RuntimeError("fread: mismatch between info.numframes and last node's time");
{
// RuntimeError("fread: mismatch between info.numframes and last node's time");
// sometimes, the data is corrputed, let's try to live with it
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
}
freadvector(f, "EDGS", edges2, info.numedges); // uniqued edges
freadvector(f, "ALNS", uniquededgedatatokens); // uniqued alignments
fcheckTag(f, "END ");
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap);
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap, specialwordids);
}
else
RuntimeError("fread: unsupported lattice format version");
@ -1055,8 +1284,35 @@ public:
// Helper method to process v2 Lattice format
template <class IDMAP>
void ProcessV2Lattice(size_t spunit, header_v1_v2& info, std::vector<aligninfo>& uniquededgedatatokens, const IDMAP& idmap)
void ProcessV2Lattice(size_t spunit, header_v1_v2& info, std::vector<aligninfo>& uniquededgedatatokens, const IDMAP& idmap, std::set<int>& specialwordids = {} )
{
/* guoye: start */
vt_node_out_edge_indices.resize(info.numnodes);
for (size_t j = 0; j < info.numedges; j++)
{
// an edge with !NULL pointing to not <s>
// this code make sure if you always start from <s> in the sampled path.
// mask here: we delay the processing in EMBRsamplepaths controlled by flag: enforceValidPathEMBR
// if (edges2[j].S == 0 && nodes[edges2[j].E].wid != 1) continue;
vt_node_out_edge_indices[edges2[j].S].push_back(j);
}
is_special_words.resize(info.numnodes);
for (size_t i = 0; i < info.numnodes; i++)
{
/*
if (nodes[i].wid == 0xfffff)
{
nodes[i].wid;
}
*/
if (specialwordids.find(int(nodes[i].wid)) != specialwordids.end()) is_special_words[i] = true;
else is_special_words[i] = false;
}
/* guoye: end */
// check if we need to map
if (info.impliedspunitid != SIZE_MAX && info.impliedspunitid >= idmap.size()) // we have buggy lattices like that--what do they mean??
{
@ -1124,7 +1380,9 @@ public:
rebuildedges(info.impliedspunitid != spunit /*to be able to read somewhat broken V2 lattice archives*/);
}
/* guoye: end */
// parallel versions (defined in parallelforwardbackward.cpp)
class parallelstate
{
@ -1152,6 +1410,14 @@ public:
const size_t getsilunitid();
void getedgeacscores(std::vector<float>& edgeacscores);
void getedgealignments(std::vector<unsigned short>& edgealignments);
/* guoye: start */
void getlogbetas(std::vector<double>& logbetas);
void getedgelogbetas(std::vector<double>& edgelogbetas);
void getedgeweights(std::vector<double>& edgeweights);
void setedgeweights(const std::vector<double>& edgeweights);
/* guoye: end */
// to work with CNTK's GPU memory
void setdevice(size_t DeviceId);
size_t getdevice();
@ -1166,11 +1432,30 @@ public:
// forward-backward function
// Note: logLLs and posteriors may be the same matrix (aliased).
/* start: guoye */
/*
double forwardbackward(parallelstate& parallelstate, const class msra::math::ssematrixbase& logLLs, const class msra::asr::simplesenonehmm& hmms,
class msra::math::ssematrixbase& result, class msra::math::ssematrixbase& errorsignalbuf,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, array_ref<size_t> uids, const_array_ref<size_t> bounds = const_array_ref<size_t>(),
const_array_ref<htkmlfwordsequence::word> transcript = const_array_ref<htkmlfwordsequence::word>(), const std::vector<float>& transcriptunigrams = std::vector<float>()) const;
*/
double forwardbackward(parallelstate& parallelstate, const class msra::math::ssematrixbase& logLLs, const class msra::asr::simplesenonehmm& hmms,
class msra::math::ssematrixbase& result, class msra::math::ssematrixbase& errorsignalbuf,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, const bool EMBR, const std::string EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const std::string getPathMethodEMBR, const std::string showWERMode,
const bool excludeSpecialWords, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR,
array_ref<size_t> uids, std::vector<size_t> wids, const_array_ref<size_t> bounds = const_array_ref<size_t>(),
const_array_ref<htkmlfwordsequence::word> transcript = const_array_ref<htkmlfwordsequence::word>(), const std::vector<float>& transcriptunigrams = std::vector<float>()) const;
/*
void embrerrorsignal(parallelstate &parallelstate,
std::vector<msra::math::ssematrixbase *> &abcs, const bool softalignstates, const msra::asr::simplesenonehmm &hset,
const edgealignments &thisedgealignments, std::vector<std::vector<size_t>>& vt_paths, std::vector<float>& path_weight, msra::math::ssematrixbase &errorsignal) const;
*/
void EMBRerrorsignal(parallelstate &parallelstate,
const edgealignments &thisedgealignments, std::vector<double>& edge_weights, msra::math::ssematrixbase &errorsignal) const;
/* end: guoye */
std::wstring key; // (keep our own name (key) so we can identify ourselves for diagnostics messages)
const wchar_t* getkey() const
{
@ -1358,8 +1643,14 @@ public:
if (sscanf(q, "[%" PRIu64 "]%c", &offset, &c) != 1)
#endif
RuntimeError("open: invalid TOC line (bad [] expression): %s", line);
if (!toc.insert(make_pair(key, latticeref(offset, archiveindex))).second)
RuntimeError("open: TOC entry leads to duplicate key: %s", line);
/* guoye: start */
// sometimes, the training will report this error. I believe it is due to some small data corruption, and fine to go on, so change the error to warning
// RuntimeError("open: TOC entry leads to duplicate key: %s", line);
fprintf(stderr, " open: TOC entry leads to duplicate key: %s\n", line);
/* guoye: end */
}
// initialize symmaps --alloc the array, but actually read the symmap on demand
@ -1390,7 +1681,10 @@ public:
// Lattices will have unit ids updated according to the modelsymmap.
// V1 lattices will be converted. 'spsenoneid' is used in the conversion for optimizing storing 0-frame /sp/ aligns.
void getlattice(const std::wstring& key, lattice& L,
size_t expectedframes = SIZE_MAX /*if unknown*/) const
/* guoye: start */
// size_t expectedframes = SIZE_MAX /*if unknown*/) const
std::set<int>& specialwordids, size_t expectedframes = SIZE_MAX) const
/* guoye: end */
{
auto iter = toc.find(key);
if (iter == toc.end())
@ -1417,7 +1711,11 @@ public:
// seek to start
fsetpos(f, offset);
// get it
L.fread(f, idmap, spunit);
/* guoye: start */
// L.fread(f, idmap, spunit);
L.fread(f, idmap, spunit, specialwordids);
/* guoye: end */
L.setverbosity(verbosity);
#ifdef HACK_IN_SILENCE // hack to simulate DEL in the lattice
const size_t silunit = getid(modelsymmap, "sil");
@ -1451,7 +1749,11 @@ public:
// - dump to stdout
// - merge two lattices (for merging numer into denom lattices)
static void convert(const std::wstring& intocpath, const std::wstring& intocpath2, const std::wstring& outpath,
const msra::asr::simplesenonehmm& hset);
/* guoye: start */
// const msra::asr::simplesenonehmm& hset);
const msra::asr::simplesenonehmm& hset, std::set<int>& specialwordids);
/* guoye: end */
};
};
};

Просмотреть файл

@ -62,10 +62,17 @@ public:
#endif
}
void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes) const
/* guoye: start */
// void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes) const
void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes, std::set<int>& specialwordids) const
/* guoye: end */
{
std::shared_ptr<latticepair> LP(new latticepair);
denlattices.getlattice(key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
/* guoye: start */
// denlattices.getlattice(key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
denlattices.getlattice(key, LP->second, specialwordids, expectedframes); // this loads the lattice from disk, using the existing L.second object
// fprintf(stderr, "latticesource.h:getlattices: %ls \n", key.c_str());
/* guoye: end */
L = LP;
}

Просмотреть файл

@ -12,6 +12,9 @@
#include <stdexcept>
#include <stdint.h>
#include <cstdio>
/* guoye: start */
#include <vector>
/* guoye: end */
#undef INITIAL_STRANGE // [v-hansu] initialize structs to strange values
#define PARALLEL_SIL // [v-hansu] process sil on CUDA, used in other files, please search this
@ -30,11 +33,22 @@ struct nodeinfo
// uint64_t firstinedge : 24; // index of first incoming edge
// uint64_t firstoutedge : 24; // index of first outgoing edge
// uint64_t t : 16; // time associated with this
/* guoye: start */
uint64_t wid; // word ID associated with the node
/* guoye: end */
unsigned short t; // time associated with this
nodeinfo(size_t pt)
: t((unsigned short) pt) // , firstinedge (NOEDGE), firstoutedge (NOEDGE)
nodeinfo(size_t pt, size_t pwid)
/* guoye: start */
// : t((unsigned short) pt) // , firstinedge (NOEDGE), firstoutedge (NOEDGE)
: t((unsigned short)pt), wid(pwid)
/* guoye: end */
{
checkoverflow(t, pt, "nodeinfo::t");
/* guoye: start */
checkoverflow(wid, pwid, "nodeinfo::wid");
/* guoye: end */
// checkoverflow (firstinedge, NOEDGE, "nodeinfo::firstinedge");
// checkoverflow (firstoutedge, NOEDGE, "nodeinfo::firstoutedge");
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -646,12 +646,47 @@ void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net,
const double& lmf /*= 14.0f*/,
const double& wp /*= 0.0f*/,
const double& bMMIfactor /*= 0.0f*/,
const bool& sMBR /*= false*/
/* guoye: start */
// const bool& sMBR /*= false*/
const bool& sMBR /*= false */,
const bool& EMBR /*= false */,
const string& EMBRUnit /* = "word" */,
const size_t& numPathsEMBR,
const bool& enforceValidPathEMBR,
const string& getPathMethodEMBR,
const string& showWERMode,
const bool& excludeSpecialWords,
const bool& wordNbest,
const bool& useAccInNbest,
const float& accWeightInNbest,
const size_t& numRawPathsEMBR
/* guoye: end */
)
{
fprintf(stderr, "Setting Hsmoothing weight to %.8g and frame-dropping threshhold to %.8g\n", hsmoothingWeight, frameDropThresh);
/* guoye: start */
/*
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, usesMBR=%s\n",
amf, lmf, wp, bMMIfactor, sMBR ? "true" : "false");
*/
if(EMBR)
{
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, useEMBR=true, EMBRUnit=%s, numPathsEMBR=%d, enforceValidPathEMBR = %d, getPathMethodEMBR = %s, showWERMode = %s, excludeSpecialWords = %d, wordNbest = %d, useAccInNbest = %d, accWeightInNbest = %f, numRawPathsEMBR = %d \n",
amf, lmf, wp, bMMIfactor, EMBRUnit.c_str(), int(numPathsEMBR), int(enforceValidPathEMBR), getPathMethodEMBR.c_str(), showWERMode.c_str(), int(excludeSpecialWords), int(wordNbest), int(useAccInNbest), float(accWeightInNbest), int(numRawPathsEMBR));
}
else if(sMBR)
{
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, usesMBR=true \n",
amf, lmf, wp, bMMIfactor);
}
else
{
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, useMMI=true \n",
amf, lmf, wp, bMMIfactor);
}
/* guoye: end */
list<ComputationNodeBasePtr> seqNodes = net->GetNodesWithType(OperationNameOf(SequenceWithSoftmaxNode), criterionNode);
if (seqNodes.size() == 0)
{
@ -665,7 +700,11 @@ void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net,
node->SetSmoothWeight(hsmoothingWeight);
node->SetFrameDropThresh(frameDropThresh);
node->SetReferenceAlign(doreferencealign);
node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR);
/* guoye: start */
// node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR);
node->SetMBR(sMBR || EMBR);
node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR, EMBR, EMBRUnit, numPathsEMBR, enforceValidPathEMBR, getPathMethodEMBR, showWERMode, excludeSpecialWords, wordNbest, useAccInNbest, accWeightInNbest, numRawPathsEMBR);
/* guoye: end */
}
}
}
@ -1549,17 +1588,35 @@ template void ComputationNetwork::Read<float>(const wstring& fileName);
template void ComputationNetwork::ReadPersistableParameters<float>(size_t modelVersion, File& fstream, bool create);
template void ComputationNetwork::PerformSVDecomposition<float>(const map<wstring, float>& SVDConfig, size_t alignedsize);
template /*static*/ void ComputationNetwork::SetBatchNormalizationTimeConstants<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double normalizationTimeConstant, double& prevNormalizationTimeConstant, double blendTimeConstant, double& prevBlendTimeConstant);
/* guoye: start */
/*
template void ComputationNetwork::SetSeqParam<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR);
*/
template void ComputationNetwork::SetSeqParam<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign, const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR, const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR);
/* guoye: end */
template void ComputationNetwork::SaveToDbnFile<float>(ComputationNetworkPtr net, const std::wstring& fileName) const;
template void ComputationNetwork::InitLearnableParametersWithBilinearFill<double>(const ComputationNodeBasePtr& node, size_t kernelWidth, size_t kernelHeight);
template void ComputationNetwork::Read<double>(const wstring& fileName);
template void ComputationNetwork::ReadPersistableParameters<double>(size_t modelVersion, File& fstream, bool create);
template void ComputationNetwork::PerformSVDecomposition<double>(const map<wstring, float>& SVDConfig, size_t alignedsize);
template /*static*/ void ComputationNetwork::SetBatchNormalizationTimeConstants<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double normalizationTimeConstant, double& prevNormalizationTimeConstant, double blendTimeConstant, double& prevBlendTimeConstant);
/* guoye: start */
/*
template void ComputationNetwork::SetSeqParam<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR);
*/
template void ComputationNetwork::SetSeqParam<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR,
const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR);
/* guoye: end */
template void ComputationNetwork::SaveToDbnFile<double>(ComputationNetworkPtr net, const std::wstring& fileName) const;
template void ComputationNetwork::InitLearnableParametersWithBilinearFill<half>(const ComputationNodeBasePtr& node, size_t kernelWidth, size_t kernelHeight);

Просмотреть файл

@ -554,7 +554,22 @@ public:
const double& lmf = 14.0f,
const double& wp = 0.0f,
const double& bMMIfactor = 0.0f,
const bool& sMBR = false);
/* guoye: start */
// const bool& sMBR = false);
const bool& sMBR = false,
const bool& EMBR = false,
const string& EMBRUnit = "word",
const size_t& numPathsEMBR = 100,
const bool& enforceValidPathEMBR = false,
const string& getPathMethodEMBR = "sampling",
const string& showWERMode = "average",
const bool& excludeSpecialWords = false,
const bool& wordNbest = false,
const bool& useAccInNbest = false,
const float& accWeightInNbest = 1.0f,
const size_t& numRawPathsEMBR = 100
);
/* guoye: end */
static void SetMaxTempMemSizeForCNN(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const size_t maxTempMemSizeInSamples);
// -----------------------------------------------------------------------

Просмотреть файл

@ -371,9 +371,21 @@ static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient)
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/
{
// TODO: should we deallocate in opposite order?
/* guoye: start */
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 0, m_nestedNodes.size() = %d \n", int(m_nestedNodes.size()));
int count = 0;
/* guoye: end */
for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter)
{
/* guoye: start */
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 1, count = %d \n", int(count));
/* guoye: end */
(*nodeIter)->AllocateGradientMatricesForInputs(matrixPool);
count++;
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 2, count = %d \n", int(count));
}
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/
@ -1061,36 +1073,53 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
const std::vector<ComputationNodeBasePtr>& outValueRootNodes,
ComputationNodeBasePtr trainRootNode)
{
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 1\n");
/* guoye: end */
if (AreMatricesAllocated())
return;
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 2\n");
/* guoye: end */
// Allocate memory for forward/backward computation
if (TraceLevel() > 0)
fprintf(stderr, "\n\nAllocating matrices for forward and/or backward propagation.\n");
VerifyIsCompiled("AllocateAllMatrices");
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 3\n");
/* guoye: end */
std::vector<ComputationNodeBasePtr> forwardPropRoots;
forwardPropRoots.insert(forwardPropRoots.end(), evalRootNodes.begin(), evalRootNodes.end());
forwardPropRoots.insert(forwardPropRoots.end(), outValueRootNodes.begin(), outValueRootNodes.end());
if (trainRootNode != nullptr)
forwardPropRoots.push_back(trainRootNode);
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 4\n");
/* guoye: end */
// Mark all the eval, output and criterion roots as non-shareable
for (auto& rootNode : forwardPropRoots)
rootNode->MarkValueNonSharable();
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 5\n");
/* guoye: end */
// Due to special topology, if a node is solely induced by parameters, its function value should not be shared
MarkValueNonSharableNodes();
bool performingBackPropagation = (trainRootNode != nullptr);
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 6\n");
/* guoye: end */
// Create a composite Eval order with the specified nodes as roots
// For each node determine parents and whether the output of the
// node is needed during back propagation
std::unordered_map<ComputationNodeBasePtr, bool> outputValueNeededDuringBackProp;
std::unordered_map<ComputationNodeBasePtr, std::unordered_set<ComputationNodeBasePtr>> parentsMap;
std::unordered_set<ComputationNodeBasePtr> uniqueForwardPropEvalNodes;
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 7\n");
/* guoye: end */
for (auto& rootNode : forwardPropRoots)
{
for (const auto& node : GetEvalOrder(rootNode))
@ -1115,7 +1144,9 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
}
}
}
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 8\n");
/* guoye: end */
// gradient reuse maps
std::unordered_map<MatrixPool::AliasNodePtr, std::unordered_set<MatrixPool::AliasNodePtr>> gradientReuseChildrenMap;
std::unordered_map<MatrixPool::AliasNodePtr, MatrixPool::AliasNodePtr> gradientReuseParentMap;
@ -1151,6 +1182,9 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
}
}
}
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 9\n");
/* guoye: end */
m_matrixPool.Reset();
@ -1175,7 +1209,9 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
ReleaseMatricesAfterEvalForChildren(node, parentsMap);
}
});
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 10\n");
/* guoye: end */
if (trainRootNode != nullptr)
{
const std::list<ComputationNodeBasePtr>& backPropNodes = GetEvalOrder(trainRootNode);
@ -1184,10 +1220,11 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
std::unordered_map<MatrixPool::AliasNodePtr, std::unordered_set<MatrixPool::AliasNodePtr>> compactGradientAliasMap;
std::unordered_map<MatrixPool::AliasNodePtr, MatrixPool::AliasNodePtr> compactGradientAliasRootMap;
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.1\n");
for (const auto& gradientReuseKeyValue : gradientReuseChildrenMap)
{
// keep searching parent until reaching root
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.2\n");
auto parent = gradientReuseKeyValue.first;
auto parentIter = gradientReuseParentMap.find(parent);
while (parentIter != gradientReuseParentMap.end())
@ -1214,20 +1251,22 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
compactGradientAliasMap[parent].insert(parent);
compactGradientAliasRootMap[parent] = parent;
}
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.3\n");
// print the memory aliasing info
if (TraceLevel() > 0 && compactGradientAliasRootMap.size() > 0)
{
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.4\n");
fprintf(stderr, "\nGradient Memory Aliasing: %d are aliased.\n", (int)compactGradientAliasRootMap.size());
for (const auto pair : compactGradientAliasRootMap)
{
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.5\n");
auto child = (const ComputationNodeBase*)pair.first;
auto parent = (const ComputationNodeBase*)pair.second;
if (child != parent)
fprintf(stderr, "\t%S (gradient) reuses %S (gradient)\n", child->GetName().c_str(), parent->GetName().c_str());
}
}
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.6\n");
m_matrixPool.SetAliasInfo(compactGradientAliasMap, compactGradientAliasRootMap);
// now, simulate the gradient computation order to determine how to allocate matrices
@ -1235,38 +1274,55 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
// we need to call it here since we always compute gradients for children and root node is not children of other node
trainRootNode->RequestMatricesBeforeBackprop(m_matrixPool);
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.7\n");
for (auto iter = backPropNodes.rbegin(); iter != backPropNodes.rend(); iter++) // for gradient computation, traverse in reverse order
{
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8\n");
auto n = *iter;
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.1\n");
if (n->IsPartOfLoop())
{
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.2\n");
std::vector<ComputationNodeBasePtr> recurrentNodes;
shared_ptr<SEQTraversalFlowControlNode> recInfo = FindInRecurrentLoops(m_allSEQNodes, n);
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.3\n");
if (completedGradient.insert(recInfo).second)
{
// SEQ mode: allocate all in loop first, then deallocate again
// TODO: next step: use PARTraversalFlowControlNode::AllocateGradientMatricesForInputs() and ReleaseMatricesAfterBackprop()...
// BUGBUG: naw, ^^ would not work! Wrong order! Need to rethink this. Need to make AllocateEvalMatrices() and AllocateGradientMatrices() the virtual functions.
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.4\n");
recInfo->AllocateGradientMatricesForInputs(m_matrixPool);
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.5\n");
// Loops are computed sample by sample so we have to allocate them all
recInfo->ReleaseMatricesAfterBackprop(m_matrixPool);
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.6\n");
}
}
else
{
// PAR mode: we can allocate and immediately deallocate one by one
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.7\n");
n->AllocateGradientMatricesForInputs(m_matrixPool);
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.8\n");
// Root node's information will be used and should not be shared with others, also it's small (1x1)
if ((n != trainRootNode) && n->NeedsGradient())
n->ReleaseMatricesAfterBackprop(m_matrixPool);
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.9\n");
}
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.10\n");
}
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.9\n");
}
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 11\n");
/* guoye: end */
m_matrixPool.OptimizedMemoryAllocation();
m_areMatricesAllocated = true;
/* guoye: start */
// fprintf(stderr, "\nAllocateAllMatrices: debug 12\n");
/* guoye: end */
// TO DO: At the time of AllocateAllMatrices we don't know the minibatch size. In theory one may allocate memory again once we start to receive
// data from the reader (and the minibatch size is known). For some problems, minibatch size can change constantly, and there needs to be a
// tradeoff in deciding how frequent to run optimized memory allocation. For now, we do it only once at the very beginning for speed concerns.

Просмотреть файл

@ -1874,24 +1874,54 @@ public:
virtual void AllocateGradientMatricesForInputs(MatrixPool& matrixPool) override
{
/* guoye: start */
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 0, m_inputs.size() = %d \n", int(m_inputs.size()));
/* guoye: end */
for (int i = 0; i < m_inputs.size(); i++)
{
/* guoye: start */
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 1, i = %d \n", int(i));
/* guoye: end */
if (m_inputs[i]->NeedsGradient())
m_inputs[i]->RequestMatricesBeforeBackprop(matrixPool);
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 2, i = %d \n", int(i));
}
}
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
size_t matrixSize = m_sampleLayout.GetNumElements();
RequestMatrixFromPool(m_gradient, matrixPool, matrixSize, HasMBLayout(), /*isWorkSpace*/false, ParentGradientReused() || IsGradientReused());
/* guoye: start */
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
size_t matrixSize = m_sampleLayout.GetNumElements();
/* guoye: start */
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 2, matrixSize = %d, mbscale = %d \n", int(matrixSize), int(HasMBLayout()));
/* guoye: end */
RequestMatrixFromPool(m_gradient, matrixPool, matrixSize, HasMBLayout(), /*isWorkSpace*/false, ParentGradientReused() || IsGradientReused());
/* guoye: start */
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
auto multiOutputNode = dynamic_cast<MultiOutputNode<ElemType>*>(this);
/* guoye: start */
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 4 \n");
/* guoye: end */
if (multiOutputNode)
{
/* guoye: start */
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 5, multiOutputNode->m_numOutputs = %d \n", int(multiOutputNode->m_numOutputs));
/* guoye: end */
for (size_t i = 1; i < multiOutputNode->m_numOutputs; ++i)
{
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 6, i = %d \n", int(i));
RequestMatrixFromPool(multiOutputNode->m_outputsGradient[i], matrixPool, multiOutputNode->m_outputsShape[i].GetNumElements(), multiOutputNode->m_outputsMBLayout[i] != nullptr);
}
}
}
@ -1957,6 +1987,11 @@ protected:
template<typename ValueType>
void TypedRequestMatrixFromPool(shared_ptr<Matrix<ValueType>>& matrixPtr, MatrixPool& matrixPool, size_t matrixSize=0, bool mbScale=false, bool isWorkSpace=false, bool aliasing=false)
{
/* guoye: start */
// fprintf(stderr, "\n computationnode.h:RequestMatrixFromPool, debug 0 \n");
// fprintf(stderr, "\n computationnode.h:RequestMatrixFromPool, debug 1 \n");
/* guoye: end */
if (matrixPtr == nullptr)
{
if (aliasing)

Просмотреть файл

@ -616,8 +616,17 @@ public:
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
/* guoye: start */
// fprintf(stderr, "\n convolutionalnodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n convolutionalnodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempMatrixBackward, matrixPool, 0, false, true);
/* guoye: start */
// fprintf(stderr, "\n convolutionalnodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
}
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override

Просмотреть файл

@ -269,9 +269,21 @@ public:
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_innerproduct, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
RequestMatrixFromPool(m_rightGradient, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 4 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.

Просмотреть файл

@ -1893,9 +1893,21 @@ public:
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_invNormSquare, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
RequestMatrixFromPool(m_temp, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.

Просмотреть файл

@ -137,14 +137,52 @@ public:
template <class ElemType>
void RequestAllocate(DEVICEID_TYPE deviceId, shared_ptr<Matrix<ElemType>>*pMatrixPtr, size_t matrixSize, bool mbScale, bool isWorkSpace)
{
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 1 \n");
/* guoye: end */
vector<MemRequestInfo<ElemType>>& memInfoVec = GetMemRequestInfoVec<ElemType>();
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 2 \n");
/* guoye: end */
MemRequestInfo<ElemType> memInfo(deviceId, pMatrixPtr, matrixSize, mbScale, isWorkSpace, m_stepCounter);
memInfoVec.push_back(memInfo);
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 3, memInfo.pMatrixPtrs.size() = %d, memInfoVec.size() = %d, sizeof(meminfo) = %d, \n", int(memInfo.pMatrixPtrs.size()), int(memInfoVec.size()), int(sizeof(memInfo)));
/*
if (memInfoVec.size() >= 256)
{
fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 3.5, sizeof(meminfo) is equal or large than 256, do no push \n");
memInfoVec.resize(memInfoVec.size() + 1, memInfo);
}
*/
/* guoye: end */
/* guoye: start */
/*
else
*/
{
memInfoVec.push_back(memInfo);
}
/* guoye: end */
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 4 \n");
/* guoye: end */
m_deviceIDSet.insert(deviceId);
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 5 \n");
/* guoye: end */
m_stepCounter++;
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 6 \n");
/* guoye: end */
// assign some temporary pointer, they will be replaced later unless the matrix is sparse
*pMatrixPtr = make_shared<Matrix<ElemType>>(deviceId);
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 7 \n");
/* guoye: end */
}
void OptimizedMemoryAllocation()
@ -207,24 +245,59 @@ public:
template <class ElemType>
void RequestAliasedAllocate(DEVICEID_TYPE deviceId, AliasNodePtr node, shared_ptr<Matrix<ElemType>>*pMatrixPtr, size_t matrixSize, bool mbScale)
{
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 1 \n");
/* guoye: end */
const auto iter = m_aliasLookup.find(node);
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 2 \n");
/* guoye: end */
if (iter == m_aliasLookup.end())
LogicError("node not aliased");
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 3 \n");
/* guoye: end */
auto parent = iter->second;
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 4 \n");
/* guoye: end */
auto& aliasInfo = m_aliasGroups[parent];
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 5 \n");
/* guoye: end */
if (aliasInfo.pMatrixPtr == nullptr)
{
// first allocation for the group
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 6 \n");
/* guoye: end */
aliasInfo.pMatrixPtr = pMatrixPtr;
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 7 \n");
/* guoye: end */
RequestAllocate(deviceId, pMatrixPtr, matrixSize, mbScale, false);
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 8 \n");
/* guoye: end */
}
else
{
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 9 \n");
/* guoye: end */
auto aliasRootMatrixPtr = (shared_ptr<Matrix<ElemType>>*)aliasInfo.pMatrixPtr;
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 10 \n");
/* guoye: end */
*pMatrixPtr = *aliasRootMatrixPtr;
GetMemInfo<ElemType>(aliasRootMatrixPtr)->pMatrixPtrs.push_back(pMatrixPtr);
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 11 \n");
/* guoye: end */
}
/* guoye: start */
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 12 \n");
/* guoye: end */
}
private:

Просмотреть файл

@ -247,8 +247,17 @@ public:
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 8 \n");
/* guoye: end */
RequestMatrixFromPool(m_gradientTemp, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 9 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
@ -318,8 +327,17 @@ public:
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 5 \n");
/* guoye: end */
RequestMatrixFromPool(m_diff, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
@ -385,8 +403,17 @@ public:
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_softmax, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.

Просмотреть файл

@ -62,9 +62,21 @@ public:
// request matrices needed to do node derivative value evaluation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_transposedDInput, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
RequestMatrixFromPool(m_transposedDOutput, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.

Просмотреть файл

@ -342,8 +342,17 @@ public:
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
/* guoye: start */
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 5 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempGatherIndices, matrixPool, 1, InputRef(0).HasMBLayout());
/* guoye: start */
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
/* guoye: end */
}
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
@ -508,9 +517,21 @@ public:
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
/* guoye: start */
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempScatterIndices, matrixPool, 1, HasMBLayout());
/* guoye: start */
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempUnpackedData, matrixPool, GetSampleLayout().GetNumElements(), HasMBLayout());
/* guoye: start */
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
/* guoye: end */
}
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override

Просмотреть файл

@ -128,9 +128,21 @@ public:
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
/* guoye: start */
//fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempScatterIndices, matrixPool, 1, HasMBLayout());
/* guoye: start */
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempUnpackedData, matrixPool, InputRef(0).GetSampleLayout().GetNumElements(), InputRef(0).HasMBLayout());
/* guoye: start */
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
/* guoye: end */
}
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
@ -464,9 +476,21 @@ public:
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
/* guoye: start */
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 5 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempGatherIndices, matrixPool, 1, HasMBLayout());
/* guoye: start */
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
/* guoye: end */
RequestMatrixFromPool(m_tempPackedGradientData, matrixPool, InputRef(0).GetSampleLayout().GetNumElements(), InputRef(0).HasMBLayout());
/* guoye: start */
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 8 \n");
/* guoye: end */
}
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override

Просмотреть файл

@ -477,12 +477,15 @@ public:
{
Input(inputIndex)->Gradient().SetValue(0.0f);
Value().SetValue(1.0f);
}
}
else
{
FrameRange fr(Input(0)->GetMBLayout());
BackpropToRight(*m_softmaxOfRight, Input(0)->Value(), Input(inputIndex)->Gradient(),
Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold);
/* guoye: start */
// Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold);
Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold, m_MBR);
/* guoye: end */
MaskMissingColumnsToZero(Input(inputIndex)->Gradient(), Input(0)->GetMBLayout(), fr);
}
#ifdef _DEBUG
@ -518,7 +521,10 @@ public:
static void WINAPI BackpropToRight(const Matrix<ElemType>& softmaxOfRight, const Matrix<ElemType>& inputFunctionValues,
Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
const Matrix<ElemType>& gammaFromLattice, double hsmoothingWeight, double frameDropThresh)
/* guoye: start */
//const Matrix<ElemType>& gammaFromLattice, double hsmoothingWeight, double frameDropThresh)
const Matrix<ElemType>& gammaFromLattice, double hsmoothingWeight, double frameDropThresh, bool MBR)
/* guoye: end */
{
#if DUMPOUTPUT
softmaxOfRight.Print("SequenceWithSoftmaxNode Partial-softmaxOfRight");
@ -526,8 +532,10 @@ public:
gradientValues.Print("SequenceWithSoftmaxNode Partial-gradientValues");
inputGradientValues.Print("SequenceWithSoftmaxNode Partial-Right-in");
#endif
inputGradientValues.AssignSequenceError((ElemType) hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element());
/* guoye: start */
// inputGradientValues.AssignSequenceError((ElemType) hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element());
inputGradientValues.AssignSequenceError((ElemType)hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element(), MBR);
/* guoye: end */
inputGradientValues.DropFrame(inputFunctionValues, gammaFromLattice, (ElemType) frameDropThresh);
#if DUMPOUTPUT
inputGradientValues.Print("SequenceWithSoftmaxNode Partial-Right");
@ -561,9 +569,20 @@ public:
m_gammaFromLattice->SwitchToMatrixType(m_softmaxOfRight->GetMatrixType(), m_softmaxOfRight->GetFormat(), false);
m_gammaFromLattice->Resize(*m_softmaxOfRight);
// guoye: start
// fprintf(stderr, "guoye debug: calgammaformb, m_m_nws.size() = %d \n", int(m_nws.size()));
for (size_t i = 0; i < m_nws.size(); i++)
{
// fprintf(stderr, "guoye debug: calgammaformb, i = %d, m_nws[i] = %d \n", int(i), int(m_nws[i]));
}
// guoye: end
m_gammaCalculator.calgammaformb(Value(), m_lattices, Input(2)->Value() /*log LLs*/,
Input(0)->Value() /*labels*/, *m_gammaFromLattice,
m_uids, m_boundaries, Input(1)->GetNumParallelSequences(),
/* guoye: start */
// m_uids, m_boundaries, Input(1)->GetNumParallelSequences(),
m_uids, m_wids, m_nws, m_boundaries, Input(1)->GetNumParallelSequences(),
/* guoye: end */
Input(0)->GetMBLayout(), m_extraUttMap, m_doReferenceAlignment);
#if NANCHECK
@ -635,15 +654,25 @@ public:
// TODO: method names should be CamelCase
std::vector<shared_ptr<const msra::dbn::latticepair>>* getLatticePtr() { return &m_lattices; }
std::vector<size_t>* getuidprt() { return &m_uids; }
/* guoye: start */
std::vector<size_t>* getwidprt() { return &m_wids; }
std::vector<short>* getnwprt() { return &m_nws; }
/* guoye: end */
std::vector<size_t>* getboundaryprt() { return &m_boundaries; }
std::vector<size_t>* getextrauttmap() { return &m_extraUttMap; }
msra::asr::simplesenonehmm* gethmm() { return &m_hmm; }
void SetSmoothWeight(double fsSmoothingWeight) { m_fsSmoothingWeight = fsSmoothingWeight; }
/* guoye : start */
void SetMBR(bool MBR) { m_MBR = MBR; }
/* guoye : end */
void SetFrameDropThresh(double frameDropThresh) { m_frameDropThreshold = frameDropThresh; }
void SetReferenceAlign(const bool doreferencealign) { m_doReferenceAlignment = doreferencealign; }
void SetGammarCalculationParam(const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR)
void SetGammarCalculationParam(const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR,
const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR)
{
msra::lattices::SeqGammarCalParam param;
param.amf = amf;
@ -651,6 +680,20 @@ public:
param.wp = wp;
param.bMMIfactor = bMMIfactor;
param.sMBRmode = sMBR;
/* guoye: start */
param.EMBR = EMBR;
param.EMBRUnit = EMBRUnit;
param.numPathsEMBR = numPathsEMBR;
param.enforceValidPathEMBR = enforceValidPathEMBR;
param.getPathMethodEMBR = getPathMethodEMBR;
param.showWERMode = showWERMode;
param.excludeSpecialWords = excludeSpecialWords;
param.wordNbest = wordNbest;
param.useAccInNbest = useAccInNbest;
param.accWeightInNbest = accWeightInNbest;
param.numRawPathsEMBR = numRawPathsEMBR;
/* guoye: end */
m_gammaCalculator.SetGammarCalculationParams(param);
}
@ -667,6 +710,9 @@ protected:
bool m_invalidMinibatch; // for single minibatch
double m_frameDropThreshold;
double m_fsSmoothingWeight; // frame-sequence criterion interpolation weight --TODO: can this be done outside?
/* guoye: start */
bool m_MBR;
/* guoye: end */
double m_seqGammarAMF;
double m_seqGammarLMF;
double m_seqGammarWP;
@ -678,6 +724,11 @@ protected:
msra::lattices::GammaCalculation<ElemType> m_gammaCalculator;
bool m_gammaCalcInitialized;
std::vector<size_t> m_uids;
/* guoye: start */
std::vector<size_t> m_wids;
std::vector<short> m_nws;
/* guoye: end */
std::vector<size_t> m_boundaries;
std::vector<size_t> m_extraUttMap;
@ -806,7 +857,9 @@ public:
auto& currentLatticeSeq = latticeMBLayout->FindSequence(currentLabelSeq.seqId);
std::shared_ptr<msra::dbn::latticepair> latticePair(new msra::dbn::latticepair);
const char* buffer = bufferStart + latticeMBNumTimeSteps * sizeof(float) * currentLatticeSeq.s + currentLatticeSeq.tBegin;
latticePair->second.ReadFromBuffer(buffer, m_idmap, m_idmap.back());
latticePair->second.ReadFromBuffer(buffer, m_idmap, m_idmap.back(), specialwordids());
assert((currentLabelSeq.tEnd - currentLabelSeq.tBegin) == latticePair->second.info.numframes);
// The size of the vector is small -- the number of sequences in the minibatch.
// Iteration likely will be faster than the overhead with unordered_map

Просмотреть файл

@ -351,8 +351,17 @@ public:
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
/* guoye: end */
RequestMatrixFromPool(m_leftDivRight, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 8 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
@ -444,8 +453,17 @@ public:
// request matrices that are needed for gradient computation
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 9 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 10 \n");
/* guoye: end */
RequestMatrixFromPool(m_gradientOfL1Norm, matrixPool);
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 11 \n");
/* guoye: end */
}
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
@ -2925,7 +2943,13 @@ public:
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
/* guoye: end */
Base::RequestMatricesBeforeBackprop(matrixPool);
/* guoye: start */
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
/* guoye: end */
RequestMatrixFromPool(m_dDataDummy, matrixPool);
this->template TypedRequestMatrixFromPool<StatType>(m_dScale, matrixPool);
this->template TypedRequestMatrixFromPool<StatType>(m_dBias, matrixPool);

Двоичный файл не отображается.

Просмотреть файл

@ -4675,7 +4675,10 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::DropFrame(const GPUMatrix<ElemType>& l
template <class ElemType>
GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label,
const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha)
/* guoye: start */
// const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha)
const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha, bool MBR)
/* guoye: end */
{
if (IsEmpty())
LogicError("AssignSequenceError: Matrix is empty.");
@ -4685,7 +4688,10 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignSequenceError(const ElemType hsm
SyncGuard syncGuard;
long N = (LONG64) label.GetNumElements();
int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock);
_AssignSequenceError<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N);
/* guoye: start */
//_AssignSequenceError<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N);
_AssignSequenceError << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> >(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N, MBR);
/* guoye: end */
return *this;
}

Просмотреть файл

@ -370,8 +370,10 @@ public:
// sequence training
GPUMatrix<ElemType>& DropFrame(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& gamma, const ElemType& threshhold);
GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha);
/* guoye: start */
//GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha);
GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha, bool MBR);
/* guoye: end */
GPUMatrix<ElemType>& AssignCTCScore(const GPUMatrix<ElemType>& prob, GPUMatrix<ElemType>& alpha, GPUMatrix<ElemType>& beta,
const GPUMatrix<ElemType> phoneSeq, const GPUMatrix<ElemType> phoneBoundary, GPUMatrix<ElemType> & totalScore, const vector<size_t>& uttMap, const vector<size_t> & uttBeginFrame, const vector<size_t> & uttFrameNum,
const vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const size_t blankTokenId, const int delayConstraint, const bool isColWise);

Просмотреть файл

@ -5256,13 +5256,24 @@ __global__ void _DropFrame(
template <class ElemType>
__global__ void _AssignSequenceError(const ElemType hsmoothingWeight, ElemType* error, const ElemType* label,
const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N)
/* guoye: start */
// const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N)
const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N, bool MBR)
/* guoye: end */
{
typedef typename TypeSelector<ElemType>::comp_t comp_t;
int id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
error[id] = (comp_t)error[id] - (comp_t)alpha * ((comp_t)label[id] - (1.0 - (comp_t)hsmoothingWeight) * (comp_t)dnnoutput[id] - (comp_t)hsmoothingWeight * (comp_t)gamma[id]);
/* guoye: start */
// error[id] -= alpha * (label[id] - (1.0 - hsmoothingWeight) * dnnoutput[id] - hsmoothingWeight * gamma[id]);
if(!MBR)
error[id] -= alpha * (label[id] - (1.0 - hsmoothingWeight) * dnnoutput[id] - hsmoothingWeight * gamma[id]);
else
error[id] -= alpha * ( (1.0 - hsmoothingWeight) * (label[id] - dnnoutput[id]) + hsmoothingWeight * gamma[id]);
/* guoye: end */
// change to ce
// error[id] -= alpha * (label[id] - dnnoutput[id] );
}

Просмотреть файл

@ -61,7 +61,7 @@
<ClCompile>
<AdditionalIncludeDirectories>$(MathIncludePath);$(BOOST_INCLUDE_PATH);$(SolutionDir)Source\Common\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<DisableSpecificWarnings>4819</DisableSpecificWarnings>
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
<AdditionalOptions>/d2Zi+ /bigobj /FS %(AdditionalOptions)</AdditionalOptions>
<TreatWarningAsError>true</TreatWarningAsError>
<WarningLevel>Level4</WarningLevel>
<SDLCheck>true</SDLCheck>

Просмотреть файл

@ -99,7 +99,7 @@ xcopy /D /Y "$(CuDnnDll)" "$(OutputPath)"
<ClCompile>
<EnableParallelCodeGeneration>true</EnableParallelCodeGeneration>
<FloatingPointExceptions>false</FloatingPointExceptions>
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
<AdditionalOptions>/d2Zi+ /FS </AdditionalOptions>
</ClCompile>
<CudaCompile>
<HostDebugInfo>false</HostDebugInfo>

Просмотреть файл

@ -6098,7 +6098,10 @@ Matrix<ElemType>& Matrix<ElemType>::DropFrame(const Matrix<ElemType>& label, con
/// <param name="c">Resulting matrix, user is responsible for allocating this</param>
template <class ElemType>
Matrix<ElemType>& Matrix<ElemType>::AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label,
const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha)
/* guoye: start */
// const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha)
const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha, bool MBR)
/* guoye: end */
{
DecideAndMoveToRightDevice(label, dnnoutput, gamma);
@ -6106,11 +6109,16 @@ Matrix<ElemType>& Matrix<ElemType>::AssignSequenceError(const ElemType hsmoothin
NOT_IMPLEMENTED;
SwitchToMatrixType(label.GetMatrixType(), label.GetFormat(), false);
DISPATCH_MATRIX_ON_FLAG(this,
this,
m_CPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_CPUMatrix, *dnnoutput.m_CPUMatrix, *gamma.m_CPUMatrix, alpha),
m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha),
/* guoye: start */
// m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha),
m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha, MBR),
/* guoye: end */
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
return *this;

Просмотреть файл

@ -402,8 +402,10 @@ public:
// sequence training
Matrix<ElemType>& DropFrame(const Matrix<ElemType>& label, const Matrix<ElemType>& gamma, const ElemType& threshhold);
Matrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label, const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha);
/* guoye: start */
// Matrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label, const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha);
Matrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label, const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha, bool MBR);
/* guoye: end */
Matrix<ElemType>& AssignCTCScore(const Matrix<ElemType>& prob, Matrix<ElemType>& alpha, Matrix<ElemType>& beta, const Matrix<ElemType>& phoneSeq, const Matrix<ElemType>& phoneBound, Matrix<ElemType>& totalScore,
const vector<size_t> & extraUttMap, const vector<size_t> & uttBeginFrame, const vector<size_t> & uttFrameNum, const vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep,
const size_t mbSize, const size_t blankTokenId, const int delayConstraint, const bool isColWise);

Просмотреть файл

@ -162,6 +162,22 @@ private:
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(Eframescorrectbuf),
logEframescorrecttotal, totalfwscore);
}
/* guoye: start */
void backwardlatticeEMBR(const size_t *batchsizebackward, const size_t numlaunchbackward,
const floatvector &edgeacscores, const edgeinfowithscoresvector &edges,
const nodeinfovector &nodes, doublevector &edgelogbetas, doublevector &logbetas,
const float lmf, const float wp, const float amf, double &totalbwscore)
{
ondevice no(deviceid);
latticefunctionsops::backwardlatticeEMBR(batchsizebackward, numlaunchbackward,
dynamic_cast<const vectorbaseimpl<floatvector, vectorref<float>> &>(edgeacscores),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(edgelogbetas),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logbetas),
lmf, wp, amf, totalbwscore);
}
/* guoye: end */
void sMBRerrorsignal(const ushortvector &alignstateids,
const uintvector &alignoffsets,
@ -183,6 +199,25 @@ private:
logEframescorrecttotal, dengammasMatrixRef, dengammasbufMatrixRef);
}
/* guoye: start */
void EMBRerrorsignal(const ushortvector &alignstateids,
const uintvector &alignoffsets,
const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
const doublevector &edgeweights,
Microsoft::MSR::CNTK::Matrix<float> &dengammas)
{
ondevice no(deviceid);
matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
latticefunctionsops::EMBRerrorsignal(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(edgeweights),
dengammasMatrixRef);
}
/* guoye: end */
void mmierrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets,
const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
const doublevector &logpps, Microsoft::MSR::CNTK::Matrix<float> &dengammas)

Просмотреть файл

@ -99,6 +99,19 @@ struct latticefunctions : public vectorbase<msra::lattices::empty>
doublevector& logaccalphas, doublevector& logaccbetas,
doublevector& logframescorrectedge, doublevector& logEframescorrect,
doublevector& Eframescorrectbuf, double& logEframescorrecttotal, double& totalfwscore) = 0;
/* guoye: start */
virtual void backwardlatticeEMBR(const size_t* batchsizebackward, const size_t numlaunchbackward,
const floatvector& edgeacscores, const edgeinfowithscoresvector& edges,
const nodeinfovector& nodes, doublevector& edgelogbetas, doublevector& logbetas,
const float lmf, const float wp, const float amf, double& totalbwscore) = 0;
virtual void EMBRerrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets,
const edgeinfowithscoresvector& edges, const nodeinfovector& nodes,
const doublevector& edgeweights, Microsoft::MSR::CNTK::Matrix<float>& dengammas) = 0;
/* guoye: end */
virtual void sMBRerrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets,
const edgeinfowithscoresvector& edges, const nodeinfovector& nodes,
const doublevector& logpps, const float amf, const doublevector& logEframescorrect,

Просмотреть файл

@ -226,7 +226,25 @@ __global__ void backwardlatticej(const size_t batchsize, const size_t startindex
logEframescorrect, logaccbetas);
}
}
/* guoye: start */
__global__ void backwardlatticejEMBR(const size_t batchsize, const size_t startindex, const vectorref<float> edgeacscores,
vectorref<msra::lattices::edgeinfowithscores> edges, vectorref<msra::lattices::nodeinfo> nodes,
vectorref<double> edgelogbetas, vectorref<double> logbetas,
float lmf, float wp, float amf)
{
const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block
const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x;
size_t j = jinblock + blockIdx.x * tpb;
if (j < batchsize) // note: will cause issues if we ever use __synctreads()
{
msra::lattices::latticefunctionskernels::backwardlatticejEMBR(j + startindex, edgeacscores,
edges, nodes, edgelogbetas, logbetas,
lmf, wp, amf);
}
}
/* guoye: end */
void latticefunctionsops::forwardbackwardlattice(const size_t *batchsizeforward, const size_t *batchsizebackward,
const size_t numlaunchforward, const size_t numlaunchbackward,
const size_t spalignunitid, const size_t silalignunitid,
@ -326,6 +344,46 @@ void latticefunctionsops::forwardbackwardlattice(const size_t *batchsizeforward,
}
}
/* guoye: start */
void latticefunctionsops::backwardlatticeEMBR( const size_t *batchsizebackward, const size_t numlaunchbackward,
const vectorref<float> &edgeacscores,
const vectorref<msra::lattices::edgeinfowithscores> &edges,
const vectorref<msra::lattices::nodeinfo> &nodes, vectorref<double> &edgelogbetas, vectorref<double> &logbetas,
const float lmf, const float wp, const float amf, double &totalbwscore) const
{
// initialize log{,acc}(alhas/betas)
dim3 t(32, 8);
const size_t tpb = t.x * t.y;
dim3 b((unsigned int)((logbetas.size() + tpb - 1) / tpb));
// TODO: is this really efficient? One thread per value?
setvaluej << <b, t, 0, GetCurrentStream() >> >(logbetas, LOGZERO, logbetas.size());
checklaunch("setvaluej");
// set initial tokens to probability 1 (0 in log)
double log1 = 0.0;
memcpy(logbetas.get(), nodes.size() - 1, &log1, 1);
// backward pass
size_t startindex = 0;
startindex = edges.size();
for (size_t i = 0; i < numlaunchbackward; i++)
{
dim3 b2((unsigned int)((batchsizebackward[i] + tpb - 1) / tpb));
backwardlatticejEMBR << <b2, t, 0, GetCurrentStream() >> >(batchsizebackward[i], startindex - batchsizebackward[i],
edgeacscores, edges, nodes, edgelogbetas, logbetas,
lmf, wp, amf);
checklaunch("edgealignment");
startindex -= batchsizebackward[i];
}
memcpy<double>(&totalbwscore, logbetas.get(), 0, 1);
}
/* guoye: end */
// -----------------------------------------------------------------------
// sMBRerrorsignal -- accumulate difference of logEframescorrect and logEframescorrecttotal into errorsignal
// -----------------------------------------------------------------------
@ -342,6 +400,22 @@ __global__ void sMBRerrorsignalj(const vectorref<unsigned short> alignstateids,
msra::lattices::latticefunctionskernels::sMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg);
}
}
/* guoye: start */
__global__ void EMBRerrorsignalj(const vectorref<unsigned short> alignstateids, const vectorref<unsigned int> alignoffsets,
const vectorref<msra::lattices::edgeinfowithscores> edges, const vectorref<msra::lattices::nodeinfo> nodes,
vectorref<double> edgeweights,
matrixref<float> errorsignal)
{
const size_t shufflemode = 1; // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle)
const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode);
if (j < edges.size()) // note: will cause issues if we ever use __synctreads()
{
msra::lattices::latticefunctionskernels::EMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
}
}
/* guoye: end */
// -----------------------------------------------------------------------
// stateposteriors --accumulate a per-edge quantity into the states that the edge is aligned with
@ -433,6 +507,28 @@ void latticefunctionsops::sMBRerrorsignal(const vectorref<unsigned short> &align
#endif
}
/* guoye: start */
void latticefunctionsops::EMBRerrorsignal(const vectorref<unsigned short> &alignstateids, const vectorref<unsigned int> &alignoffsets,
const vectorref<msra::lattices::edgeinfowithscores> &edges, const vectorref<msra::lattices::nodeinfo> &nodes,
const vectorref<double> &edgeweights,
matrixref<float> &errorsignal) const
{
// Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks.
// This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well.
const size_t numedges = edges.size();
dim3 t(32, 8);
const size_t tpb = t.x * t.y;
dim3 b((unsigned int)((numedges + tpb - 1) / tpb));
setvaluei << <dim3((((unsigned int)errorsignal.rows()) + 31) / 32), 32, 0, GetCurrentStream() >> >(errorsignal, 0);
checklaunch("setvaluei");
EMBRerrorsignalj << <b, t, 0, GetCurrentStream() >> >(alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
checklaunch("EMBRerrorsignal");
}
/* guoye: end */
void latticefunctionsops::mmierrorsignal(const vectorref<unsigned short> &alignstateids, const vectorref<unsigned int> &alignoffsets,
const vectorref<msra::lattices::edgeinfowithscores> &edges, const vectorref<msra::lattices::nodeinfo> &nodes,
const vectorref<double> &logpps, matrixref<float> &errorsignal) const

Просмотреть файл

@ -53,6 +53,17 @@ protected:
vectorref<double>& logframescorrectedge, vectorref<double>& logEframescorrect, vectorref<double>& Eframescorrectbuf,
double& logEframescorrecttotal, double& totalfwscore) const;
/* guoye: start */
void backwardlatticeEMBR(const size_t *batchsizebackward, const size_t numlaunchbackward,
const vectorref<float> &edgeacscores,
const vectorref<msra::lattices::edgeinfowithscores> &edges,
const vectorref<msra::lattices::nodeinfo> &nodes, vectorref<double> &edgelogbetas, vectorref<double> &logbetas,
const float lmf, const float wp, const float amf, double &totalbwscore) const;
void EMBRerrorsignal(const vectorref<unsigned short> &alignstateids, const vectorref<unsigned int> &alignoffsets,
const vectorref<msra::lattices::edgeinfowithscores> &edges, const vectorref<msra::lattices::nodeinfo> &nodes,
const vectorref<double> &edgeweights,
matrixref<float> &errorsignal) const;
/* guoye: end */
void sMBRerrorsignal(const vectorref<unsigned short>& alignstateids, const vectorref<unsigned int>& alignoffsets,
const vectorref<msra::lattices::edgeinfowithscores>& edges, const vectorref<msra::lattices::nodeinfo>& nodes,
const vectorref<double>& logpps, const float amf, const vectorref<double>& logEframescorrect, const double logEframescorrecttotal,

Просмотреть файл

@ -302,6 +302,25 @@ struct latticefunctionskernels
// note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always)
return bitsasfloat(old);
}
/* guoye: start */
template <typename FLOAT> // adapted from [http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ixzz32EuzZjxV]
static __device__ FLOAT atomicAdd(FLOAT *address, FLOAT val) // direct adaptation from NVidia source code
{
typedef decltype(floatasbits(val)) bitstype;
bitstype *address_as_ull = (bitstype *)address;
bitstype old = *address_as_ull, assumed;
do
{
assumed = old;
FLOAT sum = bitsasfloat(assumed);
sum = sum + val;
old = atomicCAS(address_as_ull, assumed, floatasbits(sum));
} while (assumed != old);
// note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always)
return bitsasfloat(old);
}
/* guoye: end */
#else // this code does not work because (assumed != old) will not compare correctly in case of NaNs
// same pattern as atomicAdd(), but performing the log-add operation instead
template <typename FLOAT>
@ -889,6 +908,66 @@ struct latticefunctionskernels
logEframescorrect[j] = logEframescorrectj;
}
/* guoye: start */
template <typename edgeinforvector, typename nodeinfovector, typename floatvector, typename doublevector>
static inline __device__ void backwardlatticejEMBR(size_t j, const floatvector &edgeacscores,
const edgeinforvector &edges, const nodeinfovector &nodes, doublevector & edgelogbetas,
doublevector &logbetas, float lmf, float wp, float amf)
{
// edge info
const edgeinfowithscores &e = edges[j];
double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
// zhaorui to deal with the abnormal score for sent start.
if (e.l < -200.0f)
edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
#ifdef FORBID_INVALID_SIL_PATHS
// original mode
const bool forbidinvalidsilpath = (logbetas.size() > nodes.size()); // we prune sil to sil path if alphabetablowup != 1
const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1); // HACK: 'unused' indicates artificially added sil/sp edge
if (!isaddedsil) // original mode
#endif
{
const size_t S = e.S;
const size_t E = e.E;
// backward pass
const double inscore = logbetas[E];
const double pathscore = inscore + edgescore;
edgelogbetas[j] = pathscore;
atomicLogAdd(&logbetas[S], pathscore);
}
#ifdef FORBID_INVALID_SIL_PATHS
// silence edge or second speech edge
if ((isaddedsil && e.E != nodes.size() - 1) || (forbidinvalidsilpath && e.S != 0))
{
const size_t S = (size_t)(!isaddedsil ? e.S + nodes.size() : e.S); // second speech edge comes from special 'silence state' node
const size_t E = (size_t)(isaddedsil ? e.E + nodes.size() : e.E); // silence edge goes into special 'silence state' node
// remaining lines here are code dup from above, with two changes: logadd2/logEframescorrectj2 instead of logadd/logEframescorrectj
// backward pass
const double inscore = logbetas[E];
const double pathscore = inscore + edgescore;
edgelogbetas[j] = pathscore;
atomicLogAdd(&logbetas[S], pathscore);
}
#else
nodes;
#endif
}
/* guoye: end */
template <typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
static inline __device__ void sMBRerrorsignalj(size_t j, const ushortvector &alignstateids, const uintvector &alignoffsets,
const edgeinfowithscoresvector &edges,
@ -930,6 +1009,63 @@ struct latticefunctionskernels
}
}
/* guoye: start */
template <typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
static inline __device__ void EMBRerrorsignalj(size_t j, const ushortvector &alignstateids, const uintvector &alignoffsets,
const edgeinfowithscoresvector &edges,
const nodeinfovector &nodes, const doublevector &edgeweights,
matrix &errorsignal)
{
size_t ts = nodes[edges[j].S].t;
size_t te = nodes[edges[j].E].t;
if (ts != te)
{
const float weight = (float)(edgeweights[j]);
size_t offset = alignoffsets[j];
/*
if (weight <= 1)
{
*/
// size_t k = 0;
for (size_t t = ts; t < te; t++)
{
const size_t s = (size_t)alignstateids[t - ts + offset];
/*
errorsignal(0, k) = float(t);
k = k + 1;
errorsignal(0, k) = float(s);
k = k + 1;
*/
// atomicLogAdd(&errorsignal(s, t), weight);
// use atomic function for lock the value
atomicAdd(&errorsignal(s, t), weight);
//errorsignal(s, t) = errorsignal(s, t) + weight;
// errorsignal(s, t) = errorsignal(s, t) + (float)(ts);
}
// }
/* guoye: start */
/*
if (weight > 1)
{
for (size_t t = 63; t < 71; t++)
errorsignal(7, t) = errorsignal(7, t) + ts;
errorsignal(8, 71) = errorsignal(8, 71) + te;
}
*/
/* guoye: end */
}
}
/* guoye: end */
// accumulate a per-edge quantity into the states that the edge is aligned with
// Use this for MMI passing the edge posteriors logpps[] as logq, or for sMBR passing logEframescorrect[].
// j=edge index, alignment in (alignstateids, alignoffsets)

Просмотреть файл

@ -26,6 +26,10 @@
#include "ScriptableObjects.h"
#include "HTKMLFReader.h"
#include "TimerUtility.h"
/* guoye: start */
#include "fileutil.h"
#include <string>
/* guoye: end */
#ifdef LEAKDETECT
#include <vld.h> // for memory leak detection
#endif
@ -99,6 +103,33 @@ void HTKMLFReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig
}
}
/* guoye: start */
void readwordidmap(const std::wstring &pathname, std::unordered_map<std::string, int>& wordidmap, int start_id)
{
std::unordered_map<std::string, int>::iterator mp_itr;
auto_file_ptr f(fopenOrDie(pathname, L"rbS"));
fprintf(stderr, "readwordidmap: reading %ls \n", pathname.c_str());
char buf[1024];
char word[1024];
int dumid;
while (!feof(f))
{
fgetline(f, buf);
if (sscanf(buf, "%s %d", word, &dumid) != 2)
{
fprintf(stderr, "readwordidmap: reaching the end of line, with content = %s", buf);
break;
}
if (wordidmap.find(std::string(word)) == wordidmap.end())
{
wordidmap.insert(pair<std::string, int>(string(word),start_id++));
}
}
fclose(f);
}
/* guoye: end */
// Load all input and output data.
// Note that the terms features imply be real-valued quantities and
// labels imply categorical quantities, irrespective of whether they
@ -116,6 +147,9 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
vector<vector<wstring>> infilesmulti;
size_t numFiles;
wstring unigrampath(L"");
/* guoye: start */
wstring wordidmappath(L"");
/* guoye: end */
size_t randomize = randomizeAuto;
size_t iFeat, iLabel;
@ -443,19 +477,146 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
if (readerConfig.Exists(L"unigram"))
unigrampath = (const wstring&) readerConfig(L"unigram");
/*guoye: start */
if (readerConfig.Exists(L"wordidmap"))
wordidmappath = (const wstring&)readerConfig(L"wordidmap");
/*guoye: end */
// load a unigram if needed (this is used for MMI training)
msra::lm::CSymbolSet unigramsymbols;
/* guoye: start */
std::set<int> specialwordids;
std::vector<string> specialwords;
std::unordered_map<std::string, int> wordidmap;
std::unordered_map<std::string, int>::iterator wordidmap_itr;
/* guoye: end */
std::unique_ptr<msra::lm::CMGramLM> unigram;
size_t silencewordid = SIZE_MAX;
size_t startwordid = SIZE_MAX;
size_t endwordid = SIZE_MAX;
/* guoye: debug */
if (unigrampath != L"")
// if(true)
{
// RuntimeError("should not come here.");
/* guoye: start (this code order must be consistent with dbn.exe in main.cpp */
unigram.reset(new msra::lm::CMGramLM());
unigramsymbols["!NULL"];
unigramsymbols["<s>"];
unigramsymbols["</s>"];
unigramsymbols["!sent_start"];
unigramsymbols["!sent_end"];
unigramsymbols["!silence"];
/* guoye: end */
unigram->read(unigrampath, unigramsymbols, false /*filterVocabulary--false will build the symbol map*/, 1 /*maxM--unigram only*/);
/* guoye: end */
silencewordid = unigramsymbols["!silence"]; // give this an id (even if not in the LM vocabulary)
startwordid = unigramsymbols["<s>"];
endwordid = unigramsymbols["</s>"];
/* guoye: start */
specialwordids.clear();
specialwordids.insert(unigramsymbols["<s>"]);
specialwordids.insert(unigramsymbols["</s>"]);
specialwordids.insert(unigramsymbols["!NULL"]);
specialwordids.insert(unigramsymbols["!sent_start"]);
specialwordids.insert(unigramsymbols["!sent_end"]);
specialwordids.insert(unigramsymbols["!silence"]);
specialwordids.insert(unigramsymbols["[/CNON]"]);
specialwordids.insert(unigramsymbols["[/CSPN]"]);
specialwordids.insert(unigramsymbols["[/NPS]"]);
specialwordids.insert(unigramsymbols["[CNON/]"]);
specialwordids.insert(unigramsymbols["[CNON]"]);
specialwordids.insert(unigramsymbols["[CSPN]"]);
specialwordids.insert(unigramsymbols["[FILL/]"]);
specialwordids.insert(unigramsymbols["[NON/]"]);
specialwordids.insert(unigramsymbols["[NONNATIVE/]"]);
specialwordids.insert(unigramsymbols["[NPS]"]);
specialwordids.insert(unigramsymbols["[SB/]"]);
specialwordids.insert(unigramsymbols["[SBP/]"]);
specialwordids.insert(unigramsymbols["[SN/]"]);
specialwordids.insert(unigramsymbols["[SPN/]"]);
specialwordids.insert(unigramsymbols["[UNKNOWN/]"]);
specialwordids.insert(unigramsymbols[".]"]);
// this is to exclude the unknown words in lattice brought when merging the numerator lattice into denominator lattice.
specialwordids.insert(0xfffff);
/* guoye: end */
}
else if (wordidmappath != L"")
// if(true)
{
wordidmap.insert(pair<std::string, int>("!NULL", 0));
wordidmap.insert(pair<std::string, int>("<s>", 1));
wordidmap.insert(pair<std::string, int>("</s>", 2));
wordidmap.insert(pair<std::string, int>("!sent_start", 3));
wordidmap.insert(pair<std::string, int>("!sent_end", 4));
wordidmap.insert(pair<std::string, int>("!silence", 5));
silencewordid = 5; // give this an id (even if not in the LM vocabulary)
startwordid = 1;
endwordid = 2;
int start_id = 6;
readwordidmap(wordidmappath, wordidmap, start_id);
/* guoye: start */
specialwordids.clear();
specialwords.clear();
specialwords.push_back("<s>");
specialwords.push_back("</s>");
specialwords.push_back("!NULL");
specialwords.push_back("!sent_start");
specialwords.push_back("!sent_end");
specialwords.push_back("!silence");
specialwords.push_back("[/CNON]");
specialwords.push_back("[/CSPN]");
specialwords.push_back("[/NPS]");
specialwords.push_back("[CNON/]");
specialwords.push_back("[CNON]");
specialwords.push_back("[CSPN]");
specialwords.push_back("[FILL/]");
specialwords.push_back("[NON/]");
specialwords.push_back("[NONNATIVE/]");
specialwords.push_back("[NPS]");
specialwords.push_back("[SB/]");
specialwords.push_back("[SBP/]");
specialwords.push_back("[SN/]");
specialwords.push_back("[SPN/]");
specialwords.push_back("[UNKNOWN/]");
specialwords.push_back(".]");
for (size_t i = 0; i < specialwords.size(); i++)
{
wordidmap_itr = wordidmap.find(specialwords[i]);
specialwordids.insert((wordidmap_itr == wordidmap.end()) ? -1 : wordidmap_itr->second);
}
// this is to exclude the unknown words in lattice brought when merging the numerator lattice into denominator lattice.
specialwordids.insert(0xfffff);
/* guoye: end */
}
if (!unigram && latticetocs.second.size() > 0)
@ -497,19 +658,41 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
double htktimetoframe = 100000.0; // default is 10ms
// std::vector<msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>> labelsmulti;
std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
/* guoye: start */
// std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
std::vector<std::map<std::wstring, std::pair<std::vector<msra::asr::htkmlfentry>, std::vector<unsigned int>>>> labelsmulti;
// std::vector<std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence>> wordlabelsmulti;
/* debug to clean wordidmap */
// wordidmap.clear();
/* guoye: end */
// std::vector<std::wstring> pagepath;
foreach_index (i, mlfpathsmulti)
{
/* guoye: start */
/*
const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
*/
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
// msra::asr::htkmlfreader<msra::asr::htkmlfentry>
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordidmap, htktimetoframe); // label MLF
// labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordidmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
/* guoye: end */
// get the temp file name for the page file
// Make sure 'msra::asr::htkmlfreader' type has a move constructor
static_assert(std::is_move_constructible<msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>>::value,
"Type 'msra::asr::htkmlfreader' should be move constructible!");
/* guoye: start */
// map<wstring, msra::lattices::lattice::htkmlfwordsequence> wordlabels = labels.get_wordlabels();
// guoye debug purpose
// fprintf(stderr, "debug to set wordlabels to empty");
// map<wstring, msra::lattices::lattice::htkmlfwordsequence> wordlabels;
// wordlabelsmulti.push_back(std::move(wordlabels));
/* guoye: end */
labelsmulti.push_back(std::move(labels));
}
@ -522,7 +705,11 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
// now get the frame source. This has better randomization and doesn't create temp files
bool useMersenneTwisterRand = readerConfig(L"useMersenneTwisterRand", false);
m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims,
/* guoye: start */
// m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims,
// m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, wordlabelsmulti, specialwordids, m_featDims, m_labelDims,
m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, specialwordids, m_featDims, m_labelDims,
/* guoye: end */
numContextLeft, numContextRight, randomize,
*m_lattices, m_latticeMap, m_frameMode,
m_expandToUtt, m_maxUtteranceLength, m_truncated));
@ -756,6 +943,10 @@ void HTKMLFReader<ElemType>::StartDistributedMinibatchLoop(size_t requestedMBSiz
// for the multi-utterance process for lattice and phone boundary
m_latticeBufferMultiUtt.assign(m_numSeqsPerMB, nullptr);
m_labelsIDBufferMultiUtt.resize(m_numSeqsPerMB);
/* guoye: start */
m_wlabelsIDBufferMultiUtt.resize(m_numSeqsPerMB);
m_nwsBufferMultiUtt.resize(m_numSeqsPerMB);
/* guoye: end */
m_phoneboundaryIDBufferMultiUtt.resize(m_numSeqsPerMB);
if (m_frameMode && (m_numSeqsPerMB > 1))
@ -894,11 +1085,17 @@ void HTKMLFReader<ElemType>::StartMinibatchLoopToWrite(size_t mbSize, size_t /*e
template <class ElemType>
bool HTKMLFReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput,
vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
/* guoye: start */
vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
// vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
/* guoye: end */
{
if (m_trainOrTest)
{
return GetMinibatch4SEToTrainOrTest(latticeinput, uids, boundaries, extrauttmap);
/* guoye: start */
// return GetMinibatch4SEToTrainOrTest(latticeinput, uids, boundaries, extrauttmap);
return GetMinibatch4SEToTrainOrTest(latticeinput, uids, wids, nws, boundaries, extrauttmap);
/* guoye: end */
}
else
{
@ -907,16 +1104,31 @@ bool HTKMLFReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::
}
template <class ElemType>
bool HTKMLFReader<ElemType>::GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput,
std::vector<size_t>& uids, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
/* guoye: start */
std::vector<size_t>& uids, std::vector<size_t>& wids, std::vector<short>& nws, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
// std::vector<size_t>& uids, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
/* guoye: end */
{
latticeinput.clear();
uids.clear();
/* guoye: start */
wids.clear();
nws.clear();
/* guoye: end */
boundaries.clear();
extrauttmap.clear();
for (size_t i = 0; i < m_extraSeqsPerMB.size(); i++)
{
latticeinput.push_back(m_extraLatticeBufferMultiUtt[i]);
uids.insert(uids.end(), m_extraLabelsIDBufferMultiUtt[i].begin(), m_extraLabelsIDBufferMultiUtt[i].end());
/* guoye: start */
wids.insert(wids.end(), m_extraWLabelsIDBufferMultiUtt[i].begin(), m_extraWLabelsIDBufferMultiUtt[i].end());
nws.insert(nws.end(), m_extraNWsBufferMultiUtt[i].begin(), m_extraNWsBufferMultiUtt[i].end());
/* guoye: end */
boundaries.insert(boundaries.end(), m_extraPhoneboundaryIDBufferMultiUtt[i].begin(), m_extraPhoneboundaryIDBufferMultiUtt[i].end());
}
@ -984,6 +1196,11 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
m_extraLabelsIDBufferMultiUtt.clear();
m_extraPhoneboundaryIDBufferMultiUtt.clear();
m_extraSeqsPerMB.clear();
/* guoye: start */
m_extraWLabelsIDBufferMultiUtt.clear();
m_extraNWsBufferMultiUtt.clear();
/* guoye: end */
if (m_noData && m_numFramesToProcess[0] == 0) // no data left for the first channel of this minibatch,
{
return false;
@ -1064,6 +1281,11 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
{
m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[i]);
m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[i]);
/* guoye: start */
m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[i]);
m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[i]);
/* guoye: end */
m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[i]);
}
}
@ -1106,6 +1328,12 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
{
m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[src]);
m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[src]);
/* guoye: start */
m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[src]);
m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[src]);
/* guoye: end */
m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[src]);
}
@ -1811,6 +2039,15 @@ bool HTKMLFReader<ElemType>::ReNewBufferForMultiIO(size_t i)
m_phoneboundaryIDBufferMultiUtt[i] = m_mbiter->bounds();
m_labelsIDBufferMultiUtt[i].clear();
m_labelsIDBufferMultiUtt[i] = m_mbiter->labels();
/* guoye: start */
m_wlabelsIDBufferMultiUtt[i].clear();
m_wlabelsIDBufferMultiUtt[i] = m_mbiter->wlabels();
m_nwsBufferMultiUtt[i].clear();
m_nwsBufferMultiUtt[i] = m_mbiter->nwords();
/* guoye: end */
}
m_processedFrame[i] = 0;
@ -2031,8 +2268,7 @@ unique_ptr<CUDAPageLockedMemAllocator>& HTKMLFReader<ElemType>::GetCUDAAllocator
if (m_cudaAllocator == nullptr)
{
m_cudaAllocator.reset(new CUDAPageLockedMemAllocator(deviceID));
}
}
return m_cudaAllocator;
}
@ -2049,6 +2285,7 @@ std::shared_ptr<ElemType> HTKMLFReader<ElemType>::AllocateIntermediateBuffer(int
this->GetCUDAAllocator(deviceID)->Free((char*) p);
});
}
else
{
return std::shared_ptr<ElemType>(new ElemType[numElements],
@ -2059,6 +2296,9 @@ std::shared_ptr<ElemType> HTKMLFReader<ElemType>::AllocateIntermediateBuffer(int
}
}
template class HTKMLFReader<float>;
template class HTKMLFReader<double>;
} } }

Просмотреть файл

@ -77,6 +77,16 @@ private:
std::vector<std::vector<size_t>> m_phoneboundaryIDBufferMultiUtt;
std::vector<shared_ptr<const msra::dbn::latticepair>> m_extraLatticeBufferMultiUtt;
std::vector<std::vector<size_t>> m_extraLabelsIDBufferMultiUtt;
/* guoye: start */
/* word labels */
std::vector<std::vector<size_t>> m_wlabelsIDBufferMultiUtt;
std::vector<std::vector<size_t>> m_extraWLabelsIDBufferMultiUtt;
std::vector<std::vector<short>> m_nwsBufferMultiUtt;
std::vector<std::vector<short>> m_extraNWsBufferMultiUtt;
/* guoye: end */
std::vector<std::vector<size_t>> m_extraPhoneboundaryIDBufferMultiUtt;
// hmm
@ -109,7 +119,10 @@ private:
void PrepareForWriting(const ConfigRecordType& config);
bool GetMinibatchToTrainOrTest(StreamMinibatchInputs& matrices);
bool GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, std::vector<size_t>& extrauttmap);
/* guoye: start */
// bool GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, std::vector<size_t>& extrauttmap);
bool GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, std::vector<size_t>& extrauttmap);
/* guoye: end */
void fillOneUttDataforParallelmode(StreamMinibatchInputs& matrices, size_t startFr, size_t framenum, size_t channelIndex, size_t sourceChannelIndex); // TODO: PascalCase()
bool GetMinibatchToWrite(StreamMinibatchInputs& matrices);
@ -189,7 +202,10 @@ public:
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
/* guoye: start */
// virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
/* guoye: end */
virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm);
virtual bool DataEnd();

Просмотреть файл

@ -71,6 +71,7 @@
<TreatWarningAsError>true</TreatWarningAsError>
<AdditionalIncludeDirectories Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">..\..\common\include;..\..\Math</AdditionalIncludeDirectories>
<AdditionalIncludeDirectories Condition="'$(Configuration)|$(Platform)'=='Debug_CpuOnly|x64'">..\..\common\include;..\..\Math</AdditionalIncludeDirectories>
<AdditionalOptions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">/bigobj %(AdditionalOptions)</AdditionalOptions>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>

Просмотреть файл

@ -864,14 +864,23 @@ public:
setdata(ts, te, uid);
}
};
/* guoye: start */
template <class ENTRY, class WORDSEQUENCE>
class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
// vector<ENTRY> stores the state-level label, vector<size_t> stores the word-level label
// template <class ENTRY>
class htkmlfreader : public map<wstring, std::pair<vector<ENTRY>, vector<unsigned int>>> // [key][i] the data
// class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
/* guoye: end */
{
wstring curpath; // for error messages
unordered_map<std::string, size_t> statelistmap; // for state <=> index
/* guoye: start */
map<wstring, WORDSEQUENCE> wordsequences; // [key] word sequences (if we are building word entries as well, for MMI)
/* guoye: end */
std::unordered_map<std::string, size_t> symmap;
void strtok(char* s, const char* delim, vector<char*>& toks)
{
toks.resize(0);
@ -900,10 +909,14 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
return lines;
}
template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
// template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
template <typename WORDSYMBOLTABLE>
void parseentry(const vector<std::string>& lines, size_t line, const set<wstring>& restricttokeys,
const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap,
vector<typename WORDSEQUENCE::word>& wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo>& alignseqbuffer,
/* guoye: start */
const WORDSYMBOLTABLE* wordmap, /* const UNITSYMBOLTABLE* unitmap, */
// vector<typename WORDSEQUENCE::word>& wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo>& alignseqbuffer,
/* guoye: end */
const double htkTimeToFrame)
{
size_t idx = 0;
@ -936,13 +949,25 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
// don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs)
if (!restricttokeys.empty() && restricttokeys.find(key) == restricttokeys.end())
return;
/* guoye: start */
// vector<ENTRY>& entries = (*this)[key]; // this creates a new entry
vector<ENTRY>& entries = (*this)[key]; // this creates a new entry
vector<ENTRY>& entries = (*this)[key].first; // this creates a new entry
if (!entries.empty())
malformed(msra::strfun::strprintf("duplicate entry '%ls'", key.c_str()));
/* guoye: start */
// malformed(msra::strfun::strprintf("duplicate entry '%ls'", key.c_str()));
// do not want to die immediately
fprintf(stderr,
"Warning: duplicate entry: %ls \n",
key.c_str());
/* guoye: end */
entries.resize(e - s);
wordseqbuffer.resize(0);
alignseqbuffer.resize(0);
// wordseqbuffer.resize(0);
// alignseqbuffer.resize(0);
vector<size_t>& wordids = (*this)[key].second;
wordids.resize(0);
/* guoye: end */
vector<char*> toks;
for (size_t i = s; i < e; i++)
{
@ -957,13 +982,100 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
{
if (toks.size() > 6 /*word entry are in this column*/)
{
// convert letter to uppercase
if (strcmp(toks[6], "<s>") != 0
&& strcmp(toks[6], "</s>") != 0
&& strcmp(toks[6], "!sent_start") != 0
&& strcmp(toks[6], "!sent_end") != 0
&& strcmp(toks[6], "!silence") != 0)
{
for(size_t j = 0; j < strlen(toks[6]); j++)
{
if(toks[6][j] >= 'a' && toks[6][j] <= 'z')
{
toks[6][j] = toks[6][j] + 'A' - 'a';
}
}
}
const char* w = toks[6]; // the word name
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t) wid;
wordseqbuffer.push_back(typename WORDSEQUENCE::word(wordindex, entries[i - s].firstframe, alignseqbuffer.size()));
/* guoye: start */
// For some alignment MLF the sentence start and end are both represented by <s>, we change sentence end <s> to be </s>
if (i > s && strcmp(w, "<s>") == 0)
{
w = "</s>";
}
/* guoye: end */
/*guoye: start */
/* skip the words that are not used in WER computation */
/* ugly hard code, will improve later */
if (strcmp(w, "<s>") != 0
&& strcmp(w, "</s>") != 0
&& strcmp(w, "!NULL") != 0
&& strcmp(w, "!sent_start") != 0
&& strcmp(w, "!sent_end") != 0
&& strcmp(w, "!silence") != 0
&& strcmp(w, "[/CNON]") != 0
&& strcmp(w, "[/CSPN]") != 0
&& strcmp(w, "[/NPS]") != 0
&& strcmp(w, "[CNON/]") != 0
&& strcmp(w, "[CNON]") != 0
&& strcmp(w, "[CSPN]") != 0
&& strcmp(w, "[FILL/]") != 0
&& strcmp(w, "[NON/]") != 0
&& strcmp(w, "[NONNATIVE/]") != 0
&& strcmp(w, "[NPS]") != 0
&& strcmp(w, "[SB/]") != 0
&& strcmp(w, "[SBP/]") != 0
&& strcmp(w, "[SN/]") != 0
&& strcmp(w, "[SPN/]") != 0
&& strcmp(w, "[UNKNOWN/]") != 0
&& strcmp(w, ".]") != 0
)
{
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
/* guoye: start */
// size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t)wid;
// wordseqbuffer.push_back(typename WORDSEQUENCE::word(wordindex, entries[i - s].firstframe, alignseqbuffer.size()));
static const unsigned int unknownwordindex = 0xfffff;
// TNed the word, to try one more time if there is an OOV
/*
if (wid == -1)
{
// remove / \ * _ - to see if a match could be found
char tnw[200];
size_t i = 0, j = 0;
while (w[i] != '\0')
{
if (w[i] != '\\' && w[i] != '/' && w[i] != '*' && w[i] != '-' && w[i] != '_')
{
tnw[j] = w[i]; j++;
}
i++;
}
tnw[j] = '\0';
wid = (*wordmap)[tnw];
fprintf(stderr,
"Warning: parseentry: wid = %d, new wid = %d, w = %s, tnw = %s \n",
-1, wid, w, tnw);
}
*/
size_t wordindex = (wid == -1) ? unknownwordindex : (size_t)wid;
wordids.push_back(wordindex);
/* guoye: end */
}
/*guoye: end */
}
/* guoye: start */
/*
if (unitmap)
{
if (toks.size() > 4)
{
const char* u = toks[4]; // the triphone name
@ -971,41 +1083,346 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
if (iter == unitmap->end())
RuntimeError("parseentry: unknown unit %s in utterance %ls", u, key.c_str());
const size_t uid = iter->second;
alignseqbuffer.push_back(typename WORDSEQUENCE::aligninfo(uid, 0 /*#frames--we accumulate*/));
alignseqbuffer.push_back(typename WORDSEQUENCE::aligninfo(uid, 0 /*#frames--we accumulate*/ /* ));
}
if (alignseqbuffer.empty())
RuntimeError("parseentry: lonely senone entry at start without phone/word entry found, for utterance %ls", key.c_str());
alignseqbuffer.back().frames += entries[i - s].numframes; // (we do not have an overflow check here, but should...)
}
*/
/* guoye: end */
}
}
if (wordmap) // if reading word sequences as well (for MMI), then record it (in a separate map)
{
if (!entries.empty() && wordseqbuffer.empty())
RuntimeError("parseentry: got state alignment but no word-level info, although being requested, for utterance %ls", key.c_str());
/* guoye: start */
// if (!entries.empty() && wordseqbuffer.empty())
if (!entries.empty() && wordids.empty())
/* guoye: end */
// RuntimeError("parseentry: got state alignment but no word-level info, although being requested, for utterance %ls", key.c_str());
{
fprintf(stderr,
"Warning: parseentry: got state alignment but no word-level info, although being requested, for utterance %ls \n",
key.c_str());
}
// post-process silence
// - first !silence -> !sent_start
// - last !silence -> !sent_end
int silence = (*wordmap)["!silence"];
if (silence >= 0)
else
{
int sentstart = (*wordmap)["!sent_start"]; // these must have been created
int sentend = (*wordmap)["!sent_end"];
// map first and last !silence to !sent_start and !sent_end, respectively
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t) silence)
wordseqbuffer.front().wordindex = sentstart;
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t) silence)
wordseqbuffer.back().wordindex = sentend;
int silence = (*wordmap)["!silence"];
if (silence >= 0)
{
int sentstart = (*wordmap)["!sent_start"]; // these must have been created
int sentend = (*wordmap)["!sent_end"];
// map first and last !silence to !sent_start and !sent_end, respectively
/* guoye: start */
/*
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t) silence)
wordseqbuffer.front().wordindex = sentstart;
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t) silence)
wordseqbuffer.back().wordindex = sentend;
*/
if (sentstart >= 0 && wordids.front() == (size_t)silence)
wordids.front() = sentstart;
if (sentend >= 0 && wordids.back() == (size_t)silence)
wordids.back() = sentend;
/* guoye: end */
}
}
/* guoye: end */
// if (sentstart < 0 || sentend < 0 || silence < 0)
// LogicError("parseentry: word map must contain !silence, !sent_start, and !sent_end");
// implant
/* guoye: start */
/*
auto& wordsequence = wordsequences[key]; // this creates the map entry
wordsequence.words = wordseqbuffer; // makes a copy
wordsequence.align = alignseqbuffer;
*/
/* guoye: end */
}
}
/* guoye: start */
// template <typename UNITSYMBOLTABLE>
void parseentry(const vector<std::string>& lines, size_t line, const set<wstring>& restricttokeys,
/* guoye: start */
const std::unordered_map<std::string, int>& wordidmap, /* const UNITSYMBOLTABLE* unitmap,
vector<typename WORDSEQUENCE::word>& wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo>& alignseqbuffer,
*/
/* guoye: end */
const double htkTimeToFrame)
{
std::unordered_map<std::string, int>::const_iterator mp_itr;
size_t idx = 0;
string filename = lines[idx++];
while (filename == "#!MLF!#") // skip embedded duplicate MLF headers (so user can 'cat' MLFs)
filename = lines[idx++];
// some mlf file have write errors, so skip malformed entry
if (filename.length() < 3 || filename[0] != '"' || filename[filename.length() - 1] != '"')
{
fprintf(stderr, "warning: filename entry (%s)\n", filename.c_str());
fprintf(stderr, "skip current mlf entry from line (%lu) until line (%lu).\n", (unsigned long)(line + idx), (unsigned long)(line + lines.size()));
return;
}
filename = filename.substr(1, filename.length() - 2); // strip quotes
if (filename.find("*/") == 0)
filename = filename.substr(2);
#ifdef _MSC_VER
wstring key = msra::strfun::utf16(regex_replace(filename, regex("\\.[^\\.\\\\/:]*$"), string())); // delete extension (or not if none)
#else
wstring key = msra::strfun::utf16(msra::dbn::removeExtension(filename)); // note that c++ 4.8 is incomplete for supporting regex
#endif
// determine lines range
size_t s = idx;
size_t e = lines.size() - 1;
// lines range: [s,e)
// don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs)
if (!restricttokeys.empty() && restricttokeys.find(key) == restricttokeys.end())
return;
/* guoye: start */
// vector<ENTRY>& entries = (*this)[key]; // this creates a new entry
vector<ENTRY>& entries = (*this)[key].first;
if (!entries.empty())
//malformed(msra::strfun::strprintf("duplicate entry '%ls'", key.c_str()));
// do not want to die immediately
fprintf(stderr,
"Warning: duplicate entry : %ls \n",
key.c_str());
entries.resize(e - s);
vector<unsigned int>& wordids = (*this)[key].second;
wordids.resize(0);
/*
wordseqbuffer.resize(0);
alignseqbuffer.resize(0);
*/
/* guoye: end */
vector<char*> toks;
for (size_t i = s; i < e; i++)
{
// We can mutate the original string as it is no longer needed after tokenization
strtok(const_cast<char*>(lines[i].c_str()), " \t", toks);
if (statelistmap.size() == 0)
entries[i - s].parse(toks, htkTimeToFrame);
else
entries[i - s].parsewithstatelist(toks, statelistmap, htkTimeToFrame, symmap);
// if we also read word entries, do it here
if (wordidmap.size() != 0)
{
if (toks.size() > 6 /*word entry are in this column*/)
{
// convert word to uppercase
if (strcmp(toks[6], "<s>") != 0
&& strcmp(toks[6], "</s>") != 0
&& strcmp(toks[6], "!sent_start") != 0
&& strcmp(toks[6], "!sent_end") != 0
&& strcmp(toks[6], "!silence") != 0)
{
for(size_t j = 0; j < strlen(toks[6]); j++)
{
if(toks[6][j] >= 'a' && toks[6][j] <= 'z')
{
toks[6][j] = toks[6][j] + 'A' - 'a';
}
}
}
const char* w = toks[6]; // the word name
/* guoye: start */
// For some alignment MLF the sentence start and end are both represented by <s>, we change sentence end <s> to be </s>
if (i > s && strcmp(w, "<s>") == 0)
{
w = "</s>";
}
/* guoye: end */
/*guoye: start */
/* skip the words that are not used in WER computation */
/* ugly hard code, will improve later */
if (strcmp(w, "<s>") != 0
&& strcmp(w, "</s>") != 0
&& strcmp(w, "!NULL") != 0
&& strcmp(w, "!sent_start") != 0
&& strcmp(w, "!sent_end") != 0
&& strcmp(w, "!silence") != 0
&& strcmp(w, "[/CNON]") != 0
&& strcmp(w, "[/CSPN]") != 0
&& strcmp(w, "[/NPS]") != 0
&& strcmp(w, "[CNON/]") != 0
&& strcmp(w, "[CNON]") != 0
&& strcmp(w, "[CSPN]") != 0
&& strcmp(w, "[FILL/]") != 0
&& strcmp(w, "[NON/]") != 0
&& strcmp(w, "[NONNATIVE/]") != 0
&& strcmp(w, "[NPS]") != 0
&& strcmp(w, "[SB/]") != 0
&& strcmp(w, "[SBP/]") != 0
&& strcmp(w, "[SN/]") != 0
&& strcmp(w, "[SPN/]") != 0
&& strcmp(w, "[UNKNOWN/]") != 0
&& strcmp(w, ".]") != 0
)
{
// int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
mp_itr = wordidmap.find(std::string(w));
int wid = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
// debug
// int wid = -1;
/* guoye: start */
// size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t)wid;
// guoye: debug
// wordseqbuffer.push_back(typename WORDSEQUENCE::word(wordindex, entries[i - s].firstframe, alignseqbuffer.size()));
// TNed the word, to try one more time if there is an OOV
/**/
/*
if (wid == -1)
{
// remove / \ * _ - to see if a match could be found
char tnw[200];
size_t i1 = 0, j = 0;
while (w[i1] != '\0')
{
if (w[i1] != '\\' && w[i1] != '/' && w[i1] != '*' && w[i1] != '-' && w[i1] != '_')
{
tnw[j] = w[i1]; j++;
}
i1++;
}
tnw[j] = '\0';
mp_itr = wordidmap.find(std::string(tnw));
wid = ((mp_itr == wordidmap.end()) ? -1 : mp_itr->second);
}
*/
static const unsigned int unknownwordindex = 0xfffff;
unsigned int wordindex = (wid == -1) ? unknownwordindex : (unsigned int)wid;
wordids.push_back(wordindex);
/* guoye: end */
}
/*guoye: end */
}
/* guoye: start */
/*
if (unitmap)
{
if (toks.size() > 4)
{
const char* u = toks[4]; // the triphone name
auto iter = unitmap->find(u); // map to unit id
if (iter == unitmap->end())
RuntimeError("parseentry: unknown unit %s in utterance %ls", u, key.c_str());
const size_t uid = iter->second;
alignseqbuffer.push_back(typename WORDSEQUENCE::aligninfo(uid, 0 /*#frames--we accumulate*/ /* ));
}
if (alignseqbuffer.empty())
RuntimeError("parseentry: lonely senone entry at start without phone/word entry found, for utterance %ls", key.c_str());
alignseqbuffer.back().frames += entries[i - s].numframes; // (we do not have an overflow check here, but should...)
}
*/
/* guoye: end */
}
}
if (wordidmap.size() != 0) // if reading word sequences as well (for MMI), then record it (in a separate map)
{
/* guoye: start */
// if (!entries.empty() && wordseqbuffer.empty())
if (!entries.empty() && wordids.empty())
/* guoye: end */
// RuntimeError("parseentry: got state alignment but no word-level info, although being requested, for utterance %ls", key.c_str());
{
fprintf(stderr,
"Warning: parseentry: got state alignment but no word-level info, although being requested, for utterance %ls. Ignoring this utterance for EMBR \n",
key.c_str());
// delete this item
(*this).erase(key);
return;
}
// post-process silence
// - first !silence -> !sent_start
// - last !silence -> !sent_end
else
{
mp_itr = wordidmap.find("!silence");
int silence = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
// debug
// int silence = -1;
if (silence >= 0)
{
mp_itr = wordidmap.find("!sent_start");
int sentstart = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
mp_itr = wordidmap.find("!sent_end");
int sentend = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
// map first and last !silence to !sent_start and !sent_end, respectively
/* guoye: start */
/*
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t)silence)
wordseqbuffer.front().wordindex = sentstart;
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t)silence)
wordseqbuffer.back().wordindex = sentend;
*/
if (sentstart >= 0 && wordids.front() == (size_t)silence)
wordids.front() = sentstart;
if (sentend >= 0 && wordids.back() == (size_t)silence)
wordids.back() = sentend;
/* guoye: end */
}
}
/* guoye: end */
// if (sentstart < 0 || sentend < 0 || silence < 0)
// LogicError("parseentry: word map must contain !silence, !sent_start, and !sent/_end");
// implant
/* guoye: start */
/*
wordseqbuffer.resize(0);
alignseqbuffer.resize(0);
auto& wordsequence = wordsequences[key]; // this creates the map entry
wordsequence.words = wordseqbuffer; // makes a copy
wordsequence.align = alignseqbuffer;
*/
/* guoye: end */
}
}
/* guoye: end */
public:
// return if input statename is sil state (hard code to compared first 3 chars with "sil")
bool issilstate(const string& statename) const // (later use some configuration table)
@ -1013,6 +1430,14 @@ public:
return (statename.size() > 3 && statename.at(0) == 's' && statename.at(1) == 'i' && statename.at(2) == 'l');
}
/* guoye: start */
/*
map<wstring, WORDSEQUENCE> get_wordlabels()
{
return wordsequences;
}
*/
/* guoye: end */
vector<bool> issilstatetable; // [state index] => true if is sil state (cached)
// return if input stateid represent sil state (by table lookup)
@ -1044,9 +1469,48 @@ public:
read(paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame);
}
// note: this function is not designed to be pretty but to be fast
/* guoye: start */
// alternate constructor that takes wordidmap
// template <typename UNITSYMBOLTABLE>
/* guoye: start */
// htkmlfreader(const vector<wstring>& paths, const set<wstring>& restricttokeys, const wstring& stateListPath, const std::unordered_map<std::string, int>& wordidmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
htkmlfreader(const vector<wstring>& paths, const set<wstring>& restricttokeys, const wstring& stateListPath, const std::unordered_map<std::string, int>& wordidmap, const double htkTimeToFrame)
/* guoye: end */
{
// read state list
if (stateListPath != L"")
readstatelist(stateListPath);
// read MLF(s) --note: there can be multiple, so this is a loop
foreach_index(i, paths)
/* guoye: start */
// read(paths[i], restricttokeys, wordidmap, unitmap, htkTimeToFrame);
read(paths[i], restricttokeys, wordidmap, htkTimeToFrame);
/* guoye: end */
}
/* guoye: end */
// phone boundary
template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
void read(const wstring& path, const set<wstring>& restricttokeys, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
htkmlfreader(const vector<wstring>& paths, const set<wstring>& restricttokeys, const wstring& stateListPath, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap,
const double htkTimeToFrame, const msra::asr::simplesenonehmm& hset)
{
if (stateListPath != L"")
readstatelist(stateListPath);
symmap = hset.symmap;
foreach_index (i, paths)
read(paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame);
}
// note: this function is not designed to be pretty but to be fast
/* guoye: start */
// template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
template <typename WORDSYMBOLTABLE>
// void read(const wstring& path, const set<wstring>& restricttokeys, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
void read(const wstring& path, const set<wstring>& restricttokeys, const WORDSYMBOLTABLE* wordmap, const double htkTimeToFrame)
/* guoye: end */
{
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
return;
@ -1060,8 +1524,12 @@ public:
malformed("header missing");
// Read the file in blocks and parse MLF entries
/* guoye: start */
/*
std::vector<typename WORDSEQUENCE::word> wordsequencebuffer;
std::vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
*/
/* guoye: end */
size_t readBlockSize = 1000000;
std::vector<char> currBlockBuf(readBlockSize + 1);
size_t currLineNum = 1;
@ -1091,7 +1559,10 @@ public:
{
if (restricttokeys.empty() || (this->size() < restricttokeys.size()))
{
parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
/* guoye: start */
// parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, htkTimeToFrame);
/* guoye: end */
}
currMLFLines.clear();
@ -1134,6 +1605,104 @@ public:
fprintf(stderr, " total %lu entries\n", (unsigned long)this->size());
}
// note: this function is not designed to be pretty but to be fast
/* guoye: start */
// template <typename UNITSYMBOLTABLE>
// void read(const wstring& path, const set<wstring>& restricttokeys, const std::unordered_map<std::string, int>& wordidmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
void read(const wstring& path, const set<wstring>& restricttokeys, const std::unordered_map<std::string, int>& wordidmap, const double htkTimeToFrame)
{
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
return;
fprintf(stderr, "htkmlfreader: reading MLF file %ls ...", path.c_str());
curpath = path; // for error messages only
auto_file_ptr f(fopenOrDie(path, L"rb"));
std::string headerLine = fgetline(f);
if (headerLine != "#!MLF!#")
malformed("header missing");
// Read the file in blocks and parse MLF entries
/* guoye: start */
/*
std::vector<typename WORDSEQUENCE::word> wordsequencebuffer;
std::vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
*/
/* guoye: end */
size_t readBlockSize = 1000000;
std::vector<char> currBlockBuf(readBlockSize + 1);
size_t currLineNum = 1;
std::vector<string> currMLFLines;
bool reachedEOF = (feof(f) != 0);
char* nextReadPtr = currBlockBuf.data();
size_t nextReadSize = readBlockSize;
while (!reachedEOF)
{
size_t numBytesRead = fread(nextReadPtr, sizeof(char), nextReadSize, f);
reachedEOF = (numBytesRead != nextReadSize);
if (ferror(f))
RuntimeError("error reading from file: %s", strerror(errno));
// Add 0 at the end to make it a proper C string
nextReadPtr[numBytesRead] = 0;
// Now extract lines from the currBlockBuf and parse MLF entries
char* context = nullptr;
const char* delim = "\r\n";
auto consumeMLFLine = [&](const char* mlfLine)
{
currLineNum++;
currMLFLines.push_back(mlfLine);
if ((mlfLine[0] == '.') && (mlfLine[1] == 0)) // utterance end delimiter: a single dot on a line
{
if (restricttokeys.empty() || (this->size() < restricttokeys.size()))
{
// parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordidmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordidmap, htkTimeToFrame);
}
currMLFLines.clear();
}
};
char* prevLine = strtok_s(currBlockBuf.data(), delim, &context);
for (char* currLine = strtok_s(NULL, delim, &context); currLine; currLine = strtok_s(NULL, delim, &context))
{
consumeMLFLine(prevLine);
prevLine = currLine;
}
// The last line read from the block may be a full line or part of a line
// We can tell by whether the terminating NULL for this line is the NULL
// we inserted after reading from the file
size_t prevLineLen = strlen(prevLine);
if ((prevLine + prevLineLen) == (nextReadPtr + numBytesRead))
{
// This is not a full line, but just a truncated part of a line.
// Lets copy this to the start of the currBlockBuf and read new data
// from there on
strcpy_s(currBlockBuf.data(), currBlockBuf.size(), prevLine);
nextReadPtr = currBlockBuf.data() + prevLineLen;
nextReadSize = readBlockSize - prevLineLen;
}
else
{
// A full line
consumeMLFLine(prevLine);
nextReadPtr = currBlockBuf.data();
nextReadSize = readBlockSize;
}
}
if (!currMLFLines.empty())
malformed("unexpected end in mid-utterance");
curpath.clear();
fprintf(stderr, " total %lu entries\n", (unsigned long)this->size());
}
/* guoye: end */
// read state list, index is from 0
void readstatelist(const wstring& stateListPath = L"")
{
@ -1168,10 +1737,14 @@ public:
}
// access to word sequences
/* guoye: start */
const map<wstring, WORDSEQUENCE>& allwordtranscripts() const
{
return wordsequences;
}
/* guoye: end */
};
};
}; // namespaces

Просмотреть файл

@ -405,8 +405,11 @@ void lattice::dedup()
// - empty ("") -> don't output, just check the format
// - dash ("-") -> dump lattice to stdout instead
/*static*/ void archive::convert(const std::wstring &intocpath, const std::wstring &intocpath2, const std::wstring &outpath,
const msra::asr::simplesenonehmm &hset)
{
/* guoye: start */
// const msra::asr::simplesenonehmm &hset)
const msra::asr::simplesenonehmm &hset, std::set<int>& specialwordids)
/* guoye: end */
{
const auto &modelsymmap = hset.getsymmap();
const std::wstring tocpath = outpath + L".toc";
@ -457,8 +460,10 @@ void lattice::dedup()
// fetch lattice --this performs any necessary format conversions already
lattice L;
archive.getlattice(key, L);
/* guoye: start */
// archive.getlattice(key, L);
archive.getlattice(key, L, specialwordids);
/* guoye: end */
lattice L2;
if (mergemode)
{
@ -468,8 +473,10 @@ void lattice::dedup()
skippedmerges++;
continue;
}
archive2.getlattice(key, L2);
/* guoye: start */
// archive2.getlattice(key, L2);
archive2.getlattice(key, L2, specialwordids);
/* guoye: end */
// merge it in
// This will connect each node with matching 1-phone context conditions; aimed at merging numer lattices.
L.removefinalnull(); // get rid of that final !NULL headache
@ -563,6 +570,9 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map<std::
assert(info.numnodes > 0);
nodes.reserve(info.numnodes);
/* guoye: start */
vt_node_out_edge_indices.resize(info.numnodes);
/* guoye: end */
// parse the nodes
for (size_t i = 0; i < info.numnodes; i++, iter++)
{
@ -570,11 +580,24 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map<std::
RuntimeError("lattice: not enough I lines in lattice");
unsigned long itest;
float t;
if (sscanf_s(*iter, "I=%lu t=%f%c", &itest, &t, &dummychar, (unsigned int)sizeof(dummychar)) < 2)
/* guoye: start */
char d[100];
// if (sscanf_s(*iter, "I=%lu t=%f%c", &itest, &t, &dummychar, (unsigned int)sizeof(dummychar)) < 2)
if (sscanf_s(*iter, "I=%lu t=%f W=%s", &itest, &t, &d, (unsigned int)sizeof(d)) < 3)
RuntimeError("lattice: mal-formed node line in lattice: %s", *iter);
/* guoye: end */
if (i != (size_t) itest)
RuntimeError("lattice: out-of-sequence node line in lattice: %s", *iter);
nodes.push_back(nodeinfo((unsigned int) (t / info.frameduration + 0.5)));
/* guoye: start */
// nodes.push_back(nodeinfo((unsigned int) (t / info.frameduration + 0.5)));
// To do: we need to map the d to the wordid. It is P2 task.
// For current speech production pipeline, we read from lattice archive rather than from the raw lattice. So, this code is actually not used.
nodes.push_back(nodeinfo((unsigned int)(t / info.frameduration + 0.5), 0));
/* guoye: end */
info.numframes = max(info.numframes, (size_t) nodes.back().t);
}
// parse the edges
@ -600,6 +623,10 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map<std::
if (j != (size_t) jtest)
RuntimeError("lattice: out-of-sequence edge line in lattice: %s", *iter);
edges.push_back(edgeinfowithscores(S, E, a, l, align.size()));
/* guoye: start */
vt_node_out_edge_indices[S].push_back(j);
/* guoye: end */
// build align array
size_t edgeframes = 0; // (for checking whether the alignment sums up right)
const char *p = d;
@ -731,5 +758,10 @@ void lattice::frommlf(const wstring &key2, const std::unordered_map<std::string,
showstats();
}
};
};

Просмотреть файл

@ -34,35 +34,64 @@ public:
// - lattices are returned as a shared_ptr
// Thus, getbatch() can be called in a thread-safe fashion, allowing for a 'minibatchsource' implementation that wraps another with a read-ahead thread.
// Return value is 'true' if it did read anything from disk, and 'false' if data came only from RAM cache. This is used for controlling the read-ahead thread.
virtual bool getbatch(const size_t globalts,
const size_t framesrequested, msra::dbn::matrix &feat, std::vector<size_t> &uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices) = 0;
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
virtual bool getbatch(const size_t globalts,
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries) = 0;
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries) = 0;
// getbatch() overload to support subsetting of mini-batches for parallel training
// Default implementation does not support subsetting and throws an exception on
// calling this overload with a numsubsets value other than 1.
virtual bool getbatch(const size_t globalts,
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries)
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries)
{
assert((subsetnum == 0) && (numsubsets == 1) && !supportsbatchsubsetting());
subsetnum;
numsubsets;
bool retVal = getbatch(globalts, framesrequested, feat, uids, transcripts, lattices, sentendmark, phoneboundaries);
framesadvanced = feat[0].cols();
return retVal;
}
/* guoye: start */
virtual bool getbatch(const size_t globalts,
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids, std::vector<std::vector<size_t>> &wids, std::vector<std::vector<short>> &nws,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries)
{
wids.resize(0);
nws.resize(0);
bool retVal = getbatch(globalts, framesrequested, subsetnum, numsubsets, framesadvanced, feat, uids, transcripts, lattices, sentendmark, phoneboundaries);
return retVal;
}
/* guoye: end */
virtual bool supportsbatchsubsetting() const
{
return false;
@ -102,6 +131,10 @@ class minibatchiterator
std::vector<msra::dbn::matrix> featbuf; // buffer for holding curernt minibatch's frames
std::vector<std::vector<size_t>> uids; // buffer for storing current minibatch's frame-level label sequence
/* guoye: start */
std::vector<std::vector<size_t>> wids; // buffer for storing current minibatch's word-level label sequence
std::vector<std::vector<short>> nws; // buffer for storing current minibatch's number of words for each utterance
/* guoye: end */
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts; // buffer for storing current minibatch's word-level label sequences (if available and used; empty otherwise)
std::vector<std::shared_ptr<const latticesource::latticepair>> lattices; // lattices of the utterances in current minibatch (empty in frame mode)
@ -126,7 +159,13 @@ private:
foreach_index (i, uids)
uids[i].clear();
/* guoye: start */
foreach_index(i, wids)
wids[i].clear();
foreach_index(i, nws)
nws[i].clear();
/* guoye: end */
transcripts.clear();
actualmbframes = 0;
return;
@ -135,7 +174,10 @@ private:
assert(requestedmbframes > 0);
const size_t requestedframes = std::min(requestedmbframes, epochendframe - mbstartframe); // (< mbsize at end)
assert(requestedframes > 0);
source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices, sentendmark, phoneboundaries);
/* guoye: start */
// source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices, sentendmark, phoneboundaries);
source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, wids, nws, transcripts, lattices, sentendmark, phoneboundaries);
/* guoye: end */
timegetbatch = source.gettimegetbatch();
actualmbframes = featbuf[0].cols(); // for single i/o, there featbuf is length 1
// note:
@ -314,6 +356,26 @@ public:
assert(uids.size() >= i + 1);
return uids[i];
}
/* guoye: start */
// return the reference transcript word labels (word labels) for current minibatch
/*const*/ std::vector<size_t> &wlabels()
{
checkhasdata();
assert(wids.size() == 1);
return wids[0];
}
// return the number of words for current minibatch
/*const*/ std::vector<short> &nwords()
{
checkhasdata();
assert(nws.size() == 1);
return nws[0];
}
/* guoye: end */
std::vector<size_t> &sentends()
{

Просмотреть файл

@ -194,6 +194,9 @@ static void augmentneighbors(const std::vector<std::vector<float>>& frames, cons
// TODO: This is currently being hardcoded to unsigned short for saving space, which means untied context-dependent phones
// will not work. This needs to be changed to dynamically choose what size to use based on the number of class ids.
typedef unsigned short CLASSIDTYPE;
/* guoye: start */
typedef unsigned int WORDIDTYPE;
/* guoye: end */
typedef unsigned short HMMIDTYPE;
#ifndef _MSC_VER

Просмотреть файл

@ -1,4 +1,4 @@
//
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
@ -9,12 +9,16 @@
#include "Basics.h"
#include "fileutil.h" // for opening/reading the ARPA file
/* guoye: start */
// #include "fileutil.cpp"
/* guoye: end */
#include <vector>
#include <string>
#include <unordered_map>
#include <algorithm> // for various sort() calls
#include <math.h>
namespace msra { namespace lm {
// ===========================================================================
@ -92,15 +96,40 @@ static inline double invertlogprob(double logP)
// compare function to allow char* as keys (without, unordered_map will correctly
// compute a hash key from the actual strings, but then compare the pointers
// -- duh!)
struct less_strcmp : public std::binary_function<const char *, const char *, bool>
/* guoye: start */
// struct less_strcmp : public std::binary_function<const char *, const char *, bool>
struct equal_strcmp : public std::binary_function<const char *, const char *, bool>
{ // this implements operator<
bool operator()(const char *const &_Left, const char *const &_Right) const
{
return strcmp(_Left, _Right) < 0;
// return strcmp(_Left, _Right) < 0;
return strcmp(_Left, _Right) == 0;
}
};
/* guoye: end */
struct BKDRHash {
//BKDR hash algorithm
int operator()(const char * str)const
{
unsigned int seed = 131; //31 131 1313 13131131313 etc//
unsigned int hash = 0;
while (*str)
{
hash = (hash * seed) + (*str);
str++;
}
return hash & (0x7FFFFFFF);
}
};
class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char *>, less_strcmp>
/* guoye: start */
/* bug fix: the customize function of compare should be written in the one commented below is not right. The generated behavior is very strange: it does not correctly make a map. So, fix it. */
// class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char *>, less_strcmp>
// class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char *>, equal_strcmp>
class CSymbolSet : public std::unordered_map<const char *, int, BKDRHash, equal_strcmp>
/* guoye: end */
{
std::vector<const char *> symbols; // the symbols
@ -128,7 +157,9 @@ public:
// get id for an existing word, returns -1 if not existing
int operator[](const char *key) const
{
unordered_map<const char *, int>::const_iterator iter = find(key);
/* guoye: start */
// unordered_map<const char *, int>::const_iterator iter = find(key);
unordered_map<const char *, int, BKDRHash, equal_strcmp>::const_iterator iter = find(key);
return (iter != end()) ? iter->second : -1;
}
@ -136,7 +167,10 @@ public:
// determine unique id for a word ('key')
int operator[](const char *key)
{
unordered_map<const char *, int>::const_iterator iter = find(key);
/* guoye: start */
// unordered_map<const char *, int>::const_iterator iter = find(key);
unordered_map<const char *, int, BKDRHash, equal_strcmp>::const_iterator iter = find(key);
if (iter != end())
return iter->second;
@ -149,7 +183,11 @@ public:
{
int id = (int) symbols.size();
symbols.push_back(p); // we own the memory--remember to free it
insert(std::make_pair(p, id));
/* guoye: start */
// insert(std::make_pair(p, id));
if(!insert(std::make_pair(p, id)).second)
RuntimeError("Insertion key %s into map failed in msra_mgram.h", p);
/* guoye: end */
return id;
}
catch (...)
@ -1450,30 +1488,67 @@ public:
int lineNo = 0;
auto_file_ptr f(fopenOrDie(pathname, L"rbS"));
fprintf(stderr, "read: reading %ls", pathname.c_str());
/* guoye: start */
//fprintf(stderr, "\n msra_mgram.h: read: debug 0\n");
/* guoye: end */
filename = pathname; // (keep this info for debugging)
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.1\n");
/* guoye: end */
// --- read header information
// search for header line
char buf[1024];
lineNo++, fgetline(f, buf);
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.2\n");
/* guoye: end */
/* guoye: start */
// lineNo++, fgetline(f, buf);
lineNo++;
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.25\n");
fgetline(f, buf);
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.3\n");
/* guoye: end */
while (strcmp(buf, "\\data\\") != 0 && !feof(f))
lineNo++, fgetline(f, buf);
lineNo++, fgetline(f, buf);
/* guoye: start */
{
// lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
}
/* guoye: end */
/* guoye: start */
// lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
// fprintf(stderr, "\n msra_mgram.h: read: debug 1\n");
/* guoye: end */
// get the dimensions
std::vector<int> dims;
dims.reserve(4);
while (buf[0] == 0 && !feof(f))
lineNo++, fgetline(f, buf);
/* guoye: start */
{
// lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
}
/* guoye: end */
int n, dim;
dims.push_back(1); // dummy zerogram entry
while (sscanf(buf, "ngram %d=%d", &n, &dim) == 2 && n == (int) dims.size())
{
dims.push_back(dim);
lineNo++, fgetline(f, buf);
/* guoye: start */
// lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
/* guoye: end */
}
M = (int) dims.size() - 1;
@ -1483,6 +1558,9 @@ public:
if (M > maxM)
M = maxM;
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 2\n");
/* guoye: end */
// allocate main storage
map.init(M);
logP.init(M);
@ -1502,18 +1580,34 @@ public:
std::vector<bool> skipWord; // true: skip entry containing this word
skipWord.reserve(lmSymbols.capacity());
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 3\n");
/* guoye: end */
// --- read main sections
const double ln10xLMF = log(10.0); // ARPA scores are strangely scaled
msra::strfun::tokenizer tokens(" \t\n\r", M + 1); // used in tokenizing the input line
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 4\n");
/* guoye: end */
for (int m = 1; m <= M; m++)
{
while (buf[0] == 0 && !feof(f))
lineNo++, fgetline(f, buf);
/* guoye: start */
{
// lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
}
/* guoye: end */
if (sscanf(buf, "\\%d-grams:", &n) != 1 || n != m)
RuntimeError("read: mal-formed LM file, bad section header (%d): %ls", lineNo, pathname.c_str());
lineNo++, fgetline(f, buf);
/* guoye: start */
//lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
/* guoye: end */
std::vector<int> mgram(m + 1, -1); // current mgram being read ([0]=dummy)
std::vector<int> prevmgram(m + 1, -1); // cache to speed up symbol lookup
@ -1524,7 +1618,11 @@ public:
{
if (buf[0] == 0)
{
lineNo++, fgetline(f, buf);
/* guoye: start */
// lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
/* guoye: end */
continue;
}
@ -1576,8 +1674,11 @@ public:
double boVal = atof(tokens[m + 1]); // ... use sscanf() instead for error checking?
thisLogB = boVal * ln10xLMF; // convert to natural log
}
lineNo++, fgetline(f, buf);
/* guoye: start */
// lineNo++, fgetline(f, buf);
lineNo++;
fgetline(f, buf);
/* guoye: end */
if (skipEntry) // word contained unknown vocabulary: skip entire entry
goto skipMGram;
@ -1615,17 +1716,29 @@ public:
fprintf(stderr, ", %d %d-grams", map.size(m), m);
}
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 5\n");
/* guoye: end */
fprintf(stderr, "\n");
// check end tag
if (M == fileM)
{ // only if caller did not restrict us to a lower order
while (buf[0] == 0 && !feof(f))
lineNo++, fgetline(f, buf);
/* guoye: start */
{
lineNo++;
fgetline(f, buf);
// lineNo++, fgetline(f, buf);
}
if (strcmp(buf, "\\end\\") != 0)
RuntimeError("read: mal-formed LM file, no \\end\\ tag (%d): %ls", lineNo, pathname.c_str());
}
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 6 \n");
/* guoye: end */
// update zerogram score by one appropriate for OOVs
updateOOVScore();
@ -1638,7 +1751,13 @@ public:
int id = symbolToId(sym); // may be -1 if not found
userToLMSymMap[i] = id;
}
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 7 \n");
/* guoye: end */
map.created(userToLMSymMap);
/* guoye: start */
// fprintf(stderr, "\n msra_mgram.h: read: debug 8 \n");
/* guoye: end */
}
protected:

Просмотреть файл

@ -561,7 +561,10 @@ class minibatchframesourcemulti : public minibatchsource
public:
// constructor
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
minibatchframesourcemulti(const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
/* guoye: start */
// minibatchframesourcemulti(const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
minibatchframesourcemulti(const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::pair<std::vector<msra::asr::htkmlfentry>, std::vector<unsigned int>>>> &labels,
/* guoye: end */
std::vector<size_t> vdim, std::vector<size_t> udim, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t randomizationrange, const std::vector<std::wstring> &pagepath, const bool mayhavenoframe = false, int addEnergy = 0)
: vdim(vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod(0), featdim(0), numframes(0), timegetbatch(0), verbosity(2), maxvdim(0)
{
@ -656,7 +659,10 @@ public:
// HVite occasionally generates mismatching output --skip such files
if (!key.empty()) // (we have a key if supervised mode)
{
const auto &labseq = labels[0].find(key)->second; // (we already checked above that it exists)
/* guoye: start */
// const auto &labseq = labels[0].find(key)->second; // (we already checked above that it exists)
const auto &labseq = labels[0].find(key)->second.first; // (we already checked above that it exists)
/* guoye: end */
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size() - 1].firstframe + labseq[labseq.size() - 1].numframes);
if (abs((int) labframes - (int) feat.cols()) > 0)
{
@ -695,7 +701,7 @@ public:
{
foreach_index (j, labels)
{
const auto &labseq = labels[j].find(key)->second; // (we already checked above that it exists)
const auto &labseq = labels[j].find(key)->second.first; // (we already checked above that it exists)
foreach_index (i2, labseq)
{
const auto &e = labseq[i2];

Просмотреть файл

@ -14,6 +14,9 @@
#include "minibatchiterator.h"
#include <unordered_set>
#include <random>
/* guoye: start */
#include <set>
/* guoye: end */
namespace msra { namespace dbn {
@ -36,6 +39,9 @@ class minibatchutterancesourcemulti : public minibatchsource
const bool truncated; //false -> truncated utterance or not within minibatch
size_t maxUtteranceLength; //10000 ->maximum utterance length in non-frame and non-truncated mode
/* guoye: start */
std::set<int> specialwordids; // stores the word ids that will not be counted for WER computation
/* guoye: end */
std::vector<std::vector<size_t>> counts; // [s] occurence count for all states (used for priors)
int verbosity;
// lattice reader
@ -55,9 +61,18 @@ class minibatchutterancesourcemulti : public minibatchsource
{
msra::asr::htkfeatreader::parsedpath parsedpath; // archive filename and frame range in that file
size_t classidsbegin; // index into allclassids[] array (first frame)
/* guoye: start */
size_t wordidsbegin;
short numwords;
//utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin)
// : parsedpath(std::move(ppath)), classidsbegin(classidsbegin), framesToExpand(0), needsExpansion(false)
utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin, size_t wordidsbegin)
: parsedpath(std::move(ppath)), classidsbegin(classidsbegin), wordidsbegin(wordidsbegin), framesToExpand(0), needsExpansion(false)
/* guoye: end */
utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin)
: parsedpath(std::move(ppath)), classidsbegin(classidsbegin), framesToExpand(0), needsExpansion(false)
{
}
bool needsExpansion; // ivector type of feature
@ -73,6 +88,17 @@ class minibatchutterancesourcemulti : public minibatchsource
else
return parsedpath.numframes();
}
/* guoye: start */
short getnumwords() const
{
return numwords;
}
void setnumwords(short nw)
{
numwords = nw;
}
/* guoye: end */
std::wstring key() const // key used for looking up lattice (not stored to save space)
{
#ifdef _MSC_VER
@ -129,6 +155,18 @@ class minibatchutterancesourcemulti : public minibatchsource
{
return utteranceset[i].classidsbegin;
}
/* guoye: start */
size_t getwordidsbegin(size_t i) const
{
return utteranceset[i].wordidsbegin;
}
short numwords(size_t i) const
{
return utteranceset[i].numwords;
}
/* guoye: end */
msra::dbn::matrixstripe getutteranceframes(size_t i) const // return the frame set for a given utterance
{
if (!isinram())
@ -152,8 +190,9 @@ class minibatchutterancesourcemulti : public minibatchsource
}
// page in data for this chunk
// We pass in the feature info variables by ref which will be filled lazily upon first read
void requiredata(std::string &featkind, size_t &featdim, unsigned int &sampperiod, const latticesource &latticesource, int verbosity = 0) const
void requiredata(std::string &featkind, size_t &featdim, unsigned int &sampperiod, const latticesource &latticesource, std::set<int>& specialwordids, int verbosity = 0) const
{
if (numutterances() == 0)
LogicError("requiredata: cannot page in virgin block");
if (isinram())
@ -181,8 +220,12 @@ class minibatchutterancesourcemulti : public minibatchsource
auto uttframes = getutteranceframes(i); // matrix stripe for this utterance (currently unfilled)
reader.read(utteranceset[i].parsedpath, (const std::string &)featkind, sampperiod, uttframes, utteranceset[i].needsExpansion); // note: file info here used for checkuing only
// page in lattice data
/* guoye: start */
if (!latticesource.empty())
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols());
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols(), specialwordids);
/* guoye: end */
}
if (verbosity)
{
@ -234,6 +277,10 @@ class minibatchutterancesourcemulti : public minibatchsource
std::vector<std::vector<utterancechunkdata>> allchunks; // set of utterances organized in chunks, referred to by an iterator (not an index)
std::vector<std::unique_ptr<biggrowablevector<CLASSIDTYPE>>> classids; // [classidsbegin+t] concatenation of all state sequences
/* guoye: start */
std::vector<std::unique_ptr<biggrowablevector<WORDIDTYPE>>> wordids; // [wordidsbegin+t] concatenation of all state sequences
/* guoye: end */
bool m_generatePhoneBoundaries;
std::vector<std::unique_ptr<biggrowablevector<HMMIDTYPE>>> phoneboundaries;
bool issupervised() const
@ -299,6 +346,9 @@ class minibatchutterancesourcemulti : public minibatchsource
}
size_t numframes; // (cached since we cannot directly access the underlying data from here)
/* guoye: start */
short numwords;
/* guoye: end */
size_t globalts; // start frame in global space after randomization (for mapping frame index to utterance position)
size_t globalte() const
{
@ -850,6 +900,32 @@ class minibatchutterancesourcemulti : public minibatchsource
}
return allclassids; // nothing to return
}
/* guoye: start */
template <class UTTREF>
std::vector<shiftedvector<biggrowablevector<WORDIDTYPE>>> getwordids(const UTTREF &uttref) // return sub-vector of classids[] for a given utterance
{
std::vector<shiftedvector<biggrowablevector<WORDIDTYPE>>> allwordids;
if (!issupervised())
{
foreach_index(i, wordids)
allwordids.push_back(std::move(shiftedvector<biggrowablevector<WORDIDTYPE>>((*wordids[i]), 0, 0)));
return allwordids; // nothing to return
}
const auto &chunk = randomizedchunks[0][uttref.chunkindex];
const auto &chunkdata = chunk.getchunkdata();
const size_t wordidsbegin = chunkdata.getwordidsbegin(uttref.utteranceindex()); // index of first state label in global concatenated classids[] array
const size_t n = chunkdata.numwords(uttref.utteranceindex());
foreach_index(i, wordids)
{
if ((*wordids[i])[wordidsbegin + n] != (WORDIDTYPE)-1)
LogicError("getwordids: expected boundary marker not found, internal data structure screwed up");
allwordids.push_back(std::move(shiftedvector<biggrowablevector<WORDIDTYPE>>((*wordids[i]), wordidsbegin, n)));
}
return allwordids; // nothing to return
}
/* guoye: end */
template <class UTTREF>
std::vector<shiftedvector<biggrowablevector<HMMIDTYPE>>> getphonebound(const UTTREF &uttref) // return sub-vector of classids[] for a given utterance
{
@ -882,13 +958,23 @@ public:
// constructor
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
// This mode requires utterances with time stamps.
minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
/* guoye: start */
// minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
// minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::pair<std::vector<msra::asr::htkmlfentry>, std::vector<unsigned int>>>> &labels,
// const std::vector<std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence>>& wordlabels,
std::set<int>& specialwordids,
/* guoye: end */
std::vector<size_t> vdim, std::vector<size_t> udim, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t randomizationrange,
const latticesource &lattices, const std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence> &allwordtranscripts, const bool framemode, std::vector<bool> expandToUtt,
const size_t maxUtteranceLength, const bool truncated)
: vdim(vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod(0), featdim(0), randomizationrange(randomizationrange), currentsweep(SIZE_MAX),
lattices(lattices), allwordtranscripts(allwordtranscripts), framemode(framemode), chunksinram(0), timegetbatch(0), verbosity(2), m_generatePhoneBoundaries(!lattices.empty()),
m_frameRandomizer(randomizedchunks, useMersenneTwister), expandToUtt(expandToUtt), m_useMersenneTwister(useMersenneTwister), maxUtteranceLength(maxUtteranceLength), truncated(truncated)
/* guoye: start */
, specialwordids(specialwordids)
/* guoye: end */
// [v-hansu] change framemode (lattices.empty()) into framemode (false) to run utterance mode without lattice
// you also need to change another line, search : [v-hansu] comment out to run utterance mode without lattice
{
@ -905,6 +991,9 @@ public:
std::vector<size_t> uttduration; // track utterance durations to determine utterance validity
std::vector<size_t> classidsbegin;
/* guoye: start */
std::vector<size_t> wordidsbegin;
/* guoye: end */
allchunks = std::vector<std::vector<utterancechunkdata>>(infiles.size(), std::vector<utterancechunkdata>());
featdim = std::vector<size_t>(infiles.size(), 0);
@ -917,6 +1006,9 @@ public:
foreach_index (i, labels)
{
classids.push_back(std::unique_ptr<biggrowablevector<CLASSIDTYPE>>(new biggrowablevector<CLASSIDTYPE>()));
/* guoye: start */
wordids.push_back(std::unique_ptr<biggrowablevector<WORDIDTYPE>>(new biggrowablevector<WORDIDTYPE>()));
/* guoye: end */
if (m_generatePhoneBoundaries)
phoneboundaries.push_back(std::unique_ptr<biggrowablevector<HMMIDTYPE>>(new biggrowablevector<HMMIDTYPE>()));
@ -945,7 +1037,10 @@ public:
foreach_index (i, infiles[m])
{
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0); // mseltzer - is this foolproof for multiio? is classids always non-empty?
/* guoye: start */
// utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0); // mseltzer - is this foolproof for multiio? is classids always non-empty?
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0, 0);
/* guoye: end */
const size_t uttframes = utterance.numframes(); // will throw if frame bounds not given --required to be given in this mode
if (expandToUtt[m] && uttframes != 1)
RuntimeError("minibatchutterancesource: utterance-based features must be 1 frame in duration");
@ -1002,9 +1097,13 @@ public:
// else
// if (infiles[m].size()!=numutts)
// RuntimeError("minibatchutterancesourcemulti: all feature files must have same number of utterances\n");
/* guoye: start */
if (m == 0)
{
classidsbegin.clear();
wordidsbegin.clear();
}
/* guoye: end */
foreach_index (i, infiles[m])
{
if (i % (infiles[m].size() / 100 + 1) == 0)
@ -1013,12 +1112,21 @@ public:
fflush(stderr);
}
// build utterance descriptor
/* guoye: start */
if (m == 0 && !labels.empty())
{
classidsbegin.push_back(classids[0]->size());
wordidsbegin.push_back(wordids[0]->size());
}
/* guoye: end */
if (uttisvalid[i])
{
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty?
/* guoye: start */
// utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty?
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i], labels.empty() ? 0 : wordidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty?
/* guoye: end */
const size_t uttframes = utterance.numframes(); // will throw if frame bounds not given --required to be given in this mode
if (expandToUtt[m])
{
@ -1078,7 +1186,10 @@ public:
// first verify that all the label files have the proper duration
foreach_index (j, labels)
{
const auto &labseq = labels[j].find(key)->second;
/* guoye: start */
// const auto &labseq = labels[j].find(key)->second;
const auto &labseq = labels[j].find(key)->second.first;
/* guoye: end */
// check if durations match; skip if not
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size() - 1].firstframe + labseq[labseq.size() - 1].numframes);
if (labframes != uttframes)
@ -1092,12 +1203,19 @@ public:
}
if (uttisvalid[i])
{
utteranceset.push_back(std::move(utterance));
/* guoye: start */
// utteranceset.push_back(std::move(utterance));
/* guoye: end */
_totalframes += uttframes;
// then parse each mlf if the durations are consistent
foreach_index (j, labels)
{
const auto &labseq = labels[j].find(key)->second;
/* guoye: start */
// const auto &labseq = labels[j].find(key)->second;
const auto & seqs = labels[j].find(key)->second;
// const auto &labseq = labels[j].find(key)->second.first;
const auto &labseq = seqs.first;
/* guoye: end */
// expand classid sequence into flat array
foreach_index (i2, labseq)
{
@ -1126,18 +1244,75 @@ public:
}
classids[j]->push_back((CLASSIDTYPE) -1); // append a boundary marker marker for checking
if (m_generatePhoneBoundaries)
phoneboundaries[j]->push_back((HMMIDTYPE) -1); // append a boundary marker marker for checking
/* guoye: start */
/*
if (!labels[j].empty() && classids[j]->size() != _totalframes + utteranceset.size())
LogicError("minibatchutterancesource: label duration inconsistent with feature file in MLF label set: %ls", key.c_str());
assert(labels[j].empty() || classids[j]->size() == _totalframes + utteranceset.size());
*/
// guoye: because we do utteranceset.push_back(std::move(utterance)) in the late stage
if (!labels[j].empty() && classids[j]->size() != _totalframes + utteranceset.size() + 1)
LogicError("minibatchutterancesource: label duration inconsistent with feature file in MLF label set: %ls", key.c_str());
assert(labels[j].empty() || classids[j]->size() == _totalframes + utteranceset.size() + 1);
/* guoye: end */
const auto &wordlabseq = seqs.second;
if (j == 0)
utterance.setnumwords(short(wordlabseq.size()));
foreach_index(i2, wordlabseq)
{
const auto &e = wordlabseq[i2];
if (e != (WORDIDTYPE)e)
RuntimeError("WORDIDTYPE has too few bits");
wordids[j]->push_back(e);
}
wordids[j]->push_back((WORDIDTYPE)-1); // append a boundary marker marker for checking
}
/* guoye: start */
/* mask for guoye debug */
/*
foreach_index(j, wordlabels)
{
const auto &wordlabseq = wordlabels[j].find(key)->second.words;
// expand classid sequence into flat array
if (j == 0)
utterance.setnumwords(short(wordlabseq.size()));
foreach_index(i2, wordlabseq)
{
const auto &e = wordlabseq[i2];
if (e.wordindex != (WORDIDTYPE)e.wordindex)
RuntimeError("WORDIDTYPE has too few bits");
wordids[j]->push_back(e.wordindex);
}
wordids[j]->push_back((WORDIDTYPE)-1); // append a boundary marker marker for checking
}
*/
/* guoye: end */
utteranceset.push_back(std::move(utterance));
}
}
else
{
assert(classids.empty() && labels.empty());
/* guoye: start */
// assert(classids.empty() && labels.empty());
assert(classids.empty() && labels.empty() && wordids.empty());
/* guoye: end */
utteranceset.push_back(std::move(utterance));
_totalframes += uttframes;
}
@ -1424,6 +1599,9 @@ private:
auto &uttref = randomizedutterancerefs[i];
uttref.globalts = t;
uttref.numframes = randomizedchunks[0][uttref.chunkindex].getchunkdata().numframes(uttref.utteranceindex());
/* guoye: start */
uttref.numwords = randomizedchunks[0][uttref.chunkindex].getchunkdata().numwords(uttref.utteranceindex());
/* guoye: end */
t = uttref.globalte();
}
assert(t == sweepts + _totalframes);
@ -1486,6 +1664,7 @@ private:
// Returns true if we actually did read something.
bool requirerandomizedchunk(const size_t chunkindex, const size_t windowbegin, const size_t windowend)
{
size_t numinram = 0;
if (chunkindex < windowbegin || chunkindex >= windowend)
@ -1510,7 +1689,10 @@ private:
fprintf(stderr, "feature set %d: requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", m, (int) chunkindex, (int) chunk.globalts, (int) (chunk.globalte() - 1), (int) (chunksinram + 1));
msra::util::attempt(5, [&]() // (reading from network)
{
chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, verbosity);
/* guoye: start */
// chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, verbosity);
chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, specialwordids, verbosity);
/* guoye: end */
});
}
chunksinram++;
@ -1561,6 +1743,8 @@ public:
verbosity = newverbosity;
}
// get the next minibatch
// A minibatch is made up of one or more utterances.
// We will return less than 'framesrequested' unless the first utterance is too long.
@ -1569,16 +1753,21 @@ public:
// This is efficient since getbatch() is called with sequential 'globalts' except at epoch start.
// Note that the start of an epoch does not necessarily fall onto an utterance boundary. The caller must use firstvalidglobalts() to find the first valid globalts at or after a given time.
// Support for data parallelism: If mpinodes > 1 then we will
// - load only a subset of blocks from the disk
// - skip frames/utterances in not-loaded blocks in the returned data
// - 'framesadvanced' will still return the logical #frames; that is, by how much the global time index is advanced
bool getbatch(const size_t globalts, const size_t framesrequested,
const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
/* guoye: start */
// std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids, std::vector<std::vector<size_t>> &wids, std::vector<std::vector<short>> &nws,
/* guoye: end */
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &latticepairs, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries2) override
{
bool readfromdisk = false; // return value: shall be 'true' if we paged in anything
auto_timer timergetbatch;
@ -1624,6 +1813,9 @@ public:
// determine the true #frames we return, for allocation--it is less than mbframes in the case of MPI/data-parallel sub-set mode
size_t tspos = 0;
/* guoye: start */
size_t twrds = 0;
/* guoye: end */
for (size_t pos = spos; pos < epos; pos++)
{
const auto &uttref = randomizedutterancerefs[pos];
@ -1631,11 +1823,18 @@ public:
continue;
tspos += uttref.numframes;
/* guoye: start */
twrds += uttref.numwords;
/* guoye: end */
}
// resize feat and uids
feat.resize(vdim.size());
uids.resize(classids.size());
/* guoye: start */
wids.resize(wordids.size());
nws.resize(wordids.size());
/* guoye: end */
if (m_generatePhoneBoundaries)
phoneboundaries2.resize(classids.size());
sentendmark.resize(vdim.size());
@ -1649,15 +1848,29 @@ public:
{
foreach_index (j, uids)
{
/* guoye: start */
nws[j].clear();
/* guoye: end */
if (issupervised()) // empty means unsupervised training -> return empty uids
{
uids[j].resize(tspos);
/* guoye: start */
wids[j].resize(twrds);
/* guoye: end */
if (m_generatePhoneBoundaries)
phoneboundaries2[j].resize(tspos);
}
else
{
uids[i].clear();
/* guoye: start */
// uids[i].clear();
// guoye: i think it is a bug, i should be j
uids[j].clear();
wids[j].clear();
/* guoye: end */
if (m_generatePhoneBoundaries)
phoneboundaries2[i].clear();
}
@ -1674,6 +1887,9 @@ public:
if (verbosity > 0)
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int) spos, (int) (epos - 1), (int) tspos, (int) mbframes, (int) framesrequested, (int) sweep);
tspos = 0; // relative start of utterance 'pos' within the returned minibatch
/* guoye: start */
twrds = 0;
/* guoye: end */
for (size_t pos = spos; pos < epos; pos++)
{
const auto &uttref = randomizedutterancerefs[pos];
@ -1681,6 +1897,9 @@ public:
continue;
size_t n = 0;
/* guoye: start */
size_t nw = 0;
/* guoye: end */
foreach_index (i, randomizedchunks)
{
const auto &chunk = randomizedchunks[i][uttref.chunkindex];
@ -1692,6 +1911,9 @@ public:
sentendmark[i].push_back(n + tspos);
assert(n == uttframes.cols() && uttref.numframes == n && chunkdata.numframes(uttref.utteranceindex()) == n);
/* guoye: start */
nw = uttref.numwords;
/* guoye: end */
// copy the frames and class labels
for (size_t t = 0; t < n; t++) // t = time index into source utterance
{
@ -1714,6 +1936,9 @@ public:
if (i == 0)
{
auto uttclassids = getclassids(uttref);
/* guoye: start */
auto uttwordids = getwordids(uttref);
/* guoye: end */
std::vector<shiftedvector<biggrowablevector<HMMIDTYPE>>> uttphoneboudaries;
if (m_generatePhoneBoundaries)
uttphoneboudaries = getphonebound(uttref);
@ -1742,9 +1967,28 @@ public:
}
}
}
/* guoye: start */
foreach_index(j, uttwordids)
{
nws[j].push_back(short(nw));
for (size_t t = 0; t < nw; t++) // t = time index into source utterance
{
if (issupervised())
{
wids[j][t + twrds] = uttwordids[j][t];
}
}
}
/* guoye: end */
}
}
tspos += n;
/* guoye: start */
twrds += nw;
/* guoye: end */
}
foreach_index (i, feat)
@ -1795,6 +2039,9 @@ public:
// resize feat and uids
feat.resize(vdim.size());
uids.resize(classids.size());
/* guoye: start */
// no need to care about wids for framemode = true
/* guoye: end */
assert(feat.size() == vdim.size());
assert(feat.size() == randomizedchunks.size());
foreach_index (i, feat)
@ -1878,31 +2125,360 @@ public:
return readfromdisk;
}
// get the next minibatch
// A minibatch is made up of one or more utterances.
// We will return less than 'framesrequested' unless the first utterance is too long.
// Note that this may return frames that are beyond the epoch end, but the first frame is always within the epoch.
// We specify the utterance by its global start time (in a space of a infinitely repeated training set).
// This is efficient since getbatch() is called with sequential 'globalts' except at epoch start.
// Note that the start of an epoch does not necessarily fall onto an utterance boundary. The caller must use firstvalidglobalts() to find the first valid globalts at or after a given time.
// Support for data parallelism: If mpinodes > 1 then we will
// - load only a subset of blocks from the disk
// - skip frames/utterances in not-loaded blocks in the returned data
// - 'framesadvanced' will still return the logical #frames; that is, by how much the global time index is advanced
/* guoye: start */
bool getbatch(const size_t globalts, const size_t framesrequested,
const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &latticepairs, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries2) override
{
bool readfromdisk = false; // return value: shall be 'true' if we paged in anything
auto_timer timergetbatch;
assert(_totalframes > 0);
// update randomization if a new sweep is entered --this is a complex operation that updates many of the data members used below
const size_t sweep = lazyrandomization(globalts);
size_t mbframes = 0;
const std::vector<char> noboundaryflags; // dummy
if (!framemode) // regular utterance mode
{
// find utterance position for globalts
// There must be a precise match; it is not possible to specify frames that are not on boundaries.
auto positer = randomizedutteranceposmap.find(globalts);
if (positer == randomizedutteranceposmap.end())
LogicError("getbatch: invalid 'globalts' parameter; must match an existing utterance boundary");
const size_t spos = positer->second;
// determine how many utterances will fit into the requested minibatch size
mbframes = randomizedutterancerefs[spos].numframes; // at least one utterance, even if too long
size_t epos;
for (epos = spos + 1; epos < numutterances && ((mbframes + randomizedutterancerefs[epos].numframes) < framesrequested); epos++) // add more utterances as long as they fit within requested minibatch size
mbframes += randomizedutterancerefs[epos].numframes;
// do some paging housekeeping
// This will also set the feature-kind information if it's the first time.
// Free all chunks left of the range.
// Page-in all chunks right of the range.
// We are a little more blunt for now: Free all outside the range, and page in only what is touched. We could save some loop iterations.
const size_t windowbegin = positionchunkwindows[spos].windowbegin();
const size_t windowend = positionchunkwindows[epos - 1].windowend();
for (size_t k = 0; k < windowbegin; k++)
releaserandomizedchunk(k);
for (size_t k = windowend; k < randomizedchunks[0].size(); k++)
releaserandomizedchunk(k);
for (size_t pos = spos; pos < epos; pos++)
if ((randomizedutterancerefs[pos].chunkindex % numsubsets) == subsetnum)
readfromdisk |= requirerandomizedchunk(randomizedutterancerefs[pos].chunkindex, windowbegin, windowend); // (window range passed in for checking only)
// Note that the above loop loops over all chunks incl. those that we already should have.
// This has an effect, e.g., if 'numsubsets' has changed (we will fill gaps).
// determine the true #frames we return, for allocation--it is less than mbframes in the case of MPI/data-parallel sub-set mode
size_t tspos = 0;
for (size_t pos = spos; pos < epos; pos++)
{
const auto &uttref = randomizedutterancerefs[pos];
if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node
continue;
tspos += uttref.numframes;
}
// resize feat and uids
feat.resize(vdim.size());
uids.resize(classids.size());
if (m_generatePhoneBoundaries)
phoneboundaries2.resize(classids.size());
sentendmark.resize(vdim.size());
assert(feat.size() == vdim.size());
assert(feat.size() == randomizedchunks.size());
foreach_index(i, feat)
{
feat[i].resize(vdim[i], tspos);
if (i == 0)
{
foreach_index(j, uids)
{
if (issupervised()) // empty means unsupervised training -> return empty uids
{
uids[j].resize(tspos);
if (m_generatePhoneBoundaries)
phoneboundaries2[j].resize(tspos);
}
else
{
uids[i].clear();
if (m_generatePhoneBoundaries)
phoneboundaries2[i].clear();
}
latticepairs.clear(); // will push_back() below
transcripts.clear();
}
foreach_index(j, sentendmark)
{
sentendmark[j].clear();
}
}
}
// return these utterances
if (verbosity > 0)
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int)spos, (int)(epos - 1), (int)tspos, (int)mbframes, (int)framesrequested, (int)sweep);
tspos = 0; // relative start of utterance 'pos' within the returned minibatch
for (size_t pos = spos; pos < epos; pos++)
{
const auto &uttref = randomizedutterancerefs[pos];
if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node
continue;
size_t n = 0;
foreach_index(i, randomizedchunks)
{
const auto &chunk = randomizedchunks[i][uttref.chunkindex];
const auto &chunkdata = chunk.getchunkdata();
assert((numsubsets > 1) || (uttref.globalts == globalts + tspos));
auto uttframes = chunkdata.getutteranceframes(uttref.utteranceindex());
matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[j].size() and m[j][i] as required by augmentneighbors())
n = uttframevectors.size();
sentendmark[i].push_back(n + tspos);
assert(n == uttframes.cols() && uttref.numframes == n && chunkdata.numframes(uttref.utteranceindex()) == n);
// copy the frames and class labels
for (size_t t = 0; t < n; t++) // t = time index into source utterance
{
size_t leftextent, rightextent;
// page in the needed range of frames
if (leftcontext[i] == 0 && rightcontext[i] == 0)
{
leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]);
}
else
{
leftextent = leftcontext[i];
rightextent = rightcontext[i];
}
augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], t + tspos);
// augmentneighbors(uttframevectors, noboundaryflags, t, feat[i], t + tspos);
}
// copy the frames and class labels
if (i == 0)
{
auto uttclassids = getclassids(uttref);
std::vector<shiftedvector<biggrowablevector<HMMIDTYPE>>> uttphoneboudaries;
if (m_generatePhoneBoundaries)
uttphoneboudaries = getphonebound(uttref);
foreach_index(j, uttclassids)
{
for (size_t t = 0; t < n; t++) // t = time index into source utterance
{
if (issupervised())
{
uids[j][t + tspos] = uttclassids[j][t];
if (m_generatePhoneBoundaries)
phoneboundaries2[j][t + tspos] = uttphoneboudaries[j][t];
}
}
if (!this->lattices.empty())
{
auto latticepair = chunkdata.getutterancelattice(uttref.utteranceindex());
latticepairs.push_back(latticepair);
// look up reference
const auto &key = latticepair->getkey();
if (!allwordtranscripts.empty())
{
const auto &transcript = allwordtranscripts.find(key)->second;
transcripts.push_back(transcript.words);
}
}
}
}
}
tspos += n;
}
foreach_index(i, feat)
{
assert(tspos == feat[i].cols());
}
}
else
{
const size_t sweepts = sweep * _totalframes; // first global frame index for this sweep
const size_t sweepte = sweepts + _totalframes; // and its end
const size_t globalte = std::min(globalts + framesrequested, sweepte); // we return as much as requested, but not exceeding sweep end
mbframes = globalte - globalts; // that's our mb size
// Perform randomization of the desired frame range
m_frameRandomizer.randomizeFrameRange(globalts, globalte);
// determine window range
// We enumerate all frames--can this be done more efficiently?
const size_t firstchunk = chunkforframepos(globalts);
const size_t lastchunk = chunkforframepos(globalte - 1);
const size_t windowbegin = randomizedchunks[0][firstchunk].windowbegin;
const size_t windowend = randomizedchunks[0][lastchunk].windowend;
if (verbosity > 0)
fprintf(stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n",
(int)globalts, (int)globalte, (int)mbframes, (int)framesrequested, (int)sweep, (int)firstchunk, (int)lastchunk, (int)windowbegin, (int)windowend);
// release all data outside, and page in all data inside
for (size_t k = 0; k < windowbegin; k++)
releaserandomizedchunk(k);
for (size_t k = windowbegin; k < windowend; k++)
if ((k % numsubsets) == subsetnum) // in MPI mode, we skip chunks this way
readfromdisk |= requirerandomizedchunk(k, windowbegin, windowend); // (window range passed in for checking only, redundant here)
for (size_t k = windowend; k < randomizedchunks[0].size(); k++)
releaserandomizedchunk(k);
// determine the true #frames we return--it is less than mbframes in the case of MPI/data-parallel sub-set mode
// First determine it for all nodes, then pick the min over all nodes, as to give all the same #frames for better load balancing.
// TODO: No, return all; and leave it to caller to redistribute them [Zhijie Yan]
std::vector<size_t> subsetsizes(numsubsets, 0);
for (size_t i = 0; i < mbframes; i++) // i is input frame index; j < i in case of MPI/data-parallel sub-set mode
{
const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + i);
subsetsizes[frameref.chunkindex % numsubsets]++;
}
size_t j = subsetsizes[subsetnum]; // return what we have --TODO: we can remove the above full computation again now
const size_t allocframes = std::max(j, (mbframes + numsubsets - 1) / numsubsets); // we leave space for the desired #frames, assuming caller will try to pad them later
// resize feat and uids
feat.resize(vdim.size());
uids.resize(classids.size());
assert(feat.size() == vdim.size());
assert(feat.size() == randomizedchunks.size());
foreach_index(i, feat)
{
feat[i].resize(vdim[i], allocframes);
feat[i].shrink(vdim[i], j);
if (i == 0)
{
foreach_index(k, uids)
{
if (issupervised()) // empty means unsupervised training -> return empty uids
uids[k].resize(j);
else
uids[k].clear();
latticepairs.clear(); // will push_back() below
transcripts.clear();
}
}
}
// return randomized frames for the time range of those utterances
size_t currmpinodeframecount = 0;
for (size_t j2 = 0; j2 < mbframes; j2++)
{
if (currmpinodeframecount >= feat[0].cols()) // MPI/data-parallel mode: all nodes return the same #frames, which is how feat(,) is allocated
break;
// map to time index inside arrays
const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + j2);
// in MPI/data-parallel mode, skip frames that are not in chunks loaded for this MPI node
if ((frameref.chunkindex % numsubsets) != subsetnum)
continue;
// random utterance
readfromdisk |= requirerandomizedchunk(frameref.chunkindex, windowbegin, windowend); // (this is just a check; should not actually page in anything)
foreach_index(i, randomizedchunks)
{
const auto &chunk = randomizedchunks[i][frameref.chunkindex];
const auto &chunkdata = chunk.getchunkdata();
auto uttframes = chunkdata.getutteranceframes(frameref.utteranceindex());
matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[.].size() and m[.][.] as required by augmentneighbors())
const size_t n = uttframevectors.size();
assert(n == uttframes.cols() && chunkdata.numframes(frameref.utteranceindex()) == n);
n;
// copy frame and class labels
const size_t t = frameref.frameindex();
size_t leftextent, rightextent;
// page in the needed range of frames
if (leftcontext[i] == 0 && rightcontext[i] == 0)
{
leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]);
}
else
{
leftextent = leftcontext[i];
rightextent = rightcontext[i];
}
augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], currmpinodeframecount);
if (issupervised() && i == 0)
{
auto frameclassids = getclassids(frameref);
foreach_index(k, uids)
uids[k][currmpinodeframecount] = frameclassids[k][t];
}
}
currmpinodeframecount++;
}
}
timegetbatch = timergetbatch;
// this is the number of frames we actually moved ahead in time
framesadvanced = mbframes;
return readfromdisk;
}
bool supportsbatchsubsetting() const override
{
return true;
}
bool getbatch(const size_t globalts,
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices2, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries2)
/* guoye: start */
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
// const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids, std::vector<std::vector<size_t>> &wids,
/* guoye: end */
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices2, std::vector<std::vector<size_t>> &sentendmark,
std::vector<std::vector<size_t>> &phoneboundaries2)
{
size_t dummy;
/* guoye: start */
return getbatch(globalts, framesrequested, 0, 1, dummy, feat, uids, transcripts, lattices2, sentendmark, phoneboundaries2);
// return getbatch(globalts, framesrequested, 0, 1, dummy, feat, uids, wids, transcripts, lattices, sentendmark, phoneboundaries);
/* guoye: end */
}
double gettimegetbatch()
{
return timegetbatch;
}
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
bool getbatch(const size_t /*globalts*/,
const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/,
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & /*transcripts*/,
std::vector<std::shared_ptr<const latticesource::latticepair>> & /*latticepairs*/)
/* guoye: start */
const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/,
// const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/, std::vector<size_t> & /*wids*/,
/* guoye: end */
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & /*transcripts*/,
std::vector<std::shared_ptr<const latticesource::latticepair>> & /*latticepairs*/) override
{
// should never get here
RuntimeError("minibatchframesourcemulti: getbatch() being called for single input feature and single output feature, should use minibatchutterancesource instead\n");
@ -1912,6 +2488,14 @@ public:
// uids.resize(1);
// return getbatch(globalts, framesrequested, feat[0], uids[0], transcripts, latticepairs);
}
double gettimegetbatch()
{
return timegetbatch;
}
size_t totalframes() const
{

Просмотреть файл

@ -36,6 +36,8 @@ class minibatchutterancesourcemulti : public minibatchsource
// const std::vector<std::unique_ptr<latticesource>> &lattices;
const latticesource &lattices;
// std::vector<latticesource> lattices;
// word-level transcripts (for MMI mode when adding best path to lattices)
const std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence> &allwordtranscripts; // (used for getting word-level transcripts)
@ -158,7 +160,15 @@ class minibatchutterancesourcemulti : public minibatchsource
reader.readNoAlloc(utteranceset[i].parsedpath, (const string &) featkind, sampperiod, uttframes); // note: file info here used for checkuing only
// page in lattice data
if (!latticesource.empty())
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols());
/* guoye: start */
// we currently don't care about kaldi format, so, just to make the compiler happy
// latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols());
{
std::set<int> specialwordids;
specialwordids.clear();
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols(), specialwordids);
}
/* guoye: end */
}
// fprintf (stderr, "\n");
if (verbosity)

Просмотреть файл

@ -1,13 +1,13 @@
== Authors of the Linux Building README ==
Kaisheng Yao
Kaisheng Yao
Microsoft Research
email: kaisheny@microsoft.com
Wengong Jin,
Shanghai Jiao Tong University
email: acmgokun@gmail.com
Wengong Jin,
Shanghai Jiao Tong University
email: acmgokun@gmail.com
Yu Zhang, Leo Liu, Scott Cyphers
CSAIL, Massachusetts Institute of Technology
@ -78,10 +78,10 @@ To clean
== Run ==
All executables are in bin directory:
cntk: The main executable for CNTK
*.so: shared library for corresponding reader, these readers will be linked and loaded dynamically at runtime.
cntk: The main executable for CNTK
*.so: shared library for corresponding reader, these readers will be linked and loaded dynamically at runtime.
./cntk configFile=${your cntk config file}
./cntk configFile=${your cntk config file}
== Kaldi Reader ==
This is a HTKMLF reader and kaldi writer (for decode)

Просмотреть файл

@ -46,7 +46,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
bool useParallelTrain,
StreamMinibatchInputs& inputMatrices,
size_t& actualMBSize,
const MPIWrapperPtr& mpi)
/* guoye: start */
// const MPIWrapperPtr& mpi)
const MPIWrapperPtr& mpi,
size_t& actualNumWords)
/* guoye: end */
{
// Reading consists of a sequence of Reader API calls:
// - GetMinibatch() --fills the inputMatrices and copies the MBLayout from Reader into inputMatrices
@ -71,8 +75,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
auto uids = node->getuidprt();
auto boundaries = node->getboundaryprt();
auto extrauttmap = node->getextrauttmap();
/* guoye: start */
auto wids = node->getwidprt();
auto nws = node->getnwprt();
// trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *boundaries, *extrauttmap);
trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *wids, *nws, *boundaries, *extrauttmap);
trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *boundaries, *extrauttmap);
actualNumWords = 0;
for (size_t i = 0; i < (*nws).size(); i++)
actualNumWords += (*nws)[i];
/* guoye: end */
}
// TODO: move this into shim for the old readers.
@ -284,11 +296,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
private:
typedef std::vector<shared_ptr<const msra::dbn::latticesource::latticepair>> Lattice;
typedef std::vector<size_t> Uid;
/* guoye: start */
typedef std::vector<size_t> Wid;
typedef std::vector<short> Nw;
/* guoye: end */
typedef std::vector<size_t> ExtrauttMap;
typedef std::vector<size_t> Boundaries;
typedef std::vector<shared_ptr<const msra::dbn::latticesource::latticepair>>* LatticePtr;
typedef std::vector<size_t>* UidPtr;
/* guoye: start */
typedef std::vector<size_t>* WidPtr;
typedef std::vector<short>* NwPtr;
/* guoye: end */
typedef std::vector<size_t>* ExtrauttMapPtr;
typedef std::vector<size_t>* BoundariesPtr;
typedef StreamMinibatchInputs Matrices;
@ -298,6 +319,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
MBLayoutPtr m_MBLayoutCache;
Lattice m_LatticeCache;
Uid m_uidCache;
/* guoye: start */
Wid m_widCache;
Nw m_nwCache;
/* guoye: end */
ExtrauttMap m_extrauttmapCache;
Boundaries m_BoundariesCache;
shared_ptr<Matrix<ElemType>> m_netCriterionAccumulator;
@ -313,6 +338,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
Matrices m_netInputMatrixPtr;
LatticePtr m_netLatticePtr;
UidPtr m_netUidPtr;
/* guoye: start */
WidPtr m_netWidPtr;
NwPtr m_netNwPtr;
/* guoye: end */
ExtrauttMapPtr m_netExtrauttMapPtr;
BoundariesPtr m_netBoundariesPtr;
// we remember the pointer to the learnable Nodes so that we can accumulate the gradient once a sub-minibatch is done
@ -352,7 +381,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
public:
SubminibatchDispatcher()
: m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr)
/* guoye: start */
// : m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr)
: m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr), m_netWidPtr(nullptr), m_netNwPtr(nullptr)
/* guoye: end */
{
}
@ -398,6 +430,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_netLatticePtr = node->getLatticePtr();
m_netExtrauttMapPtr = node->getextrauttmap();
m_netUidPtr = node->getuidprt();
/* guoye: start */
m_netWidPtr = node->getwidprt();
m_netNwPtr = node->getnwprt();
/* guoye: end */
m_netBoundariesPtr = node->getboundaryprt();
m_hasLattices = true;
}
@ -408,6 +444,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_netUidPtr = nullptr;
m_netBoundariesPtr = nullptr;
m_hasLattices = false;
/* guoye: start */
m_netWidPtr = nullptr;
m_netNwPtr = nullptr;
/* guoye: end */
}
}
@ -444,11 +484,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_uidCache.clear();
m_extrauttmapCache.clear();
m_BoundariesCache.clear();
/* guoye: start */
m_widCache.clear();
m_nwCache.clear();
/* guoye: end */
m_LatticeCache = *m_netLatticePtr;
m_uidCache = *m_netUidPtr;
m_extrauttmapCache = *m_netExtrauttMapPtr;
m_BoundariesCache = *m_netBoundariesPtr;
/* guoye: start */
m_widCache = *m_netWidPtr;
m_nwCache = *m_netNwPtr;
/* guoye: end */
}
// subminibatches are cutted at the parallel sequence level;
@ -495,10 +544,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
BoundariesPtr decimatedBoundaryPtr, /* output: boundary after decimation*/
ExtrauttMapPtr decimatedExtraMapPtr, /* output: extramap after decimation*/
UidPtr decimatedUidPtr, /* output: Uid after decimation*/
/* guoye: start */
WidPtr decimatedWidPtr, /* output: Wid after decimation*/
NwPtr decimatedNwPtr, /* output: Nw after decimation*/
/* guoye: end */
const Lattice lattices, /* input: lattices to be decimated */
const Boundaries boundaries, /* input: boundary to be decimated */
const ExtrauttMap extraMaps, /* input: extra map to be decimated */
const Uid uids, /* input: uid to be decimated*/
/* guoye: start */
const Wid wids, /* input: uid to be decimated*/
const Nw nws, /* input: uid to be decimated*/
/* guoye: end */
pair<size_t, size_t> parallelSeqRange /* input: what parallel sequence range we are looking at */
)
{
@ -509,12 +566,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
decimatedBoundaryPtr->clear();
decimatedExtraMapPtr->clear();
decimatedUidPtr->clear();
/* guoye: start */
decimatedWidPtr->clear();
decimatedNwPtr->clear();
/* guoye: end */
size_t stFrame = 0;
/* guoye: start */
size_t stWord = 0;
/* guoye: end */
for (size_t iUtt = 0; iUtt < extraMaps.size(); iUtt++)
{
size_t numFramesInThisUtterance = lattices[iUtt]->getnumframes();
size_t iParallelSeq = extraMaps[iUtt]; // i-th utterance belongs to iParallelSeq-th parallel sequence
/* guoye: start */
size_t numWordsInThisUtterance = nws[iUtt];
/* guoye: end */
if (iParallelSeq >= parallelSeqStId && iParallelSeq < parallelSeqEnId)
{
// this utterance has been selected
@ -522,8 +589,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
decimatedBoundaryPtr->insert(decimatedBoundaryPtr->end(), boundaries.begin() + stFrame, boundaries.begin() + stFrame + numFramesInThisUtterance);
decimatedUidPtr->insert(decimatedUidPtr->end(), uids.begin() + stFrame, uids.begin() + stFrame + numFramesInThisUtterance);
decimatedExtraMapPtr->push_back(extraMaps[iUtt] - parallelSeqStId);
/* guoye: start */
decimatedWidPtr->insert(decimatedWidPtr->end(), wids.begin() + stWord, wids.begin() + stWord + numWordsInThisUtterance);
decimatedNwPtr->push_back(numWordsInThisUtterance);
/* guoye: end */
}
stFrame += numFramesInThisUtterance;
/* guoye: start */
stWord += numWordsInThisUtterance;
/* guoye: end */
}
}
@ -538,12 +613,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (m_hasLattices)
{
DecimateLattices(
/* guoye: start */
/*output */
m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr,
// m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr,
m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr, m_netWidPtr, m_netNwPtr,
/*input to be decimated */
m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache,
// m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache,
m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache, m_widCache, m_nwCache,
/* what range we want ? */
seqRange);
/* guoye: end */
}
// The following does m_netInputMatrixPtr = decimatedMatrices; with ownership shenanigans.

Просмотреть файл

@ -81,6 +81,9 @@ void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader *
let bnNode = static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
size_t actualMBSize = 0;
/* guoye: start */
size_t actualNumWords = 0;
/* guoye: end */
LOGPRINTF(stderr, "Estimating Statistics --> %ls\n", bnNode->GetName().c_str());
@ -90,8 +93,10 @@ void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader *
{
// during the bn stat, dataRead must be ensured
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
/* guoye: start */
// nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi, actualNumWords);
/* guoye: end */
if (!wasDataRead) LogicError("DataRead Failure in batch normalization statistics");
ComputationNetwork::BumpEvalTimeStamp(featureNodes);

Просмотреть файл

@ -262,7 +262,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
}
}
}
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 1 \n");
/* guoye: end */
std::vector<ComputationNodeBasePtr> additionalNodesToEvaluate;
// Do not include the output nodes in the matrix sharing structure when using forward value matrix
@ -273,13 +275,18 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
auto& outputNodes = net->OutputNodes();
additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), outputNodes.cbegin(), outputNodes.cend());
}
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 2 \n");
/* guoye: end */
auto preComputeNodesList = net->GetNodesRequiringPreComputation();
// LOGPRINTF(stderr, "SGD debug 2.1 \n");
additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), preComputeNodesList.cbegin(), preComputeNodesList.cend());
// LOGPRINTF(stderr, "SGD debug 2.2 \n");
// allocate memory for forward and backward computation
net->AllocateAllMatrices(evaluationNodes, additionalNodesToEvaluate, criterionNodes[0]); // TODO: use criterionNodes.front() throughout
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 3 \n");
/* guoye: end */
// get feature and label nodes into an array of matrices that will be passed to GetMinibatch()
// TODO: instead, remember the nodes directly, to be able to handle both float and double nodes; current version will crash for mixed networks
StreamMinibatchInputs* inputMatrices = new StreamMinibatchInputs();
@ -287,23 +294,34 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
let& featureNodes = net->FeatureNodes();
let& labelNodes = net->LabelNodes();
// BUGBUG: ^^ should not get all feature/label nodes, but only the ones referenced in a criterion
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 4 \n");
/* guoye: end */
for (size_t pass = 0; pass < 2; pass++)
{
auto& nodes = (pass == 0) ? featureNodes : labelNodes;
for (const auto & node : nodes)
inputMatrices->AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
}
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 5 \n");
/* guoye: end */
// get hmm file for sequence training
bool isSequenceTrainingCriterion = (criterionNodes[0]->OperationName() == L"SequenceWithSoftmax");
// LOGPRINTF(stderr, "SGD debug 5.1 \n");
if (isSequenceTrainingCriterion)
{
// SequenceWithSoftmaxNode<ElemType>* node = static_cast<SequenceWithSoftmaxNode<ElemType>*>(criterionNodes[0]);
// LOGPRINTF(stderr, "SGD debug 5.2 \n");
auto node = dynamic_pointer_cast<SequenceWithSoftmaxNode<ElemType>>(criterionNodes[0]);
auto hmm = node->gethmm();
// LOGPRINTF(stderr, "SGD debug 5.4 \n");
trainSetDataReader->GetHmmData(hmm);
// LOGPRINTF(stderr, "SGD debug 5.5 \n");
}
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 6 \n");
/* guoye: end */
// used for KLD regularized adaptation. For all other adaptation techniques
// use MEL to edit the model and using normal training algorithm
@ -329,6 +347,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
// allocate memory for forward computation
refNet->AllocateAllMatrices({refNode}, {}, nullptr);
}
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 7 \n");
/* guoye: end */
// initializing weights and gradient holder
// only one criterion so far TODO: support multiple ones?
@ -338,6 +359,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
size_t numParameters = 0;
vector<wstring> nodesToUpdateDescriptions; // for logging only
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 8 \n");
/* guoye: end */
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
{
ComputationNodePtr node = dynamic_pointer_cast<ComputationNode<ElemType>>(*nodeIter);
@ -354,12 +378,18 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
numParameters += node->GetSampleLayout().GetNumElements();
}
}
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 9 \n");
/* guoye: end */
size_t numNeedsGradient = 0;
for (let node : net->GetEvalOrder(criterionNodes[0]))
{
if (node->NeedsGradient())
numNeedsGradient++;
}
/* guoye: start */
// LOGPRINTF(stderr, "SGD debug 10 \n");
/* guoye: end */
fprintf(stderr, "\n");
LOGPRINTF(stderr, "Training %.0f parameters in %d ",
(double)numParameters, (int)nodesToUpdateDescriptions.size());
@ -482,8 +512,14 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
// likewise for sequence training parameters
if (isSequenceTrainingCriterion)
{
ComputationNetwork::SetSeqParam<ElemType>(net, criterionNodes[0], m_hSmoothingWeight, m_frameDropThresh, m_doReferenceAlign,
/* guoye: start */
/* ComputationNetwork::SetSeqParam<ElemType>(net, criterionNodes[0], m_hSmoothingWeight, m_frameDropThresh, m_doReferenceAlign,
m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR);
*/
ComputationNetwork::SetSeqParam<ElemType>(net, criterionNodes[0], m_hSmoothingWeight, m_frameDropThresh, m_doReferenceAlign,
m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR,
m_seqGammarCalcUseEMBR, m_EMBRUnit, m_numPathsEMBR, m_enforceValidPathEMBR, m_getPathMethodEMBR, m_showWERMode, m_excludeSpecialWords, m_wordNbest, m_useAccInNbest, m_accWeightInNbest, m_numRawPathsEMBR);
/* guoye: end */
}
// Multiverso Warpper for ASGD logic init
@ -660,8 +696,16 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
learnableNodes, smoothedGradients, smoothedCounts,
epochCriterion, epochEvalErrors,
"", SIZE_MAX, totalMBsSeen, tensorBoardWriter, startEpoch);
totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only
/* guoye: start */
// totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only
if(!m_seqGammarCalcUseEMBR)
totalTrainingSamplesSeen += epochCriterion.second;
else
totalTrainingSamplesSeen += epochEvalErrors[0].second;
/* guoye: end */
timer.Stop();
double epochTime = timer.ElapsedSeconds();
@ -1167,10 +1211,14 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
// get minibatch
// TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
size_t actualMBSize = 0;
/* guoye_start */
size_t actualNumWords = 0;
auto profGetMinibatch = ProfilerTimeBegin();
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi);
// useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi);
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi, actualNumWords);
/* guoye_end */
if (maxNumSamplesExceeded) // Dropping data.
wasDataRead = false;
@ -1294,7 +1342,15 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
// accumulate criterion values (objective, eval)
assert(wasDataRead || numSamplesWithLabelOfNetwork == 0);
// criteria are in Value()(0,0), we accumulate into another 1x1 Matrix (to avoid having to pull the values off the GPU)
localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork);
// localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork);
/* guoye: start */
if(!m_seqGammarCalcUseEMBR)
localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork);
else
localEpochCriterion.Add(0, actualNumWords);
/* guoye: end */
for (size_t i = 0; i < evaluationNodes.size(); i++)
localEpochEvalErrors.Add(i, numSamplesWithLabelOfNetwork);
}
@ -1326,14 +1382,29 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
}
// hoist the criterion into CPU space for all-reduce
localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork);
/* guoye: start */
if (!m_seqGammarCalcUseEMBR)
localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork);
else
localEpochCriterion.Assign(0, actualNumWords);
// localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork);
/* guoye: end */
for (size_t i = 0; i < evaluationNodes.size(); i++)
localEpochEvalErrors.Assign(i, numSamplesWithLabelOfNetwork);
// copy all values to be aggregated into the header
m_gradHeader->numSamples = aggregateNumSamples;
m_gradHeader->criterion = localEpochCriterion.GetCriterion(0).first;
m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel
/* guoye: start */
// m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel
if (!m_seqGammarCalcUseEMBR)
m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel
else
m_gradHeader->numSamplesWithLabel = numSamplesWithLabelOfNetwork;
/* guoye: end */
assert(m_gradHeader->numSamplesWithLabel == aggregateNumSamplesWithLabel);
for (size_t i = 0; i < evaluationNodes.size(); i++)
m_gradHeader->evalErrors[i] = localEpochEvalErrors.GetCriterion(i);
@ -1482,8 +1553,14 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
// epochCriterion aggregates over entire epoch, but we only show difference to last time we logged
EpochCriterion epochCriterionSinceLastLogged = epochCriterion - epochCriterionLastLogged;
let trainLossSinceLastLogged = epochCriterionSinceLastLogged.Average(); // TODO: Check whether old trainSamplesSinceLastLogged matches this ^^ difference
let trainSamplesSinceLastLogged = (int)epochCriterionSinceLastLogged.second;
/* guoye: start */
// let trainSamplesSinceLastLogged = (int)epochCriterionSinceLastLogged.second;
// for EMBR, epochCriterionSinceLastLogged.second stores the #words rather than #frames
let trainSamplesSinceLastLogged = (m_seqGammarCalcUseEMBR? (int)(epochEvalErrors[0].second - epochEvalErrorsLastLogged[0].second) : (int)epochCriterionSinceLastLogged.second);
/* guoye: end */
// determine progress in percent
int mbProgNumPrecision = 2;
double mbProg = 0.0;
@ -1777,7 +1854,12 @@ bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
const size_t numIterationsBeforePrintingProgress = 100;
size_t numItersSinceLastPrintOfProgress = 0;
size_t actualMBSizeDummy;
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi))
/* guoye: start */
size_t actualNumWordsDummy;
// while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi))
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi, actualNumWordsDummy))
/* guoye: end */
{
// TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'?
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
@ -2981,6 +3063,42 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
m_frameDropThresh = configSGD(L"frameDropThresh", 1e-10);
m_doReferenceAlign = configSGD(L"doReferenceAlign", false);
m_seqGammarCalcUsesMBR = configSGD(L"seqGammarUsesMBR", false);
/* guoye start */
m_seqGammarCalcUseEMBR = configSGD(L"seqGammarUseEMBR", false);
m_EMBRUnit = configSGD(L"EMBRUnit", "word");
m_numPathsEMBR = configSGD(L"numPathsEMBR", (size_t)100);
// enforce the path starting with sentence start
m_enforceValidPathEMBR = configSGD(L"enforceValidPathEMBR", false);
//could be sampling or nbest
m_getPathMethodEMBR = configSGD(L"getPathMethodEMBR", "sampling");
// could be average or onebest
m_showWERMode = configSGD(L"showWERMode", "average");
// don't include path that has special words if true
m_excludeSpecialWords = configSGD(L"excludeSpecialWords", false);
// true then, we force the nbest has different word sequence
m_wordNbest = configSGD(L"wordNbest", false);
m_useAccInNbest = configSGD(L"useAccInNbest", false);
m_accWeightInNbest = configSGD(L"accWeightInNbest", 1.0f);
m_numRawPathsEMBR = configSGD(L"numRawPathsEMBR", (size_t)100);
if (!m_useAccInNbest)
{
if (m_numRawPathsEMBR > m_numPathsEMBR)
{
fprintf(stderr, "SGDParams: WARNING: we do not use acc in nbest, so no need to make numRawPathsEMBR = %d larger than numPathsEMBR = %d \n", (int)m_numRawPathsEMBR, (int)m_numPathsEMBR);
}
}
if (m_getPathMethodEMBR == "sampling" && m_showWERMode == "onebest")
{
RuntimeError("There is no way to show onebest WER in sampling based EMBR");
}
/* guoye end */
m_seqGammarCalcAMF = configSGD(L"seqGammarAMF", 14.0);
m_seqGammarCalcLMF = configSGD(L"seqGammarLMF", 14.0);
m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0);

Просмотреть файл

@ -323,6 +323,20 @@ protected:
double m_seqGammarCalcbMMIFactor;
bool m_seqGammarCalcUsesMBR;
/* guoye start */
bool m_seqGammarCalcUseEMBR;
string m_EMBRUnit; //unit could be: word, phone, state (we all compute edit distance
bool m_enforceValidPathEMBR;
string m_getPathMethodEMBR;
size_t m_numPathsEMBR; // number of sampled paths
string m_showWERMode; // number of sampled paths
bool m_excludeSpecialWords;
bool m_wordNbest;
bool m_useAccInNbest;
float m_accWeightInNbest;
size_t m_numRawPathsEMBR;
/* guoye end */
// decide whether should apply regularization into BatchNormalizationNode
// true: disable Regularization
// false: enable Regularization (default)

Просмотреть файл

@ -120,7 +120,11 @@ public:
for (;;)
{
size_t actualMBSize = 0;
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
/* guoye: start */
size_t actualNumWords = 0;
/* guoye: end */
// bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi, actualNumWords);
// in case of distributed reading, we do a few more loops until all ranks have completed
// end of epoch
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess))

Просмотреть файл

@ -62,7 +62,11 @@ public:
const size_t numIterationsBeforePrintingProgress = 100;
size_t numItersSinceLastPrintOfProgress = 0;
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
/* guoye: start */
size_t actualNumWords;
// while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr, actualNumWords))
/* guoye: end */
{
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
m_net->ForwardProp(outputNodes);
@ -230,7 +234,11 @@ public:
char formatChar = !formattingOptions.isCategoryLabel ? 'f' : !formattingOptions.labelMappingFile.empty() ? 's' : 'u';
std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr); numMBsRun++)
/* guoye: start */
size_t actualNumWords;
//for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr); numMBsRun++)
for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr, actualNumWords); numMBsRun++)
/* guoye: end */
{
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
m_net->ForwardProp(outputNodes);

Просмотреть файл

@ -11,6 +11,11 @@
#include <memory>
#include <vector>
/* guoye: start */
#include <string>
/* guoye: end */
#pragma warning(disable : 4127) // conditional expression is constant
namespace msra { namespace lattices {
@ -22,6 +27,19 @@ struct SeqGammarCalParam
double wp;
double bMMIfactor;
bool sMBRmode;
/* guoye: start */
bool EMBR;
std::string EMBRUnit;
size_t numPathsEMBR;
bool enforceValidPathEMBR;
std::string getPathMethodEMBR;
std::string showWERMode;
bool excludeSpecialWords;
bool wordNbest;
bool useAccInNbest;
float accWeightInNbest;
size_t numRawPathsEMBR;
/* guoye: end */
SeqGammarCalParam()
{
amf = 14.0;
@ -29,6 +47,20 @@ struct SeqGammarCalParam
wp = 0.0;
bMMIfactor = 0.0;
sMBRmode = false;
/* guoye: start */
EMBR = false;
EMBRUnit = "word";
numPathsEMBR = 100;
enforceValidPathEMBR = false;
getPathMethodEMBR = "sampling";
showWERMode = "average";
excludeSpecialWords = false;
wordNbest = false;
useAccInNbest = false;
accWeightInNbest = 1.0;
numRawPathsEMBR = 100;
/* guoye: end*/
}
};
@ -81,6 +113,19 @@ public:
wp = (float) gammarParam.wp;
seqsMBRmode = gammarParam.sMBRmode;
boostmmifactor = (float) gammarParam.bMMIfactor;
/* guoye: start */
EMBR = gammarParam.EMBR;
EMBRUnit = gammarParam.EMBRUnit;
numPathsEMBR = gammarParam.numPathsEMBR;
enforceValidPathEMBR = gammarParam.enforceValidPathEMBR;
getPathMethodEMBR = gammarParam.getPathMethodEMBR;
showWERMode = gammarParam.showWERMode;
excludeSpecialWords = gammarParam.excludeSpecialWords;
wordNbest = gammarParam.wordNbest;
useAccInNbest = gammarParam.useAccInNbest;
accWeightInNbest = gammarParam.accWeightInNbest;
numRawPathsEMBR = gammarParam.numRawPathsEMBR;
/* guoye: end */
}
// ========================================
@ -91,7 +136,10 @@ public:
const Microsoft::MSR::CNTK::Matrix<ElemType>& loglikelihood,
Microsoft::MSR::CNTK::Matrix<ElemType>& labels,
Microsoft::MSR::CNTK::Matrix<ElemType>& gammafromlattice,
std::vector<size_t>& uids, std::vector<size_t>& boundaries,
/* guoye: start */
// std::vector<size_t>& uids, std::vector<size_t>& boundaries,
std::vector<size_t>& uids, std::vector<size_t>& wids, std::vector<short>& nws, std::vector<size_t>& boundaries,
/* guoye: end */
size_t samplesInRecurrentStep, /* numParallelUtterance ? */
std::shared_ptr<Microsoft::MSR::CNTK::MBLayout> pMBLayout,
std::vector<size_t>& extrauttmap,
@ -99,6 +147,13 @@ public:
{
// check total frame number to be added ?
// int deviceid = loglikelihood.GetDeviceId();
/* guoye: start */
/*
for (size_t i = 0; i < lattices.size(); i++)
{
// fprintf(stderr, "calgammaformb: i = %d, utt = %ls \n", int(i), lattices[i]->second.key.c_str());
}
*/
size_t boundaryframenum;
std::vector<size_t> validframes; // [s] cursor pointing to next utterance begin within a single parallel sequence [s]
validframes.assign(samplesInRecurrentStep, 0);
@ -128,9 +183,15 @@ public:
size_t mapi = 0; // parallel-sequence index for utterance [i]
// cal gamma for each utterance
size_t ts = 0;
/* guoye: start */
size_t ws = 0;
/* guoye: end */
for (size_t i = 0; i < lattices.size(); i++)
{
const size_t numframes = lattices[i]->getnumframes();
/* guoye: start */
const short numwords = nws[i];
/* guoye: end */
msra::dbn::matrixstripe predstripe(pred, ts, numframes); // logLLs for this utterance
msra::dbn::matrixstripe dengammasstripe(dengammas, ts, numframes); // denominator gammas
@ -186,6 +247,9 @@ public:
}
array_ref<size_t> uidsstripe(&uids[ts], numframes);
/* guoye: start */
std::vector<size_t> widsstripe(wids.begin() + ws, wids.begin() + ws + numwords);
/* guoye: end */
if (doreferencealign)
{
@ -204,12 +268,28 @@ public:
numavlogp /= numframes;
// auto_timer dengammatimer;
/* guoye: start */
// double denavlogp = lattices[i]->second.forwardbackward(parallellattice,
// (const msra::math::ssematrixbase&) predstripe, (const msra::asr::simplesenonehmm&) m_hset,
// (msra::math::ssematrixbase&) dengammasstripe, (msra::math::ssematrixbase&) gammasbuffer /*empty, not used*/,
// lmf, wp, amf, boostmmifactor, seqsMBRmode, uidsstripe, boundariesstripe);
// fprintf(stderr, "calgammaformb: i = %d, utt = %ls \n", int(i), lattices[i]->second.key.c_str());
double denavlogp = lattices[i]->second.forwardbackward(parallellattice,
(const msra::math::ssematrixbase&) predstripe, (const msra::asr::simplesenonehmm&) m_hset,
(msra::math::ssematrixbase&) dengammasstripe, (msra::math::ssematrixbase&) gammasbuffer /*empty, not used*/,
lmf, wp, amf, boostmmifactor, seqsMBRmode, uidsstripe, boundariesstripe);
objectValue += (ElemType)((numavlogp - denavlogp) * numframes);
lmf, wp, amf, boostmmifactor, seqsMBRmode, EMBR, EMBRUnit, numPathsEMBR, enforceValidPathEMBR, getPathMethodEMBR, showWERMode, excludeSpecialWords, wordNbest, useAccInNbest, accWeightInNbest, numRawPathsEMBR, uidsstripe, widsstripe, boundariesstripe);
/* guoye: end */
/* guoye: start */
// objectValue += (ElemType)((numavlogp - denavlogp) * numframes);
numavlogp;
denavlogp;
// objectValue += (ElemType)( 0 * numframes);
objectValue += (ElemType)(denavlogp*numwords);
/* guoye: end */
if (samplesInRecurrentStep == 1)
{
tempmatrix = gammafromlattice.ColumnSlice(ts, numframes);
@ -243,8 +323,13 @@ public:
}
if (samplesInRecurrentStep > 1)
validframes[mapi] += numframes; // advance the cursor within the parallel sequence
fprintf(stderr, "dengamma value %f\n", denavlogp);
/* guoye: start */
// fprintf(stderr, "dengamma value %f\n", denavlogp);
/* guoye: end */
ts += numframes;
/* guoye: start */
ws += numwords;
/* guoye: end */
}
functionValues.SetValue(objectValue);
}
@ -509,6 +594,20 @@ protected:
float boostmmifactor;
bool seqsMBRmode;
/* guoye: start */
bool EMBR;
std::string EMBRUnit;
size_t numPathsEMBR;
bool enforceValidPathEMBR;
std::string getPathMethodEMBR;
std::string showWERMode;
bool excludeSpecialWords;
bool wordNbest;
bool useAccInNbest;
float accWeightInNbest;
size_t numRawPathsEMBR;
/* guoye: end */
private:
std::unique_ptr<Microsoft::MSR::CNTK::CUDAPageLockedMemAllocator> m_cudaAllocator;
std::shared_ptr<ElemType> m_intermediateCUDACopyBuffer;

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -133,6 +133,27 @@ void backwardlatticej(const size_t batchsize, const size_t startindex, const std
}
}
/* guoye: start */
void backwardlatticejEMBR(const size_t batchsize, const size_t startindex, const std::vector<float>& edgeacscores,
const std::vector<msra::lattices::edgeinfowithscores>& edges,
const std::vector<msra::lattices::nodeinfo>& nodes,
std::vector<double>& edgelogbetas, std::vector<double>& logbetas,
float lmf, float wp, float amf)
{
const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block
const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x;
const size_t j = jinblock + blockIdx.x * tpb;
if (j < batchsize) // note: will cause issues if we ever use __synctreads() in backwardlatticej
{
msra::lattices::latticefunctionskernels::backwardlatticejEMBR(j + startindex, edgeacscores,
edges, nodes, edgelogbetas,
logbetas, lmf, wp, amf);
}
}
/* guoye: end */
void sMBRerrorsignalj(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
const std::vector<double>& logpps, const float amf, const std::vector<double>& logEframescorrect,
@ -147,6 +168,19 @@ void sMBRerrorsignalj(const std::vector<unsigned short>& alignstateids, const st
}
}
/* guoye: start */
void EMBRerrorsignalj(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
const std::vector<double>& edgeweights, msra::math::ssematrixbase& errorsignal)
{
const size_t shufflemode = 3;
const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode);
if (j < edges.size()) // note: will cause issues if we ever use __synctreads()
{
msra::lattices::latticefunctionskernels::EMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
}
}
/* guoye: end */
void stateposteriorsj(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
const std::vector<double>& logqs, msra::math::ssematrixbase& logacc)
@ -298,6 +332,50 @@ static double emulateforwardbackwardlattice(const size_t* batchsizeforward, cons
#endif
return totalfwscore;
}
/* guoye: start */
static double emulatebackwardlatticeEMBR(const size_t* batchsizebackward, const size_t numlaunchbackward,
const std::vector<float>& edgeacscores,
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
std::vector<double>& edgelogbetas, std::vector<double>& logbetas,
const float lmf, const float wp, const float amf)
{
dim3 t(32, 8);
const size_t tpb = t.x * t.y;
dim3 b((unsigned int)((logbetas.size() + tpb - 1) / tpb));
emulatecuda(b, t, [&]()
{
setvaluej(logbetas, LOGZERO, logbetas.size());
});
logbetas[nodes.size() - 1] = 0;
// forward pass
// backward pass
size_t startindex = edges.size();
for (size_t i = 0; i < numlaunchbackward; i++)
{
dim3 b3((unsigned int)((batchsizebackward[i] + tpb - 1) / tpb));
emulatecuda(b3, t, [&]()
{
backwardlatticejEMBR(batchsizebackward[i], startindex - batchsizebackward[i], edgeacscores,
edges, nodes, edgelogbetas, logbetas, lmf, wp, amf);
});
startindex -= batchsizebackward[i];
}
double totalbwscore = logbetas.front();
return totalbwscore;
}
/* guoye: end */
// this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything
// this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector)
static void emulatesMBRerrorsignal(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
@ -324,6 +402,29 @@ static void emulatesMBRerrorsignal(const std::vector<unsigned short>& alignstate
});
}
/* guoye: start */
// this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything
// this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector)
static void emulateEMBRerrorsignal(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
const std::vector<double>& edgeweights,
msra::math::ssematrixbase& errorsignal)
{
const size_t numedges = edges.size();
dim3 t(32, 8);
const size_t tpb = t.x * t.y;
foreach_coord(i, j, errorsignal)
errorsignal(i, j) = 0;
dim3 b((unsigned int)((numedges + tpb - 1) / tpb));
emulatecuda(b, t, [&]()
{
EMBRerrorsignalj(alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
});
dim3 b1((((unsigned int)errorsignal.rows()) + 31) / 32);
}
/* guoye: end */
// this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything
// this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector)
static void emulatemmierrorsignal(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
@ -388,6 +489,11 @@ struct parallelstateimpl
logppsgpu(msra::cuda::newdoublevector(deviceid)),
logalphasgpu(msra::cuda::newdoublevector(deviceid)),
logbetasgpu(msra::cuda::newdoublevector(deviceid)),
/* guoye: start */
edgelogbetasgpu(msra::cuda::newdoublevector(deviceid)),
edgeweightsgpu(msra::cuda::newdoublevector(deviceid)),
/* guoye: end */
logaccalphasgpu(msra::cuda::newdoublevector(deviceid)),
logaccbetasgpu(msra::cuda::newdoublevector(deviceid)),
logframescorrectedgegpu(msra::cuda::newdoublevector(deviceid)),
@ -526,6 +632,10 @@ struct parallelstateimpl
std::unique_ptr<doublevector> logppsgpu;
std::unique_ptr<doublevector> logalphasgpu;
/* guoye: start */
std::unique_ptr<doublevector> edgelogbetasgpu;
std::unique_ptr<doublevector> edgeweightsgpu;
/* guoye: end */
std::unique_ptr<doublevector> logbetasgpu;
std::unique_ptr<doublevector> logaccalphasgpu;
std::unique_ptr<doublevector> logaccbetasgpu;
@ -619,6 +729,20 @@ struct parallelstateimpl
logEframescorrectgpu->allocate(edges.size());
}
}
/* guoye: start */
template <class edgestype, class nodestype>
void allocbwvectorsEMBR(const edgestype& edges, const nodestype& nodes)
{
#ifndef TWO_CHANNEL
const size_t alphabetanoderatio = 1;
#else
const size_t alphabetanoderatio = 2;
#endif
logbetasgpu->allocate(alphabetanoderatio * nodes.size());
edgelogbetasgpu->allocate(edges.size());
}
/* guoye: end */
// check if gpumatrixstorage supports size of cpumatrix, if not allocate. set gpumatrix to part of gpumatrixstorage
// This function checks the size of errorsignalgpustorage, and then sets errorsignalgpu to a columnslice of the
@ -664,6 +788,35 @@ struct parallelstateimpl
edgealignments.resize(alignresult->size());
alignresult->fetch(edgealignments, true);
}
/* guoye: start */
void getlogbetas(std::vector<double>& logbetas)
{
logbetas.resize(logbetasgpu->size());
logbetasgpu->fetch(logbetas, true);
}
void getedgelogbetas(std::vector<double>& edgelogbetas)
{
edgelogbetas.resize(edgelogbetasgpu->size());
edgelogbetasgpu->fetch(edgelogbetas, true);
}
void getedgeweights(std::vector<double>& edgeweights)
{
edgeweights.resize(edgeweightsgpu->size());
edgeweightsgpu->fetch(edgeweights, true);
}
void setedgeweights(const std::vector<double>& edgeweights)
{
edgeweightsgpu->assign(edgeweights, false);
}
/* guoye: end */
};
void lattice::parallelstate::setdevice(size_t deviceid)
@ -725,6 +878,29 @@ void lattice::parallelstate::getedgealignments(std::vector<unsigned short>& edge
{
pimpl->getedgealignments(edgealignments);
}
/* guoye: start */
void lattice::parallelstate::getlogbetas(std::vector<double>& logbetas)
{
pimpl->getlogbetas(logbetas);
}
void lattice::parallelstate::getedgelogbetas(std::vector<double>& edgelogbetas)
{
pimpl->getedgelogbetas(edgelogbetas);
}
void lattice::parallelstate::getedgeweights(std::vector<double>& edgeweights)
{
pimpl->getedgeweights(edgeweights);
}
void lattice::parallelstate::setedgeweights(const std::vector<double>& edgeweights)
{
pimpl->setedgeweights(edgeweights);
}
/* guoye: end */
//template<class ElemType>
void lattice::parallelstate::setloglls(const Microsoft::MSR::CNTK::Matrix<float>& loglls)
{
@ -909,6 +1085,73 @@ double lattice::parallelforwardbackwardlattice(parallelstate& parallelstate, con
return totalfwscore;
}
/* guoye: start */
// parallelforwardbackwardlattice() -- compute the latticelevel logpps using forwardbackward
double lattice::parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector<float>& edgeacscores,
const float lmf, const float wp, const float amf,
std::vector<double>& edgelogbetas, std::vector<double>& logbetas) const
{ // ^^ TODO: remove this
vector<size_t> batchsizebackward; // record the batch size that exclude the data dependency for backward
size_t endindexbackward = edges.back().S;
size_t countbatchbackward = 0;
foreach_index (j, edges) // compute the batch size info for kernel launches
{
const size_t backj = edges.size() - 1 - j;
if (edges[backj].E > endindexbackward)
{
countbatchbackward++;
if (endindexbackward < edges[backj].S)
endindexbackward = edges[backj].S;
}
else
{
batchsizebackward.push_back(countbatchbackward);
countbatchbackward = 1;
endindexbackward = edges[backj].S;
}
}
batchsizebackward.push_back(countbatchbackward);
double totalbwscore = 0.0f;
if (!parallelstate->emulation)
{
if (verbosity >= 2)
fprintf(stderr, "parallelbackwardlatticeEMBR: %d launches for backward\n", (int) batchsizebackward.size());
parallelstate->allocbwvectorsEMBR(edges, nodes);
std::unique_ptr<latticefunctions> latticefunctions(msra::cuda::newlatticefunctions(parallelstate.getdevice())); // final CUDA call
latticefunctions->backwardlatticeEMBR(&batchsizebackward[0], batchsizebackward.size(),
*parallelstate->edgeacscoresgpu.get(), *parallelstate->edgesgpu.get(),
*parallelstate->nodesgpu.get(), *parallelstate->edgelogbetasgpu.get(),
*parallelstate->logbetasgpu.get(), lmf, wp, amf, totalbwscore);
}
else // emulation
{
#ifndef TWO_CHANNEL
fprintf(stderr, "forbid invalid sil path\n");
const size_t alphabetanoderatio = 1;
#else
const size_t alphabetanoderatio = 2;
#endif
logbetas.resize(alphabetanoderatio * nodes.size());
edgelogbetas.resize(edges.size());
totalbwscore = emulatebackwardlatticeEMBR(&batchsizebackward[0], batchsizebackward.size(),
edgeacscores, edges, nodes,
edgelogbetas, logbetas, lmf, wp, amf);
}
return totalbwscore;
}
/* guoye: end */
// ------------------------------------------------------------------------
// parallel implementations of sMBR error updating step
// ------------------------------------------------------------------------
@ -948,6 +1191,36 @@ void lattice::parallelsMBRerrorsignal(parallelstate& parallelstate, const edgeal
}
}
/* guoye: start */
// ------------------------------------------------------------------------
void lattice::parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
const std::vector<double>& edgeweights,
msra::math::ssematrixbase& errorsignal) const
{
if (!parallelstate->emulation)
{
// no need negative buffer for EMBR
const bool cacheerrorsignalneg = false;
parallelstate->cacheerrorsignal(errorsignal, cacheerrorsignalneg);
std::unique_ptr<latticefunctions> latticefunctions(msra::cuda::newlatticefunctions(parallelstate.getdevice()));
latticefunctions->EMBRerrorsignal(*parallelstate->alignresult.get(), *parallelstate->alignoffsetsgpu.get(), *parallelstate->edgesgpu.get(),
*parallelstate->nodesgpu.get(), *parallelstate->edgeweightsgpu.get(),
*parallelstate->errorsignalgpu.get());
if (errorsignal.rows() > 0 && errorsignal.cols() > 0)
{
parallelstate->errorsignalgpu->CopySection(errorsignal.rows(), errorsignal.cols(), &errorsignal(0, 0), errorsignal.getcolstride());
}
}
else
{
emulateEMBRerrorsignal(thisedgealignments.getalignmentsbuffer(), thisedgealignments.getalignoffsets(), edges, nodes, edgeweights, errorsignal);
}
}
/* guoye: end */
// ------------------------------------------------------------------------
// parallel implementations of MMI error updating step
// ------------------------------------------------------------------------

Разница между файлами не показана из-за своего большого размера Загрузить разницу