This commit is contained in:
Родитель
830b6f94b4
Коммит
4018e1e724
1
CNTK.sln
1
CNTK.sln
|
@ -2216,6 +2216,7 @@ Global
|
|||
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug_CpuOnly|x64.ActiveCfg = Debug|Any CPU
|
||||
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug_UWP|x64.ActiveCfg = Debug|Any CPU
|
||||
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug|x64.ActiveCfg = Debug|Any CPU
|
||||
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Debug|x64.Build.0 = Debug|Any CPU
|
||||
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Release_CpuOnly|x64.ActiveCfg = Release|Any CPU
|
||||
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Release_NoOpt|x64.ActiveCfg = Release|Any CPU
|
||||
{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}.Release_UWP|x64.ActiveCfg = Release|Any CPU
|
||||
|
|
3
Makefile
3
Makefile
|
@ -784,6 +784,9 @@ HTKMLFREADER_SRC =\
|
|||
$(SOURCEDIR)/Readers/HTKMLFReader/DataWriterLocal.cpp \
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFReader.cpp \
|
||||
$(SOURCEDIR)/Readers/HTKMLFReader/HTKMLFWriter.cpp \
|
||||
# $(SOURCEDIR)/Common/File.cpp \
|
||||
# $(SOURCEDIR)/Common/fileutil.cpp \
|
||||
# $(SOURCEDIR)/Common/ExceptionWithCallStack.cpp \
|
||||
|
||||
HTKMLFREADER_OBJ := $(patsubst %.cpp, $(OBJDIR)/%.o, $(HTKMLFREADER_SRC))
|
||||
|
||||
|
|
Двоичный файл не отображается.
|
@ -246,7 +246,7 @@ TIMIT_TrainSimple = new TrainAction [ // new: added TrainAction; t
|
|||
//L3 = SBFF(L2,hiddenDim,hiddenDim)
|
||||
//CE = SMBFF(L3,labelDim,hiddenDim,myLabels,tag=Criteria)
|
||||
//Err = ClassificationError(myLabels,CE.BFF.FF.P,tag=Eval)
|
||||
//logPrior = LogPrior(myLabels)
|
||||
//logPrior = LogPrior(myLabels)
|
||||
//ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
|
||||
// new:
|
||||
|
@ -282,7 +282,7 @@ TIMIT_TrainSimple = new TrainAction [ // new: added TrainAction; t
|
|||
Err = ClassificationError(myLabels, outZ)
|
||||
|
||||
// define output node for decoding
|
||||
logPrior = LogPrior(myLabels)
|
||||
logPrior = LogPrior(myLabels)
|
||||
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
]
|
||||
]
|
||||
|
@ -395,6 +395,6 @@ network = new NDL [
|
|||
Err = ClassificationError(myLabels, outZ)
|
||||
|
||||
// define output node for decoding
|
||||
logPrior = LogPrior(myLabels)
|
||||
logPrior = LogPrior(myLabels)
|
||||
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
]
|
||||
|
|
|
@ -606,7 +606,7 @@ static void PrintBanner(int argc, wchar_t* argv[], const string& timestamp)
|
|||
fprintf(stderr, "%s %.6s, ", _BUILDBRANCH_, _BUILDSHA1_);
|
||||
#endif
|
||||
fprintf(stderr, "%s %s", __DATE__, __TIME__); // build time
|
||||
fprintf(stderr, ") at %s\n\n", timestamp.c_str());
|
||||
fprintf(stderr, ") on %s at %s\n\n", GetHostName().c_str(), timestamp.c_str());
|
||||
for (int i = 0; i < argc; i++)
|
||||
fprintf(stderr, "%*s%ls", i > 0 ? 2 : 0, "", argv[i]); // use 2 spaces for better visual separability
|
||||
fprintf(stderr, "\n");
|
||||
|
@ -617,8 +617,7 @@ int wmainOldCNTKConfig(int argc, wchar_t* argv[])
|
|||
{
|
||||
std::string timestamp = TimeDateStamp();
|
||||
PrintBanner(argc, argv, timestamp);
|
||||
|
||||
ConfigParameters config;
|
||||
ConfigParameters config;
|
||||
std::string rawConfigString = ConfigParameters::ParseCommandLine(argc, argv, config); // get the command param set they want
|
||||
|
||||
int traceLevel = config(L"traceLevel", 0);
|
||||
|
|
|
@ -488,11 +488,11 @@ ndlMacroUseCNNAuto=[
|
|||
]
|
||||
|
||||
ndlRnnNetwork=[
|
||||
#define basic i/o
|
||||
featDim=1845
|
||||
labelDim=183
|
||||
hiddenDim=2048
|
||||
features=Input(featDim, tag=feature)
|
||||
#define basic i/o
|
||||
featDim=1845
|
||||
labelDim=183
|
||||
hiddenDim=2048
|
||||
features=Input(featDim, tag=feature)
|
||||
labels=Input(labelDim, tag=label)
|
||||
|
||||
MeanVarNorm(x)=[
|
||||
|
@ -502,9 +502,9 @@ ndlRnnNetwork=[
|
|||
]
|
||||
|
||||
# define network
|
||||
featNorm = MeanVarNorm(features)
|
||||
featNorm = MeanVarNorm(features)
|
||||
W0 = Parameter(hiddenDim, featDim)
|
||||
L1 = Times(W0,featNorm)
|
||||
L1 = Times(W0,featNorm)
|
||||
|
||||
W = Parameter(hiddenDim, hiddenDim)
|
||||
|
||||
|
@ -515,8 +515,8 @@ ndlRnnNetwork=[
|
|||
Output = Times(W2, Dout)
|
||||
criterion = CrossEntropyWithSoftmax(labels, Output, tag=Criteria)
|
||||
|
||||
#CE = SMBFF(Dout,labelDim,hiddenDim,labels,tag=Criteria)
|
||||
#Err = ErrorPrediction(labels,CE.BFF.FF.P,tag=Eval)
|
||||
#CE = SMBFF(Dout,labelDim,hiddenDim,labels,tag=Criteria)
|
||||
#Err = ErrorPrediction(labels,CE.BFF.FF.P,tag=Eval)
|
||||
|
||||
LogPrior(labels)
|
||||
{
|
||||
|
@ -525,8 +525,8 @@ ndlRnnNetwork=[
|
|||
}
|
||||
|
||||
# define output (scaled loglikelihood)
|
||||
logPrior = LogPrior(labels)
|
||||
#ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
logPrior = LogPrior(labels)
|
||||
#ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
# rootNodes defined here temporarily so we pass
|
||||
OutputNodes=(criterion)
|
||||
EvalNodes=(criterion)
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
<Keyword>Win32Proj</Keyword>
|
||||
<RootNamespace>CNTKv2LibraryDll</RootNamespace>
|
||||
<ProjectName>CNTKv2LibraryDll</ProjectName>
|
||||
<WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(SolutionDir)\CNTK.Cpp.props" />
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
|
||||
|
|
|
@ -274,11 +274,17 @@ bool DataReader::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
// uids - lables stored in size_t vector instead of ElemType matrix
|
||||
// boundary - phone boundaries
|
||||
// returns - true if there are more minibatches, false if no more minibatches remain
|
||||
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
/* guoye: start */
|
||||
// bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
/* guoye: end */
|
||||
{
|
||||
bool bRet = true;
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap);
|
||||
/* guoye: start */
|
||||
// bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap);
|
||||
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, wids, nws, boundaries, extrauttmap);
|
||||
/* guoye: end */
|
||||
return bRet;
|
||||
}
|
||||
|
||||
|
@ -288,8 +294,14 @@ bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::lattice
|
|||
bool DataReader::GetHmmData(msra::asr::simplesenonehmm* hmm)
|
||||
{
|
||||
bool bRet = true;
|
||||
// fprintf(stderr, "DataReader::GetHmmData: debug 1, m_ioNames.size() = %d \n", int(m_ioNames.size()));
|
||||
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
{
|
||||
//fprintf(stderr, "DataReader::GetHmmData: debug 2, i = %d , m_ioNames[i] = %ls \n", int(i), m_ioNames[i].c_str());
|
||||
bRet &= m_dataReaders[m_ioNames[i]]->GetHmmData(hmm);
|
||||
// fprintf(stderr, "DataReader::GetHmmData: debug 3, i = %d \n", int(i));
|
||||
}
|
||||
return bRet;
|
||||
}
|
||||
|
||||
|
|
|
@ -264,7 +264,10 @@ public:
|
|||
}
|
||||
|
||||
virtual bool GetMinibatch(StreamMinibatchInputs& matrices) = 0;
|
||||
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
|
||||
/* guoye: start */
|
||||
// virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
|
||||
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*wids*/, vector<short>& /*nws*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
|
||||
/* guoye: end */
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
};
|
||||
|
@ -444,7 +447,11 @@ public:
|
|||
// [out] each matrix resized if necessary containing data.
|
||||
// returns - true if there are more minibatches, false if no more minibatches remain
|
||||
virtual bool GetMinibatch(StreamMinibatchInputs& matrices);
|
||||
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
|
||||
/* guoye: start */
|
||||
// virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
|
||||
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
|
||||
|
||||
/* guoye: end */
|
||||
virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm);
|
||||
|
||||
size_t GetNumParallelSequencesForFixingBPTTMode();
|
||||
|
|
|
@ -34,6 +34,22 @@
|
|||
#include <fcntl.h>
|
||||
|
||||
#define FCLOSE_SUCCESS 0
|
||||
/* guoye: start */
|
||||
/*
|
||||
#include "basetypes.h" //for attemp()
|
||||
#include "ProgressTracing.h"
|
||||
#include <unistd.h>
|
||||
#include <glob.h>
|
||||
#include <dirent.h>
|
||||
#include <sys/sendfile.h>
|
||||
#include <stdio.h>
|
||||
#include <ctype.h>
|
||||
#include <limits.h>
|
||||
#include <memory>
|
||||
#include <cwctype>
|
||||
*/
|
||||
// using namespace Microsoft::MSR::CNTK;
|
||||
/* guoye: end */
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
|
||||
|
@ -77,6 +93,7 @@ void freadOrDie(_T& data, size_t num, FILE* f) // template for std::vector<>
|
|||
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
|
||||
}
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
template <class _T>
|
||||
void freadOrDie(_T& data, int num, const HANDLE f) // template for std::vector<>
|
||||
|
@ -229,11 +246,129 @@ void fputstring(FILE* f, const wchar_t*);
|
|||
void fputstring(FILE* f, const std::wstring&);
|
||||
|
||||
template <class CHAR>
|
||||
CHAR* fgetline(FILE* f, CHAR* buf, int size);
|
||||
CHAR* fgetline(FILE* f, CHAR* buf, int size)
|
||||
{
|
||||
// TODO: we should redefine this to write UTF-16 (which matters on GCC which defines wchar_t as 32 bit)
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 0\n");
|
||||
/* guoye: end */
|
||||
CHAR* p = fgets(buf, size, f);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 1\n");
|
||||
/* guoye: end */
|
||||
if (p == NULL) // EOF reached: next time feof() = true
|
||||
{
|
||||
if (ferror(f))
|
||||
RuntimeError("error reading line: %s", strerror(errno));
|
||||
buf[0] = 0;
|
||||
return buf;
|
||||
}
|
||||
size_t n = strnlen(p, size);
|
||||
|
||||
// check for buffer overflow
|
||||
|
||||
if (n >= (size_t)size - 1)
|
||||
{
|
||||
/* guoye: start */
|
||||
// basic_string<CHAR> example(p, n < 100 ? n : 100);
|
||||
std::basic_string<CHAR> example(p, n < 100 ? n : 100);
|
||||
/* guoye: end */
|
||||
uint64_t filepos = fgetpos(f); // (for error message only)
|
||||
RuntimeError("input line too long at file offset %d (max. %d characters allowed) [%s ...]", (int)filepos, (int)size - 1, msra::strfun::utf8(example).c_str());
|
||||
}
|
||||
|
||||
// remove newline at end
|
||||
|
||||
if (n > 0 && p[n - 1] == '\n') // UNIX and Windows style
|
||||
{
|
||||
n--;
|
||||
p[n] = 0;
|
||||
if (n > 0 && p[n - 1] == '\r') // Windows style
|
||||
{
|
||||
n--;
|
||||
p[n] = 0;
|
||||
}
|
||||
}
|
||||
else if (n > 0 && p[n - 1] == '\r') // Mac style
|
||||
{
|
||||
n--;
|
||||
p[n] = 0;
|
||||
}
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
// this is add to fix the code bug, without this, the code does not support wchar
|
||||
template <class CHAR>
|
||||
CHAR* fgetlinew(FILE* f, CHAR* buf, int size)
|
||||
{
|
||||
// TODO: we should redefine this to write UTF-16 (which matters on GCC which defines wchar_t as 32 bit)
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 0\n");
|
||||
/* guoye: end */
|
||||
CHAR* p = fgets(buf, size, f);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n fileutil.cpp: fgetline: debug 1\n");
|
||||
/* guoye: end */
|
||||
if (p == NULL) // EOF reached: next time feof() = true
|
||||
{
|
||||
if (ferror(f))
|
||||
RuntimeError("error reading line: %s", strerror(errno));
|
||||
buf[0] = L'\0';
|
||||
return buf;
|
||||
}
|
||||
size_t n = wcsnlen(p, size);
|
||||
|
||||
// check for buffer overflow
|
||||
|
||||
if (n >= (size_t)size - 1)
|
||||
{
|
||||
/* guoye: start */
|
||||
// basic_string<CHAR> example(p, n < 100 ? n : 100);
|
||||
std::basic_string<CHAR> example(p, n < 100 ? n : 100);
|
||||
/* guoye: end */
|
||||
uint64_t filepos = fgetpos(f); // (for error message only)
|
||||
RuntimeError("input line too long at file offset %d (max. %d characters allowed) [%s ...]", (int)filepos, (int)size - 1, msra::strfun::utf8(example).c_str());
|
||||
}
|
||||
|
||||
// remove newline at end
|
||||
|
||||
if (n > 0 && p[n - 1] == L'\n') // UNIX and Windows style
|
||||
{
|
||||
n--;
|
||||
p[n] = L'\0';
|
||||
if (n > 0 && p[n - 1] == L'\r') // Windows style
|
||||
{
|
||||
n--;
|
||||
p[n] = L'\0';
|
||||
}
|
||||
}
|
||||
else if (n > 0 && p[n - 1] == L'\r') // Mac style
|
||||
{
|
||||
n--;
|
||||
p[n] = L'\0';
|
||||
}
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <class CHAR, size_t n>
|
||||
CHAR* fgetlinew(FILE* f, CHAR(&buf)[n])
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n fileutil.h: fgetline(FILE* f, CHAR(&buf)[n]): debug 0\n");
|
||||
return fgetlinew(f, buf, n);
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
template <class CHAR, size_t n>
|
||||
CHAR* fgetline(FILE* f, CHAR(&buf)[n])
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n fileutil.h: fgetline(FILE* f, CHAR(&buf)[n]): debug 0\n");
|
||||
return fgetline(f, buf, n);
|
||||
/* guoye: end */
|
||||
}
|
||||
std::string fgetline(FILE* f);
|
||||
std::wstring fgetlinew(FILE* f);
|
||||
|
@ -902,9 +1037,40 @@ static inline String& trim(String& s)
|
|||
{
|
||||
return ltrim(rtrim(s));
|
||||
}
|
||||
/* guoye: start */
|
||||
|
||||
template<class String>
|
||||
std::vector<String> SplitString(const String& str, const String& sep);
|
||||
// move from fileutil.h, the definition and declartion should be at the same file.
|
||||
|
||||
// vector<String> SplitString(const String& str, const String& sep)
|
||||
std::vector<String> SplitString(const String& str, const String& sep)
|
||||
/* guoye: end */
|
||||
{
|
||||
/* guoye: start */
|
||||
// vector<String> vstr;
|
||||
std::vector<String> vstr;
|
||||
/* guoye: end */
|
||||
String csub;
|
||||
size_t ifound = 0;
|
||||
size_t ifoundlast = ifound;
|
||||
ifound = str.find_first_of(sep, ifound);
|
||||
while (ifound != String::npos)
|
||||
{
|
||||
csub = str.substr(ifoundlast, ifound - ifoundlast);
|
||||
if (!csub.empty())
|
||||
vstr.push_back(csub);
|
||||
|
||||
ifoundlast = ifound + 1;
|
||||
ifound = str.find_first_of(sep, ifoundlast);
|
||||
}
|
||||
ifound = str.length();
|
||||
csub = str.substr(ifoundlast, ifound - ifoundlast);
|
||||
if (!csub.empty())
|
||||
vstr.push_back(csub);
|
||||
|
||||
return vstr;
|
||||
}
|
||||
/* guoye: end */
|
||||
template<class String, class Char>
|
||||
std::vector<String> SplitString(const String& str, const Char* sep) { return SplitString(str, String(sep)); }
|
||||
|
||||
|
@ -912,4 +1078,8 @@ std::wstring s2ws(const std::string& str);
|
|||
|
||||
std::string ws2s(const std::wstring& wstr);
|
||||
|
||||
|
||||
/* guoye: start */
|
||||
// #include "../fileutil.cpp"
|
||||
/* guoye: end */
|
||||
#endif // _FILEUTIL_
|
||||
|
|
|
@ -0,0 +1,938 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
// fileutil.h - file I/O with error checking
|
||||
//
|
||||
#pragma once
|
||||
#ifndef _FILEUTIL_
|
||||
#define _FILEUTIL_
|
||||
|
||||
#include "Basics.h"
|
||||
#ifdef __WINDOWS__
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif // NOMINMAX
|
||||
#include "Windows.h" // for mmreg.h and FILETIME
|
||||
#include <mmreg.h>
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
#include <algorithm> // for std::find
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include <cctype>
|
||||
#include <errno.h>
|
||||
#include <stdint.h>
|
||||
#include <assert.h>
|
||||
#include <string.h> // for strerror()
|
||||
#include <stdexcept> // for exception
|
||||
#include <fcntl.h>
|
||||
|
||||
#define FCLOSE_SUCCESS 0
|
||||
/* guoye: start */
|
||||
/*
|
||||
#include "basetypes.h" //for attemp()
|
||||
#include "ProgressTracing.h"
|
||||
#include <unistd.h>
|
||||
#include <glob.h>
|
||||
#include <dirent.h>
|
||||
#include <sys/sendfile.h>
|
||||
#include <stdio.h>
|
||||
#include <ctype.h>
|
||||
#include <limits.h>
|
||||
#include <memory>
|
||||
#include <cwctype>
|
||||
*/
|
||||
// using namespace Microsoft::MSR::CNTK;
|
||||
/* guoye: end */
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
|
||||
// A pathname of "-" returns stdout or stdin, depending on mode, and it will
|
||||
// change the binary mode if 'b' or 't' are given. If you use this, make sure
|
||||
// not to fclose() such a handle.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
FILE* fopenOrDie(const std::string& pathname, const char* mode);
|
||||
FILE* fopenOrDie(const std::wstring& pathname, const wchar_t* mode);
|
||||
|
||||
#ifndef __unix__
|
||||
// ----------------------------------------------------------------------------
|
||||
// fsetmode(): set mode to binary or text
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fsetmode(FILE* f, char type);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// freadOrDie(): like fread() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void freadOrDie(void* ptr, size_t size, size_t count, FILE* f);
|
||||
#ifdef _WIN32
|
||||
void freadOrDie(void* ptr, size_t size, size_t count, const HANDLE f);
|
||||
#endif
|
||||
|
||||
template <class _T>
|
||||
void freadOrDie(_T& data, int num, FILE* f) // template for std::vector<>
|
||||
{
|
||||
data.resize(num);
|
||||
if (data.size() > 0)
|
||||
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
|
||||
}
|
||||
template <class _T>
|
||||
void freadOrDie(_T& data, size_t num, FILE* f) // template for std::vector<>
|
||||
{
|
||||
data.resize(num);
|
||||
if (data.size() > 0)
|
||||
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
template <class _T>
|
||||
void freadOrDie(_T& data, int num, const HANDLE f) // template for std::vector<>
|
||||
{
|
||||
data.resize(num);
|
||||
if (data.size() > 0)
|
||||
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
|
||||
}
|
||||
template <class _T>
|
||||
void freadOrDie(_T& data, size_t num, const HANDLE f) // template for std::vector<>
|
||||
{
|
||||
data.resize(num);
|
||||
if (data.size() > 0)
|
||||
freadOrDie(&data[0], sizeof(data[0]), data.size(), f);
|
||||
}
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fwriteOrDie(): like fwrite() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fwriteOrDie(const void* ptr, size_t size, size_t count, FILE* f);
|
||||
#ifdef _WIN32
|
||||
void fwriteOrDie(const void* ptr, size_t size, size_t count, const HANDLE f);
|
||||
#endif
|
||||
|
||||
template <class _T>
|
||||
void fwriteOrDie(const _T& data, FILE* f) // template for std::vector<>
|
||||
{
|
||||
if (data.size() > 0)
|
||||
fwriteOrDie(&data[0], sizeof(data[0]), data.size(), f);
|
||||
}
|
||||
#ifdef _WIN32
|
||||
template <class _T>
|
||||
void fwriteOrDie(const _T& data, const HANDLE f) // template for std::vector<>
|
||||
{
|
||||
if (data.size() > 0)
|
||||
fwriteOrDie(&data[0], sizeof(data[0]), data.size(), f);
|
||||
}
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fprintfOrDie(): like fprintf() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fprintfOrDie(FILE* f, const char* format, ...);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcloseOrDie(): like fclose() but terminate with err msg in case of error
|
||||
// not yet implemented, but we should
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#define fcloseOrDie fclose
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fflushOrDie(): like fflush() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fflushOrDie(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// filesize(): determine size of the file in bytes
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
size_t filesize(const wchar_t* pathname);
|
||||
size_t filesize(FILE* f);
|
||||
int64_t filesize64(const wchar_t* pathname);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fseekOrDie(),ftellOrDie(), fget/setpos(): seek functions with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// 32-bit offsets only
|
||||
long fseekOrDie(FILE* f, long offset, int mode = SEEK_SET);
|
||||
#define ftellOrDie ftell
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fget/setpos(): seek functions with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
uint64_t fgetpos(FILE* f);
|
||||
void fsetpos(FILE* f, uint64_t pos);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// unlinkOrDie(): unlink() with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void unlinkOrDie(const std::string& pathname);
|
||||
void unlinkOrDie(const std::wstring& pathname);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// renameOrDie(): rename() with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void renameOrDie(const std::string& from, const std::string& to);
|
||||
void renameOrDie(const std::wstring& from, const std::wstring& to);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// copyOrDie(): copy file with error handling.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void copyOrDie(const std::string& from, const std::string& to);
|
||||
void copyOrDie(const std::wstring& from, const std::wstring& to);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fexists(): test if a file exists
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool fexists(const char* pathname);
|
||||
bool fexists(const wchar_t* pathname);
|
||||
inline bool fexists(const std::string& pathname)
|
||||
{
|
||||
return fexists(pathname.c_str());
|
||||
}
|
||||
inline bool fexists(const std::wstring& pathname)
|
||||
{
|
||||
return fexists(pathname.c_str());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// funicode(): test if a file uses unicode
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool funicode(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fskipspace(): skip space characters
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool fskipspace(FILE* F);
|
||||
bool fskipwspace(FILE* F);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetline(): like fgets() but terminate with err msg in case of error;
|
||||
// removes the newline character at the end (like gets()), returned buffer is
|
||||
// always 0-terminated; has second version that returns an STL std::string instead
|
||||
// fgetstring(): read a 0-terminated std::string (terminate if error)
|
||||
// fgetword(): read a space-terminated token (terminate if error)
|
||||
// fskipNewLine(): skip all white space until end of line incl. the newline
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputstring(): write a 0-terminated std::string (terminate if error)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputstring(FILE* f, const char*);
|
||||
void fputstring(const HANDLE f, const char* str);
|
||||
void fputstring(FILE* f, const std::string&);
|
||||
void fputstring(FILE* f, const wchar_t*);
|
||||
void fputstring(FILE* f, const std::wstring&);
|
||||
|
||||
template <class CHAR>
|
||||
CHAR* fgetline(FILE* f, CHAR* buf, int size);
|
||||
template <class CHAR, size_t n>
|
||||
CHAR* fgetline(FILE* f, CHAR(&buf)[n])
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n fileutil.h: fgetline(FILE* f, CHAR(&buf)[n]): debug 0\n");
|
||||
return fgetline(f, buf, n);
|
||||
/* guoye: end */
|
||||
}
|
||||
std::string fgetline(FILE* f);
|
||||
std::wstring fgetlinew(FILE* f);
|
||||
void fgetline(FILE* f, std::string& s, std::vector<char>& buf);
|
||||
void fgetline(FILE* f, std::wstring& s, std::vector<char>& buf);
|
||||
void fgetline(FILE* f, std::vector<char>& buf);
|
||||
void fgetline(FILE* f, std::vector<wchar_t>& buf);
|
||||
|
||||
const char* fgetstring(FILE* f, char* buf, int size);
|
||||
template <size_t n>
|
||||
const char* fgetstring(FILE* f, char(&buf)[n])
|
||||
{
|
||||
return fgetstring(f, buf, n);
|
||||
}
|
||||
const char* fgetstring(const HANDLE f, char* buf, int size);
|
||||
template <size_t n>
|
||||
const char* fgetstring(const HANDLE f, char(&buf)[n])
|
||||
{
|
||||
return fgetstring(f, buf, n);
|
||||
}
|
||||
|
||||
const wchar_t* fgetstring(FILE* f, wchar_t* buf, int size);
|
||||
std::wstring fgetwstring(FILE* f);
|
||||
std::string fgetstring(FILE* f);
|
||||
|
||||
const char* fgettoken(FILE* f, char* buf, int size);
|
||||
template <size_t n>
|
||||
const char* fgettoken(FILE* f, char(&buf)[n])
|
||||
{
|
||||
return fgettoken(f, buf, n);
|
||||
}
|
||||
std::string fgettoken(FILE* f);
|
||||
const wchar_t* fgettoken(FILE* f, wchar_t* buf, int size);
|
||||
std::wstring fgetwtoken(FILE* f);
|
||||
|
||||
int fskipNewline(FILE* f, bool skip = true);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputstring(): write a 0-terminated std::string (terminate if error)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputstring(FILE* f, const char*);
|
||||
#ifdef _WIN32
|
||||
void fputstring(const HANDLE f, const char* str);
|
||||
#endif
|
||||
void fputstring(FILE* f, const std::string&);
|
||||
void fputstring(FILE* f, const wchar_t*);
|
||||
void fputstring(FILE* f, const std::wstring&);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetTag(): read a 4-byte tag & return as a std::string
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
std::string fgetTag(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcheckTag(FILE* f, const char* expectedTag);
|
||||
#ifdef _WIN32
|
||||
void fcheckTag(const HANDLE f, const char* expectedTag);
|
||||
#endif
|
||||
void fcheckTag_ascii(FILE* f, const std::string& expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcompareTag(): compare two tags; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcompareTag(const std::string& readTag, const std::string& expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputTag(): write a 4-byte tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputTag(FILE* f, const char* tag);
|
||||
#ifdef _WIN32
|
||||
void fputTag(const HANDLE f, const char* tag);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fskipstring(): skip a 0-terminated std::string, such as a pad std::string
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fskipstring(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fpad(): write a 0-terminated std::string to pad file to a n-byte boundary
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fpad(FILE* f, int n);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetbyte(): read a byte value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
char fgetbyte(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetshort(): read a short value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
short fgetshort(FILE* f);
|
||||
short fgetshort_bigendian(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetint24(): read a 3-byte (24-bit) int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
int fgetint24(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetint(): read an int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
int fgetint(FILE* f);
|
||||
#ifdef _WIN32
|
||||
int fgetint(const HANDLE f);
|
||||
#endif
|
||||
int fgetint_bigendian(FILE* f);
|
||||
int fgetint_ascii(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetlong(): read an long value
|
||||
// ----------------------------------------------------------------------------
|
||||
long fgetlong(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfloat(): read a float value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
float fgetfloat(FILE* f);
|
||||
float fgetfloat_bigendian(FILE* f);
|
||||
float fgetfloat_ascii(FILE* f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetdouble(): read a double value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
double fgetdouble(FILE* f);
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetwav(): read an entire .wav file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fgetwav(FILE* f, std::vector<short>& wav, int& sampleRate);
|
||||
void fgetwav(const std::wstring& fn, std::vector<short>& wav, int& sampleRate);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputwav(): save data into a .wav file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputwav(FILE* f, const std::vector<short>& wav, int sampleRate, int nChannels = 1);
|
||||
void fputwav(const std::wstring& fn, const std::vector<short>& wav, int sampleRate, int nChannels = 1);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputbyte(): write a byte value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputbyte(FILE* f, char val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputshort(): write a short value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputshort(FILE* f, short val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputint24(): write a 3-byte (24-bit) int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputint24(FILE* f, int v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputint(): write an int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputint(FILE* f, int val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputlong(): write an long value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputlong(FILE* f, long val);
|
||||
|
||||
#ifdef _WIN32
|
||||
void fputint(const HANDLE f, int v);
|
||||
#endif
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfloat(): write a float value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputfloat(FILE* f, float val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputdouble(): write a double value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputdouble(FILE* f, double val);
|
||||
|
||||
// template versions of put/get functions for binary files
|
||||
template <typename T>
|
||||
void fput(FILE* f, T v)
|
||||
{
|
||||
fwriteOrDie(&v, sizeof(v), 1, f);
|
||||
}
|
||||
|
||||
// template versions of put/get functions for binary files
|
||||
template <typename T>
|
||||
void fget(FILE* f, T& v)
|
||||
{
|
||||
freadOrDie((void*) &v, sizeof(v), 1, f);
|
||||
}
|
||||
|
||||
// GetFormatString - get the format std::string for a particular type
|
||||
template <typename T>
|
||||
const wchar_t* GetFormatString(T /*t*/)
|
||||
{
|
||||
// if this _ASSERT goes off it means that you are using a type that doesn't have
|
||||
// a read and/or write routine.
|
||||
// If the type is a user defined class, you need to create some global functions that handles file in/out.
|
||||
// for example:
|
||||
// File& operator>>(File& stream, MyClass& test);
|
||||
// File& operator<<(File& stream, MyClass& test);
|
||||
//
|
||||
// in your class you will probably want to add these functions as friends so you can access any private members
|
||||
// friend File& operator>>(File& stream, MyClass& test);
|
||||
// friend File& operator<<(File& stream, MyClass& test);
|
||||
//
|
||||
// if you are using wchar_t* or char* types, these use other methods because they require buffers to be passed
|
||||
// either use std::string and std::wstring, or use the WriteString() and ReadString() methods
|
||||
assert(false); // need a specialization
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// GetFormatString - specalizations to get the format std::string for a particular type
|
||||
template <>
|
||||
const wchar_t* GetFormatString(char);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(wchar_t);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(short);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(int);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(long);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(unsigned short);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(unsigned int);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(unsigned long);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(float);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(double);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(unsigned long long);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(long long);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(const char*);
|
||||
template <>
|
||||
const wchar_t* GetFormatString(const wchar_t*);
|
||||
|
||||
// GetScanFormatString - get the format std::string for a particular type
|
||||
template <typename T>
|
||||
const wchar_t* GetScanFormatString(T)
|
||||
{
|
||||
assert(false); // need a specialization
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// GetScanFormatString - specalizations to get the format std::string for a particular type
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(char);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(wchar_t);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(short);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(int);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(long);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(unsigned short);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(unsigned int);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(unsigned long);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(float);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(double);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(unsigned long long);
|
||||
template <>
|
||||
const wchar_t* GetScanFormatString(long long);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetText(): get a value from a text file
|
||||
// ----------------------------------------------------------------------------
|
||||
template <typename T>
|
||||
void fgetText(FILE* f, T& v)
|
||||
{
|
||||
int rc = ftrygetText(f, v);
|
||||
if (rc == 0)
|
||||
Microsoft::MSR::CNTK::RuntimeError("error reading value from file (invalid format)");
|
||||
else if (rc == EOF)
|
||||
Microsoft::MSR::CNTK::RuntimeError("error reading from file: %s", strerror(errno));
|
||||
assert(rc == 1);
|
||||
}
|
||||
|
||||
// version to try and get a std::string, and not throw exceptions if contents don't match
|
||||
template <typename T>
|
||||
int ftrygetText(FILE* f, T& v)
|
||||
{
|
||||
const wchar_t* formatString = GetScanFormatString<T>(v);
|
||||
int rc = fwscanf(f, formatString, &v);
|
||||
assert(rc == 1 || rc == 0);
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <>
|
||||
int ftrygetText<bool>(FILE* f, bool& v);
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetText() specializations for fwscanf_s differences: get a value from a text file
|
||||
// ----------------------------------------------------------------------------
|
||||
void fgetText(FILE* f, char& v);
|
||||
void fgetText(FILE* f, wchar_t& v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputText(): write a value out as text
|
||||
// ----------------------------------------------------------------------------
|
||||
template <typename T>
|
||||
void fputText(FILE* f, T v)
|
||||
{
|
||||
const wchar_t* formatString = GetFormatString(v);
|
||||
int rc = fwprintf(f, formatString, v);
|
||||
if (rc == 0)
|
||||
Microsoft::MSR::CNTK::RuntimeError("error writing value to file, no values written");
|
||||
else if (rc < 0)
|
||||
Microsoft::MSR::CNTK::RuntimeError("error writing to file: %s", strerror(errno));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputText(): write a bool out as character
|
||||
// ----------------------------------------------------------------------------
|
||||
template <>
|
||||
void fputText<bool>(FILE* f, bool v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfile(): write a binary block or a std::string as a file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputfile(const std::wstring& pathname, const std::vector<char>& buffer);
|
||||
void fputfile(const std::wstring& pathname, const std::wstring&);
|
||||
void fputfile(const std::wstring& pathname, const std::string&);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfile(): load a file as a binary block
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fgetfile(const std::wstring& pathname, std::vector<char>& buffer);
|
||||
void fgetfile(FILE* f, std::vector<char>& buffer);
|
||||
namespace msra { namespace files {
|
||||
|
||||
void fgetfilelines(const std::wstring& pathname, std::vector<char>& readbuffer, std::vector<std::string>& lines, int numberOfTries = 1);
|
||||
|
||||
static inline std::vector<std::string> fgetfilelines(const std::wstring& pathname)
|
||||
{
|
||||
std::vector<char> buffer;
|
||||
std::vector<std::string> lines;
|
||||
fgetfilelines(pathname, buffer, lines);
|
||||
return lines;
|
||||
}
|
||||
std::vector<char*> fgetfilelines(const std::wstring& pathname, std::vector<char>& readbuffer, int numberOfTries = 1);
|
||||
|
||||
}}
|
||||
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// getfiletime(), setfiletime(): access modification time
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool getfiletime(const std::wstring& path, FILETIME& time);
|
||||
void setfiletime(const std::wstring& path, const FILETIME& time);
|
||||
|
||||
#endif
|
||||
// ----------------------------------------------------------------------------
|
||||
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void expand_wildcards(const std::wstring& path, std::vector<std::wstring>& paths);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// make_intermediate_dirs() -- make all intermediate dirs on a path
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace files {
|
||||
|
||||
void make_intermediate_dirs(const std::wstring& filepath);
|
||||
|
||||
std::vector<std::wstring> get_all_files_from_directory(const std::wstring& directory);
|
||||
|
||||
}}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fuptodate() -- test whether an output file is at least as new as an input file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace files {
|
||||
|
||||
bool fuptodate(const std::wstring& target, const std::wstring& input, bool inputrequired = true);
|
||||
};
|
||||
};
|
||||
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// simple support for WAV file I/O
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
typedef struct wavehder
|
||||
{
|
||||
char riffchar[4];
|
||||
unsigned int RiffLength;
|
||||
char wavechar[8];
|
||||
unsigned int FmtLength;
|
||||
signed short wFormatTag;
|
||||
signed short nChannels;
|
||||
unsigned int nSamplesPerSec;
|
||||
unsigned int nAvgBytesPerSec;
|
||||
signed short nBlockAlign;
|
||||
signed short wBitsPerSample;
|
||||
char datachar[4];
|
||||
unsigned int DataLength;
|
||||
|
||||
private:
|
||||
void prepareRest(int SampleCount);
|
||||
|
||||
public:
|
||||
void prepare(unsigned int Fs, int Bits, int Channels, int SampleCount);
|
||||
void prepare(const WAVEFORMATEX& wfx, int SampleCount);
|
||||
unsigned int read(FILE* f, signed short& wRealFormatTag, int& bytesPerSample);
|
||||
void write(FILE* f);
|
||||
static void update(FILE* f);
|
||||
} WAVEHEADER;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetwfx(), fputwfx(): I/O of wave file headers only
|
||||
// ----------------------------------------------------------------------------
|
||||
unsigned int fgetwfx(FILE* f, WAVEFORMATEX& wfx);
|
||||
void fputwfx(FILE* f, const WAVEFORMATEX& wfx, unsigned int numSamples);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetraw(): read data of .wav file, and separate data of multiple channels.
|
||||
// For example, data[i][j]: i is channel index, 0 means the first
|
||||
// channel. j is sample index.
|
||||
// ----------------------------------------------------------------------------
|
||||
void fgetraw(FILE* f, std::vector<std::vector<short>>& data, const WAVEHEADER& wavhd);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// auto_file_ptr -- FILE* with auto-close; use auto_file_ptr instead of FILE*.
|
||||
// Warning: do not pass an auto_file_ptr to a function that calls fclose(),
|
||||
// except for fclose() itself.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class auto_file_ptr
|
||||
{
|
||||
FILE* f;
|
||||
FILE* operator=(auto_file_ptr&); // can't ref-count: no assignment
|
||||
auto_file_ptr(auto_file_ptr&);
|
||||
void close()
|
||||
{
|
||||
if (f && f != stdin && f != stdout && f != stderr)
|
||||
{
|
||||
int rc = ::fclose(f);
|
||||
if ((rc != FCLOSE_SUCCESS) && !std::uncaught_exception())
|
||||
RuntimeError("auto_file_ptr: failed to close file: %s", strerror(errno));
|
||||
|
||||
f = NULL;
|
||||
}
|
||||
}
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
void openfailed(const std::string& path)
|
||||
{
|
||||
Microsoft::MSR::CNTK::RuntimeError("auto_file_ptr: error opening file '%s': %s", path.c_str(), strerror(errno));
|
||||
}
|
||||
#pragma warning(pop)
|
||||
protected:
|
||||
friend int fclose(auto_file_ptr&); // explicit close (note: may fail)
|
||||
int fclose()
|
||||
{
|
||||
int rc = ::fclose(f);
|
||||
if (rc == 0)
|
||||
f = NULL;
|
||||
return rc;
|
||||
}
|
||||
|
||||
public:
|
||||
auto_file_ptr()
|
||||
: f(NULL)
|
||||
{
|
||||
}
|
||||
~auto_file_ptr()
|
||||
{
|
||||
close();
|
||||
}
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
auto_file_ptr(const char* path, const char* mode)
|
||||
{
|
||||
f = fopen(path, mode);
|
||||
if (f == NULL)
|
||||
openfailed(path);
|
||||
}
|
||||
auto_file_ptr(const wchar_t* wpath, const char* mode)
|
||||
{
|
||||
f = _wfopen(wpath, msra::strfun::utf16(mode).c_str());
|
||||
if (f == NULL)
|
||||
openfailed(msra::strfun::utf8(wpath));
|
||||
}
|
||||
#pragma warning(pop)
|
||||
FILE* operator=(FILE* other)
|
||||
{
|
||||
close();
|
||||
f = other;
|
||||
return f;
|
||||
}
|
||||
auto_file_ptr(FILE* other)
|
||||
: f(other)
|
||||
{
|
||||
}
|
||||
operator FILE*() const
|
||||
{
|
||||
return f;
|
||||
}
|
||||
FILE* operator->() const
|
||||
{
|
||||
return f;
|
||||
}
|
||||
void swap(auto_file_ptr& other) throw()
|
||||
{
|
||||
std::swap(f, other.f);
|
||||
}
|
||||
};
|
||||
inline int fclose(auto_file_ptr& af)
|
||||
{
|
||||
return af.fclose();
|
||||
}
|
||||
|
||||
namespace msra { namespace files {
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// textreader -- simple reader for text files --we need this all the time!
|
||||
// Currently reads 8-bit files, but can return as wstring, in which case
|
||||
// they are interpreted as UTF-8 (without BOM).
|
||||
// Note: Not suitable for pipes or typed input due to readahead (fixable if needed).
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class textreader
|
||||
{
|
||||
auto_file_ptr f;
|
||||
std::vector<char> buf; // read buffer (will only grow, never shrink)
|
||||
int ch; // next character (we need to read ahead by one...)
|
||||
char getch()
|
||||
{
|
||||
char prevch = (char) ch;
|
||||
ch = fgetc(f);
|
||||
return prevch;
|
||||
}
|
||||
|
||||
public:
|
||||
textreader(const std::wstring& path)
|
||||
: f(path.c_str(), "rb")
|
||||
{
|
||||
buf.reserve(10000);
|
||||
ch = fgetc(f);
|
||||
}
|
||||
operator bool() const
|
||||
{
|
||||
return ch != EOF;
|
||||
} // true if still a line to read
|
||||
std::string getline() // get and consume the next line
|
||||
{
|
||||
if (ch == EOF)
|
||||
LogicError("textreader: attempted to read beyond EOF");
|
||||
assert(buf.empty());
|
||||
// get all line's characters --we recognize UNIX (LF), DOS (CRLF), and Mac (CR) convention
|
||||
while (ch != EOF && ch != '\n' && ch != '\r')
|
||||
buf.push_back(getch());
|
||||
if (ch != EOF && getch() == '\r' && ch == '\n')
|
||||
getch(); // consume EOLN char
|
||||
std::string line(buf.begin(), buf.end());
|
||||
buf.clear();
|
||||
return line;
|
||||
}
|
||||
std::wstring wgetline()
|
||||
{
|
||||
return msra::strfun::utf16(getline());
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// temp functions -- clean these up
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// split a pathname into directory and filename
|
||||
static inline void splitpath(const std::wstring& path, std::wstring& dir, std::wstring& file)
|
||||
{
|
||||
size_t pos = path.find_last_of(L"\\:/"); // DOS drives, UNIX, Windows
|
||||
if (pos == path.npos) // no directory found
|
||||
{
|
||||
dir.clear();
|
||||
file = path;
|
||||
}
|
||||
else
|
||||
{
|
||||
dir = path.substr(0, pos);
|
||||
file = path.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// test if a pathname is a relative path
|
||||
// A relative path is one that can be appended to a directory.
|
||||
// Drive-relative paths, such as D:file, are considered non-relative.
|
||||
static inline bool relpath(const wchar_t* path)
|
||||
{ // this is a wild collection of pathname conventions in Windows
|
||||
if (path[0] == '/' || path[0] == '\\') // e.g. \WINDOWS
|
||||
return false;
|
||||
if (path[0] && path[1] == ':') // drive syntax
|
||||
return false;
|
||||
// ... TODO: handle long NT paths
|
||||
return true; // all others
|
||||
}
|
||||
template <class Char>
|
||||
static inline bool relpath(const std::basic_string<Char>& s)
|
||||
{
|
||||
return relpath(s.c_str());
|
||||
}
|
||||
|
||||
// trim from start
|
||||
template<class String>
|
||||
static inline String& ltrim(String& s)
|
||||
{
|
||||
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](typename String::value_type c){ return !iscspace(c); }));
|
||||
return s;
|
||||
}
|
||||
|
||||
// trim from end
|
||||
template<class String>
|
||||
static inline String& rtrim(String& s)
|
||||
{
|
||||
s.erase(std::find_if(s.rbegin(), s.rend(), [](typename String::value_type c){ return !iscspace(c); }).base(), s.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
// trim from both ends
|
||||
template<class String>
|
||||
static inline String& trim(String& s)
|
||||
{
|
||||
return ltrim(rtrim(s));
|
||||
}
|
||||
|
||||
template<class String>
|
||||
std::vector<String> SplitString(const String& str, const String& sep);
|
||||
template<class String, class Char>
|
||||
std::vector<String> SplitString(const String& str, const Char* sep) { return SplitString(str, String(sep)); }
|
||||
|
||||
std::wstring s2ws(const std::string& str);
|
||||
|
||||
std::string ws2s(const std::wstring& wstr);
|
||||
|
||||
|
||||
/* guoye: start */
|
||||
#include "../fileutil.cpp"
|
||||
/* guoye: end */
|
||||
#endif // _FILEUTIL_
|
|
@ -23,7 +23,9 @@
|
|||
#include <algorithm> // for find()
|
||||
#include "simplesenonehmm.h"
|
||||
#include "Matrix.h"
|
||||
|
||||
/* guoye: start */
|
||||
#include <set>
|
||||
/* guoye: end */
|
||||
namespace msra { namespace math {
|
||||
|
||||
class ssematrixbase;
|
||||
|
@ -67,7 +69,33 @@ enum mbrclassdefinition // used to identify definition of class in minimum bayes
|
|||
// ===========================================================================
|
||||
class lattice
|
||||
{
|
||||
public:
|
||||
mutable int verbosity;
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
// definie structure for nbest EMBR
|
||||
struct TokenInfo
|
||||
{
|
||||
double score; // the score of the token
|
||||
size_t prev_edge_index; // edge ending with this token, edge start points to the previous node
|
||||
size_t prev_token_index; // the token index in the previous node
|
||||
};
|
||||
struct PrevTokenInfo
|
||||
{
|
||||
size_t prev_edge_index;
|
||||
size_t prev_token_index;
|
||||
double path_score; // use pure to indicatethe path score does not consider the WER of the path
|
||||
};
|
||||
|
||||
struct NBestToken
|
||||
{
|
||||
// for sorting purpose
|
||||
// make sure the map is stored with keys in descending order
|
||||
std::map<double, std::vector<PrevTokenInfo>, std::greater <double>> mp_score_token_infos; // for sorting the tokens in map
|
||||
std::vector<TokenInfo> vt_nbest_tokens; // stores the nbest tokens in the node
|
||||
};
|
||||
|
||||
/* guoye: end */
|
||||
struct header_v1_v2
|
||||
{
|
||||
size_t numnodes : 32;
|
||||
|
@ -90,12 +118,22 @@ private:
|
|||
static const unsigned int NOEDGE = 0xffffff; // 24 bits
|
||||
// static_assert (sizeof (nodeinfo) == 8, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
|
||||
// ensure type size as these are expected to be of this size in the files we read
|
||||
static_assert(sizeof(nodeinfo) == 2, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
|
||||
/* guoye: start */
|
||||
static_assert(sizeof(nodeinfo) == 16, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
|
||||
|
||||
/* guoye: end */
|
||||
static_assert(sizeof(edgeinfowithscores) == 16, "unexpected size of edgeinfowithscores");
|
||||
static_assert(sizeof(aligninfo) == 4, "unexpected size of aligninfo");
|
||||
std::vector<nodeinfo> nodes;
|
||||
/* guoye: start */
|
||||
mutable std::vector<std::vector<uint64_t>> vt_node_out_edge_indices; // vt_node_out_edge_indices[i]: it stores the outgoing edge indices starting from node i
|
||||
std::vector<bool> is_special_words; // true if it is special words that do not count to WER computation, false if it is not
|
||||
|
||||
|
||||
/* guoye: end */
|
||||
std::vector<edgeinfowithscores> edges;
|
||||
std::vector<aligninfo> align;
|
||||
|
||||
// V2 lattices --for a while, we will store both in RAM, until all code is updated
|
||||
static int fsgn(float f)
|
||||
{
|
||||
|
@ -217,6 +255,12 @@ private:
|
|||
public: // TODO: make private again once
|
||||
// construct from edges/align
|
||||
// This is also used for merging, where the edges[] array is not correctly sorted. So don't assume this here.
|
||||
/* guoye: start */
|
||||
void erase_node_out_edges(size_t nodeidx, size_t edgeidx_start, size_t edgeidx_end) const
|
||||
{
|
||||
vt_node_out_edge_indices[nodeidx].erase(vt_node_out_edge_indices[nodeidx].begin() + edgeidx_start, vt_node_out_edge_indices[nodeidx].begin() + edgeidx_end);
|
||||
}
|
||||
/* guoye: end */
|
||||
void builduniquealignments(size_t spunit = SIZE_MAX /*fix this later*/)
|
||||
{
|
||||
// infer /sp/ unit if not given
|
||||
|
@ -701,6 +745,7 @@ private:
|
|||
const float lmf, const float wp, const float amf, const_array_ref<size_t>& uids,
|
||||
const edgealignments& thisedgealignments, std::vector<double>& Eframescorrect) const;
|
||||
|
||||
|
||||
void sMBRerrorsignal(parallelstate& parallelstate,
|
||||
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg,
|
||||
const std::vector<double>& logpps, const float amf, double minlogpp,
|
||||
|
@ -736,7 +781,8 @@ private:
|
|||
const std::vector<double>& logpps, const float amf,
|
||||
const std::vector<double>& logEframescorrect, const double logEframescorrecttotal,
|
||||
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg) const;
|
||||
|
||||
void parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
|
||||
const std::vector<double>& edgeweights, msra::math::ssematrixbase& errorsignal) const;
|
||||
void parallelmmierrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
|
||||
const std::vector<double>& logpps, msra::math::ssematrixbase& errorsignal) const;
|
||||
|
||||
|
@ -747,6 +793,20 @@ private:
|
|||
const_array_ref<size_t>& uids, std::vector<double>& logEframescorrect,
|
||||
std::vector<double>& Eframescorrectbuf, double& logEframescorrecttotal) const;
|
||||
|
||||
/* guoye: start */
|
||||
double parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector<float>& edgeacscores,
|
||||
const float lmf, const float wp,
|
||||
const float amf, std::vector<double>& edgelogbetas,
|
||||
std::vector<double>& logbetas) const;
|
||||
|
||||
void EMBRsamplepaths(const std::vector<double> &edgelogbetas,
|
||||
const std::vector<double> &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords, std::vector< std::vector<size_t> > & vt_paths) const;
|
||||
|
||||
void EMBRnbestpaths(std::vector<NBestToken>& tokenlattice, std::vector<std::vector<size_t>> & vt_paths, std::vector<double>& path_posterior_probs) const;
|
||||
|
||||
double get_edge_weights(std::vector<size_t>& wids, std::vector<std::vector<size_t>>& vt_paths, std::vector<double>& vt_edge_weights, std::vector<double>& vt_path_posterior_probs, std::string getPathMethodEMBR, double& onebestwer) const;
|
||||
/* guoye: end */
|
||||
|
||||
static double scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript,
|
||||
const std::vector<float>& transcriptunigrams, const msra::math::ssematrixbase& logLLs,
|
||||
const msra::asr::simplesenonehmm& hset, const float lmf, const float wp, const float amf);
|
||||
|
@ -762,6 +822,16 @@ private:
|
|||
std::vector<double>& logEframescorrect, std::vector<double>& Eframescorrectbuf,
|
||||
double& logEframescorrecttotal) const;
|
||||
|
||||
/* guoye: start */
|
||||
double backwardlatticeEMBR(const std::vector<float>& edgeacscores, parallelstate& parallelstate, std::vector<double> &edgelogbetas,
|
||||
std::vector<double>& logbetas,
|
||||
const float lmf, const float wp, const float amf) const;
|
||||
|
||||
void constructnodenbestoken(std::vector<NBestToken> &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const;
|
||||
|
||||
double nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<NBestToken> &vt_nbesttokens, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords,
|
||||
const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids) const;
|
||||
/* guoye: end */
|
||||
public:
|
||||
// construct from a HTK lattice file
|
||||
void fromhtklattice(const std::wstring& path, const std::unordered_map<std::string, size_t>& unitmap);
|
||||
|
@ -770,6 +840,156 @@ public:
|
|||
void frommlf(const std::wstring& key, const std::unordered_map<std::string, size_t>& unitmap, const msra::asr::htkmlfreader<msra::asr::htkmlfentry, lattice::htkmlfwordsequence>& labels,
|
||||
const msra::lm::CMGramLM& lm, const msra::lm::CSymbolSet& unigramsymbols);
|
||||
|
||||
/* guoye: start */
|
||||
// template <class IDMAP>
|
||||
// void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids);
|
||||
// move from latticearchive.cpp to .h, it requires template definition and delcaration are both in .h file
|
||||
template <class IDMAP>
|
||||
void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids)
|
||||
/* guoye: end */
|
||||
{
|
||||
|
||||
|
||||
|
||||
size_t version = freadtag(f, "LAT ");
|
||||
if (version == 1)
|
||||
{
|
||||
freadOrDie(&info, sizeof(info), 1, f);
|
||||
freadvector(f, "NODE", nodes, info.numnodes);
|
||||
if (nodes.back().t != info.numframes)
|
||||
/* guoye: start */
|
||||
{
|
||||
// RuntimeError("fread: mismatch between info.numframes and last node's time");
|
||||
// sometimes, the data is corrputed, let's try to live with it
|
||||
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
|
||||
}
|
||||
/* guoye: end */
|
||||
freadvector(f, "EDGE", edges, info.numedges);
|
||||
freadvector(f, "ALIG", align);
|
||||
fcheckTag(f, "END ");
|
||||
// map align ids to user's symmap --the lattice gets updated in place here
|
||||
foreach_index(k, align)
|
||||
align[k].updateunit(idmap); // updates itself
|
||||
}
|
||||
else if (version == 2)
|
||||
{
|
||||
freadOrDie(&info, sizeof(info), 1, f);
|
||||
freadvector(f, "NODS", nodes, info.numnodes);
|
||||
if (nodes.back().t != info.numframes)
|
||||
{
|
||||
/* guoye: start */
|
||||
{
|
||||
// RuntimeError("fread: mismatch between info.numframes and last node's time");
|
||||
// sometimes, the data is corrputed, let's try to live with it
|
||||
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
|
||||
}
|
||||
/* guoye: end */
|
||||
}
|
||||
freadvector(f, "EDGS", edges2, info.numedges); // uniqued edges
|
||||
freadvector(f, "ALNS", uniquededgedatatokens); // uniqued alignments
|
||||
fcheckTag(f, "END ");
|
||||
|
||||
/* guoye: start */
|
||||
vt_node_out_edge_indices.resize(info.numnodes);
|
||||
for (size_t j = 0; j < info.numedges; j++)
|
||||
{
|
||||
// an edge with !NULL pointing to not <s>
|
||||
// this code make sure if you always start from <s> in the sampled path.
|
||||
// mask here: we delay the processing in EMBRsamplepaths controlled by flag: enforceValidPathEMBR
|
||||
// if (edges2[j].S == 0 && nodes[edges2[j].E].wid != 1) continue;
|
||||
|
||||
vt_node_out_edge_indices[edges2[j].S].push_back(j);
|
||||
|
||||
}
|
||||
|
||||
is_special_words.resize(info.numnodes);
|
||||
for (size_t i = 0; i < info.numnodes; i++)
|
||||
{
|
||||
/*
|
||||
if (nodes[i].wid == 0xfffff)
|
||||
{
|
||||
nodes[i].wid;
|
||||
}
|
||||
*/
|
||||
if (specialwordids.find(int(nodes[i].wid)) != specialwordids.end()) is_special_words[i] = true;
|
||||
else is_special_words[i] = false;
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
// check if we need to map
|
||||
#if 1 // post-bugfix for incorrect inference of spunit
|
||||
if (info.impliedspunitid != SIZE_MAX && info.impliedspunitid >= idmap.size()) // we have buggy lattices like that--what do they mean??
|
||||
{
|
||||
fprintf(stderr, "fread: detected buggy spunit id %d which is out of range (%d entries in map)\n", (int)info.impliedspunitid, (int)idmap.size());
|
||||
RuntimeError("fread: out of bounds spunitid");
|
||||
}
|
||||
#endif
|
||||
// This is critical--we have a buggy lattice set that requires no mapping where mapping would fail
|
||||
bool needsmapping = false;
|
||||
foreach_index(k, idmap)
|
||||
{
|
||||
if (idmap[k] != (size_t)k
|
||||
#if 1
|
||||
&& (k != (int)idmap.size() - 1 || idmap[k] != spunit) // that HACK that we add one more /sp/ entry at the end...
|
||||
#endif
|
||||
)
|
||||
{
|
||||
needsmapping = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// map align ids to user's symmap --the lattice gets updated in place here
|
||||
if (needsmapping)
|
||||
{
|
||||
if (info.impliedspunitid != SIZE_MAX)
|
||||
info.impliedspunitid = idmap[info.impliedspunitid];
|
||||
|
||||
// deal with broken (zero-token) edges
|
||||
std::vector<bool> isendworkaround;
|
||||
if (info.impliedspunitid != spunit)
|
||||
{
|
||||
fprintf(stderr, "fread: lattice with broken spunit, using workaround to handle potentially broken zero-token edges\n");
|
||||
inferends(isendworkaround);
|
||||
}
|
||||
|
||||
size_t uniquealignments = 1;
|
||||
const size_t skipscoretokens = info.hasacscores ? 2 : 1;
|
||||
for (size_t k = skipscoretokens; k < uniquededgedatatokens.size(); k++)
|
||||
{
|
||||
if (!isendworkaround.empty() && isendworkaround[k]) // secondary criterion to detect ends in broken lattices
|
||||
{
|
||||
k--; // don't advance, since nothing to advance over
|
||||
}
|
||||
else
|
||||
{
|
||||
// this is a regular token: update it in-place
|
||||
auto& ai = uniquededgedatatokens[k];
|
||||
if (ai.unit >= idmap.size())
|
||||
RuntimeError("fread: broken-file heuristics failed");
|
||||
ai.updateunit(idmap); // updates itself
|
||||
if (!ai.last)
|
||||
continue;
|
||||
}
|
||||
// if last then skip over the lm and ac scores
|
||||
k += skipscoretokens;
|
||||
uniquealignments++;
|
||||
}
|
||||
fprintf(stderr, "fread: mapped %d unique alignments\n", (int)uniquealignments);
|
||||
}
|
||||
if (info.impliedspunitid != spunit)
|
||||
{
|
||||
// fprintf (stderr, "fread: inconsistent spunit id in file %d vs. expected %d; due to erroneous heuristic\n", info.impliedspunitid, spunit); // [v-hansu] comment out becaues it takes up most of the log
|
||||
// it's actually OK, we can live with this, since we only decompress and then move on without any assumptions
|
||||
// RuntimeError("fread: mismatching /sp/ units");
|
||||
}
|
||||
// reconstruct old lattice format from this --TODO: remove once we change to new data representation
|
||||
rebuildedges(info.impliedspunitid != spunit /*to be able to read somewhat broken V2 lattice archives*/);
|
||||
}
|
||||
else
|
||||
RuntimeError("fread: unsupported lattice format version");
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
// check consistency
|
||||
// - only one end node
|
||||
// - only forward edges
|
||||
|
@ -995,6 +1215,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// read from a stream
|
||||
// This can be used on an existing structure and will replace its content. May be useful to avoid memory allocations (resize() will not shrink memory).
|
||||
// For efficiency, we will not check the inner consistency of the file here, but rather when we further process it.
|
||||
|
@ -1003,7 +1224,7 @@ public:
|
|||
// This will also map the aligninfo entries to the new symbol table, through idmap.
|
||||
// V1 lattices will be converted. 'spsenoneid' is used in that process.
|
||||
template <class IDMAP>
|
||||
void fread(FILE* f, const IDMAP& idmap, size_t spunit)
|
||||
void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids)
|
||||
{
|
||||
size_t version = freadtag(f, "LAT ");
|
||||
if (version == 1)
|
||||
|
@ -1011,7 +1232,11 @@ public:
|
|||
freadOrDie(&info, sizeof(info), 1, f);
|
||||
freadvector(f, "NODE", nodes, info.numnodes);
|
||||
if (nodes.back().t != info.numframes)
|
||||
RuntimeError("fread: mismatch between info.numframes and last node's time");
|
||||
{
|
||||
// RuntimeError("fread: mismatch between info.numframes and last node's time");
|
||||
// sometimes, the data is corrputed, let's try to live with it
|
||||
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
|
||||
}
|
||||
freadvector(f, "EDGE", edges, info.numedges);
|
||||
freadvector(f, "ALIG", align);
|
||||
fcheckTag(f, "END ");
|
||||
|
@ -1024,11 +1249,15 @@ public:
|
|||
freadOrDie(&info, sizeof(info), 1, f);
|
||||
freadvector(f, "NODS", nodes, info.numnodes);
|
||||
if (nodes.back().t != info.numframes)
|
||||
RuntimeError("fread: mismatch between info.numframes and last node's time");
|
||||
{
|
||||
// RuntimeError("fread: mismatch between info.numframes and last node's time");
|
||||
// sometimes, the data is corrputed, let's try to live with it
|
||||
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
|
||||
}
|
||||
freadvector(f, "EDGS", edges2, info.numedges); // uniqued edges
|
||||
freadvector(f, "ALNS", uniquededgedatatokens); // uniqued alignments
|
||||
fcheckTag(f, "END ");
|
||||
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap);
|
||||
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap, specialwordids);
|
||||
}
|
||||
else
|
||||
RuntimeError("fread: unsupported lattice format version");
|
||||
|
@ -1055,8 +1284,35 @@ public:
|
|||
|
||||
// Helper method to process v2 Lattice format
|
||||
template <class IDMAP>
|
||||
void ProcessV2Lattice(size_t spunit, header_v1_v2& info, std::vector<aligninfo>& uniquededgedatatokens, const IDMAP& idmap)
|
||||
void ProcessV2Lattice(size_t spunit, header_v1_v2& info, std::vector<aligninfo>& uniquededgedatatokens, const IDMAP& idmap, std::set<int>& specialwordids = {} )
|
||||
{
|
||||
/* guoye: start */
|
||||
vt_node_out_edge_indices.resize(info.numnodes);
|
||||
for (size_t j = 0; j < info.numedges; j++)
|
||||
{
|
||||
// an edge with !NULL pointing to not <s>
|
||||
// this code make sure if you always start from <s> in the sampled path.
|
||||
// mask here: we delay the processing in EMBRsamplepaths controlled by flag: enforceValidPathEMBR
|
||||
// if (edges2[j].S == 0 && nodes[edges2[j].E].wid != 1) continue;
|
||||
|
||||
vt_node_out_edge_indices[edges2[j].S].push_back(j);
|
||||
|
||||
}
|
||||
|
||||
is_special_words.resize(info.numnodes);
|
||||
for (size_t i = 0; i < info.numnodes; i++)
|
||||
{
|
||||
/*
|
||||
if (nodes[i].wid == 0xfffff)
|
||||
{
|
||||
nodes[i].wid;
|
||||
}
|
||||
*/
|
||||
if (specialwordids.find(int(nodes[i].wid)) != specialwordids.end()) is_special_words[i] = true;
|
||||
else is_special_words[i] = false;
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
// check if we need to map
|
||||
if (info.impliedspunitid != SIZE_MAX && info.impliedspunitid >= idmap.size()) // we have buggy lattices like that--what do they mean??
|
||||
{
|
||||
|
@ -1124,7 +1380,9 @@ public:
|
|||
rebuildedges(info.impliedspunitid != spunit /*to be able to read somewhat broken V2 lattice archives*/);
|
||||
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
// parallel versions (defined in parallelforwardbackward.cpp)
|
||||
class parallelstate
|
||||
{
|
||||
|
@ -1152,6 +1410,14 @@ public:
|
|||
const size_t getsilunitid();
|
||||
void getedgeacscores(std::vector<float>& edgeacscores);
|
||||
void getedgealignments(std::vector<unsigned short>& edgealignments);
|
||||
/* guoye: start */
|
||||
void getlogbetas(std::vector<double>& logbetas);
|
||||
void getedgelogbetas(std::vector<double>& edgelogbetas);
|
||||
void getedgeweights(std::vector<double>& edgeweights);
|
||||
|
||||
|
||||
void setedgeweights(const std::vector<double>& edgeweights);
|
||||
/* guoye: end */
|
||||
// to work with CNTK's GPU memory
|
||||
void setdevice(size_t DeviceId);
|
||||
size_t getdevice();
|
||||
|
@ -1166,11 +1432,30 @@ public:
|
|||
|
||||
// forward-backward function
|
||||
// Note: logLLs and posteriors may be the same matrix (aliased).
|
||||
/* start: guoye */
|
||||
|
||||
/*
|
||||
double forwardbackward(parallelstate& parallelstate, const class msra::math::ssematrixbase& logLLs, const class msra::asr::simplesenonehmm& hmms,
|
||||
class msra::math::ssematrixbase& result, class msra::math::ssematrixbase& errorsignalbuf,
|
||||
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, array_ref<size_t> uids, const_array_ref<size_t> bounds = const_array_ref<size_t>(),
|
||||
const_array_ref<htkmlfwordsequence::word> transcript = const_array_ref<htkmlfwordsequence::word>(), const std::vector<float>& transcriptunigrams = std::vector<float>()) const;
|
||||
|
||||
|
||||
*/
|
||||
double forwardbackward(parallelstate& parallelstate, const class msra::math::ssematrixbase& logLLs, const class msra::asr::simplesenonehmm& hmms,
|
||||
class msra::math::ssematrixbase& result, class msra::math::ssematrixbase& errorsignalbuf,
|
||||
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, const bool EMBR, const std::string EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const std::string getPathMethodEMBR, const std::string showWERMode,
|
||||
const bool excludeSpecialWords, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR,
|
||||
array_ref<size_t> uids, std::vector<size_t> wids, const_array_ref<size_t> bounds = const_array_ref<size_t>(),
|
||||
const_array_ref<htkmlfwordsequence::word> transcript = const_array_ref<htkmlfwordsequence::word>(), const std::vector<float>& transcriptunigrams = std::vector<float>()) const;
|
||||
|
||||
/*
|
||||
void embrerrorsignal(parallelstate ¶llelstate,
|
||||
std::vector<msra::math::ssematrixbase *> &abcs, const bool softalignstates, const msra::asr::simplesenonehmm &hset,
|
||||
const edgealignments &thisedgealignments, std::vector<std::vector<size_t>>& vt_paths, std::vector<float>& path_weight, msra::math::ssematrixbase &errorsignal) const;
|
||||
*/
|
||||
void EMBRerrorsignal(parallelstate ¶llelstate,
|
||||
const edgealignments &thisedgealignments, std::vector<double>& edge_weights, msra::math::ssematrixbase &errorsignal) const;
|
||||
/* end: guoye */
|
||||
std::wstring key; // (keep our own name (key) so we can identify ourselves for diagnostics messages)
|
||||
const wchar_t* getkey() const
|
||||
{
|
||||
|
@ -1358,8 +1643,14 @@ public:
|
|||
if (sscanf(q, "[%" PRIu64 "]%c", &offset, &c) != 1)
|
||||
#endif
|
||||
RuntimeError("open: invalid TOC line (bad [] expression): %s", line);
|
||||
|
||||
if (!toc.insert(make_pair(key, latticeref(offset, archiveindex))).second)
|
||||
RuntimeError("open: TOC entry leads to duplicate key: %s", line);
|
||||
/* guoye: start */
|
||||
// sometimes, the training will report this error. I believe it is due to some small data corruption, and fine to go on, so change the error to warning
|
||||
// RuntimeError("open: TOC entry leads to duplicate key: %s", line);
|
||||
|
||||
fprintf(stderr, " open: TOC entry leads to duplicate key: %s\n", line);
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// initialize symmaps --alloc the array, but actually read the symmap on demand
|
||||
|
@ -1390,7 +1681,10 @@ public:
|
|||
// Lattices will have unit ids updated according to the modelsymmap.
|
||||
// V1 lattices will be converted. 'spsenoneid' is used in the conversion for optimizing storing 0-frame /sp/ aligns.
|
||||
void getlattice(const std::wstring& key, lattice& L,
|
||||
size_t expectedframes = SIZE_MAX /*if unknown*/) const
|
||||
/* guoye: start */
|
||||
// size_t expectedframes = SIZE_MAX /*if unknown*/) const
|
||||
std::set<int>& specialwordids, size_t expectedframes = SIZE_MAX) const
|
||||
/* guoye: end */
|
||||
{
|
||||
auto iter = toc.find(key);
|
||||
if (iter == toc.end())
|
||||
|
@ -1417,7 +1711,11 @@ public:
|
|||
// seek to start
|
||||
fsetpos(f, offset);
|
||||
// get it
|
||||
L.fread(f, idmap, spunit);
|
||||
/* guoye: start */
|
||||
// L.fread(f, idmap, spunit);
|
||||
L.fread(f, idmap, spunit, specialwordids);
|
||||
|
||||
/* guoye: end */
|
||||
L.setverbosity(verbosity);
|
||||
#ifdef HACK_IN_SILENCE // hack to simulate DEL in the lattice
|
||||
const size_t silunit = getid(modelsymmap, "sil");
|
||||
|
@ -1451,7 +1749,11 @@ public:
|
|||
// - dump to stdout
|
||||
// - merge two lattices (for merging numer into denom lattices)
|
||||
static void convert(const std::wstring& intocpath, const std::wstring& intocpath2, const std::wstring& outpath,
|
||||
const msra::asr::simplesenonehmm& hset);
|
||||
/* guoye: start */
|
||||
// const msra::asr::simplesenonehmm& hset);
|
||||
const msra::asr::simplesenonehmm& hset, std::set<int>& specialwordids);
|
||||
/* guoye: end */
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
|
|
|
@ -62,10 +62,17 @@ public:
|
|||
#endif
|
||||
}
|
||||
|
||||
void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes) const
|
||||
/* guoye: start */
|
||||
// void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes) const
|
||||
void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes, std::set<int>& specialwordids) const
|
||||
/* guoye: end */
|
||||
{
|
||||
std::shared_ptr<latticepair> LP(new latticepair);
|
||||
denlattices.getlattice(key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
|
||||
/* guoye: start */
|
||||
// denlattices.getlattice(key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
|
||||
denlattices.getlattice(key, LP->second, specialwordids, expectedframes); // this loads the lattice from disk, using the existing L.second object
|
||||
// fprintf(stderr, "latticesource.h:getlattices: %ls \n", key.c_str());
|
||||
/* guoye: end */
|
||||
L = LP;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <cstdio>
|
||||
/* guoye: start */
|
||||
#include <vector>
|
||||
/* guoye: end */
|
||||
|
||||
#undef INITIAL_STRANGE // [v-hansu] initialize structs to strange values
|
||||
#define PARALLEL_SIL // [v-hansu] process sil on CUDA, used in other files, please search this
|
||||
|
@ -30,11 +33,22 @@ struct nodeinfo
|
|||
// uint64_t firstinedge : 24; // index of first incoming edge
|
||||
// uint64_t firstoutedge : 24; // index of first outgoing edge
|
||||
// uint64_t t : 16; // time associated with this
|
||||
|
||||
/* guoye: start */
|
||||
uint64_t wid; // word ID associated with the node
|
||||
/* guoye: end */
|
||||
unsigned short t; // time associated with this
|
||||
nodeinfo(size_t pt)
|
||||
: t((unsigned short) pt) // , firstinedge (NOEDGE), firstoutedge (NOEDGE)
|
||||
|
||||
nodeinfo(size_t pt, size_t pwid)
|
||||
/* guoye: start */
|
||||
// : t((unsigned short) pt) // , firstinedge (NOEDGE), firstoutedge (NOEDGE)
|
||||
: t((unsigned short)pt), wid(pwid)
|
||||
/* guoye: end */
|
||||
{
|
||||
checkoverflow(t, pt, "nodeinfo::t");
|
||||
/* guoye: start */
|
||||
checkoverflow(wid, pwid, "nodeinfo::wid");
|
||||
/* guoye: end */
|
||||
// checkoverflow (firstinedge, NOEDGE, "nodeinfo::firstinedge");
|
||||
// checkoverflow (firstoutedge, NOEDGE, "nodeinfo::firstoutedge");
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -646,12 +646,47 @@ void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net,
|
|||
const double& lmf /*= 14.0f*/,
|
||||
const double& wp /*= 0.0f*/,
|
||||
const double& bMMIfactor /*= 0.0f*/,
|
||||
const bool& sMBR /*= false*/
|
||||
/* guoye: start */
|
||||
// const bool& sMBR /*= false*/
|
||||
const bool& sMBR /*= false */,
|
||||
const bool& EMBR /*= false */,
|
||||
const string& EMBRUnit /* = "word" */,
|
||||
const size_t& numPathsEMBR,
|
||||
const bool& enforceValidPathEMBR,
|
||||
const string& getPathMethodEMBR,
|
||||
const string& showWERMode,
|
||||
const bool& excludeSpecialWords,
|
||||
const bool& wordNbest,
|
||||
const bool& useAccInNbest,
|
||||
const float& accWeightInNbest,
|
||||
const size_t& numRawPathsEMBR
|
||||
/* guoye: end */
|
||||
)
|
||||
{
|
||||
fprintf(stderr, "Setting Hsmoothing weight to %.8g and frame-dropping threshhold to %.8g\n", hsmoothingWeight, frameDropThresh);
|
||||
/* guoye: start */
|
||||
|
||||
/*
|
||||
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, usesMBR=%s\n",
|
||||
amf, lmf, wp, bMMIfactor, sMBR ? "true" : "false");
|
||||
*/
|
||||
|
||||
if(EMBR)
|
||||
{
|
||||
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, useEMBR=true, EMBRUnit=%s, numPathsEMBR=%d, enforceValidPathEMBR = %d, getPathMethodEMBR = %s, showWERMode = %s, excludeSpecialWords = %d, wordNbest = %d, useAccInNbest = %d, accWeightInNbest = %f, numRawPathsEMBR = %d \n",
|
||||
amf, lmf, wp, bMMIfactor, EMBRUnit.c_str(), int(numPathsEMBR), int(enforceValidPathEMBR), getPathMethodEMBR.c_str(), showWERMode.c_str(), int(excludeSpecialWords), int(wordNbest), int(useAccInNbest), float(accWeightInNbest), int(numRawPathsEMBR));
|
||||
}
|
||||
else if(sMBR)
|
||||
{
|
||||
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, usesMBR=true \n",
|
||||
amf, lmf, wp, bMMIfactor);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, useMMI=true \n",
|
||||
amf, lmf, wp, bMMIfactor);
|
||||
}
|
||||
/* guoye: end */
|
||||
list<ComputationNodeBasePtr> seqNodes = net->GetNodesWithType(OperationNameOf(SequenceWithSoftmaxNode), criterionNode);
|
||||
if (seqNodes.size() == 0)
|
||||
{
|
||||
|
@ -665,7 +700,11 @@ void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net,
|
|||
node->SetSmoothWeight(hsmoothingWeight);
|
||||
node->SetFrameDropThresh(frameDropThresh);
|
||||
node->SetReferenceAlign(doreferencealign);
|
||||
node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR);
|
||||
/* guoye: start */
|
||||
// node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR);
|
||||
node->SetMBR(sMBR || EMBR);
|
||||
node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR, EMBR, EMBRUnit, numPathsEMBR, enforceValidPathEMBR, getPathMethodEMBR, showWERMode, excludeSpecialWords, wordNbest, useAccInNbest, accWeightInNbest, numRawPathsEMBR);
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1549,17 +1588,35 @@ template void ComputationNetwork::Read<float>(const wstring& fileName);
|
|||
template void ComputationNetwork::ReadPersistableParameters<float>(size_t modelVersion, File& fstream, bool create);
|
||||
template void ComputationNetwork::PerformSVDecomposition<float>(const map<wstring, float>& SVDConfig, size_t alignedsize);
|
||||
template /*static*/ void ComputationNetwork::SetBatchNormalizationTimeConstants<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double normalizationTimeConstant, double& prevNormalizationTimeConstant, double blendTimeConstant, double& prevBlendTimeConstant);
|
||||
/* guoye: start */
|
||||
/*
|
||||
template void ComputationNetwork::SetSeqParam<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
|
||||
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR);
|
||||
*/
|
||||
|
||||
template void ComputationNetwork::SetSeqParam<float>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign, const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR, const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR);
|
||||
|
||||
/* guoye: end */
|
||||
template void ComputationNetwork::SaveToDbnFile<float>(ComputationNetworkPtr net, const std::wstring& fileName) const;
|
||||
|
||||
template void ComputationNetwork::InitLearnableParametersWithBilinearFill<double>(const ComputationNodeBasePtr& node, size_t kernelWidth, size_t kernelHeight);
|
||||
template void ComputationNetwork::Read<double>(const wstring& fileName);
|
||||
template void ComputationNetwork::ReadPersistableParameters<double>(size_t modelVersion, File& fstream, bool create);
|
||||
template void ComputationNetwork::PerformSVDecomposition<double>(const map<wstring, float>& SVDConfig, size_t alignedsize);
|
||||
|
||||
template /*static*/ void ComputationNetwork::SetBatchNormalizationTimeConstants<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double normalizationTimeConstant, double& prevNormalizationTimeConstant, double blendTimeConstant, double& prevBlendTimeConstant);
|
||||
|
||||
/* guoye: start */
|
||||
/*
|
||||
template void ComputationNetwork::SetSeqParam<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
|
||||
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR);
|
||||
*/
|
||||
template void ComputationNetwork::SetSeqParam<double>(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign,
|
||||
const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR,
|
||||
const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR);
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
template void ComputationNetwork::SaveToDbnFile<double>(ComputationNetworkPtr net, const std::wstring& fileName) const;
|
||||
|
||||
template void ComputationNetwork::InitLearnableParametersWithBilinearFill<half>(const ComputationNodeBasePtr& node, size_t kernelWidth, size_t kernelHeight);
|
||||
|
|
|
@ -554,7 +554,22 @@ public:
|
|||
const double& lmf = 14.0f,
|
||||
const double& wp = 0.0f,
|
||||
const double& bMMIfactor = 0.0f,
|
||||
const bool& sMBR = false);
|
||||
/* guoye: start */
|
||||
// const bool& sMBR = false);
|
||||
const bool& sMBR = false,
|
||||
const bool& EMBR = false,
|
||||
const string& EMBRUnit = "word",
|
||||
const size_t& numPathsEMBR = 100,
|
||||
const bool& enforceValidPathEMBR = false,
|
||||
const string& getPathMethodEMBR = "sampling",
|
||||
const string& showWERMode = "average",
|
||||
const bool& excludeSpecialWords = false,
|
||||
const bool& wordNbest = false,
|
||||
const bool& useAccInNbest = false,
|
||||
const float& accWeightInNbest = 1.0f,
|
||||
const size_t& numRawPathsEMBR = 100
|
||||
);
|
||||
/* guoye: end */
|
||||
static void SetMaxTempMemSizeForCNN(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const size_t maxTempMemSizeInSamples);
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -371,9 +371,21 @@ static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient)
|
|||
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/
|
||||
{
|
||||
// TODO: should we deallocate in opposite order?
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 0, m_nestedNodes.size() = %d \n", int(m_nestedNodes.size()));
|
||||
int count = 0;
|
||||
/* guoye: end */
|
||||
|
||||
for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 1, count = %d \n", int(count));
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
(*nodeIter)->AllocateGradientMatricesForInputs(matrixPool);
|
||||
count++;
|
||||
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 2, count = %d \n", int(count));
|
||||
}
|
||||
}
|
||||
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/
|
||||
|
@ -1061,36 +1073,53 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
|
|||
const std::vector<ComputationNodeBasePtr>& outValueRootNodes,
|
||||
ComputationNodeBasePtr trainRootNode)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 1\n");
|
||||
/* guoye: end */
|
||||
if (AreMatricesAllocated())
|
||||
return;
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 2\n");
|
||||
/* guoye: end */
|
||||
|
||||
// Allocate memory for forward/backward computation
|
||||
if (TraceLevel() > 0)
|
||||
fprintf(stderr, "\n\nAllocating matrices for forward and/or backward propagation.\n");
|
||||
|
||||
VerifyIsCompiled("AllocateAllMatrices");
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 3\n");
|
||||
/* guoye: end */
|
||||
std::vector<ComputationNodeBasePtr> forwardPropRoots;
|
||||
forwardPropRoots.insert(forwardPropRoots.end(), evalRootNodes.begin(), evalRootNodes.end());
|
||||
forwardPropRoots.insert(forwardPropRoots.end(), outValueRootNodes.begin(), outValueRootNodes.end());
|
||||
if (trainRootNode != nullptr)
|
||||
forwardPropRoots.push_back(trainRootNode);
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 4\n");
|
||||
/* guoye: end */
|
||||
// Mark all the eval, output and criterion roots as non-shareable
|
||||
for (auto& rootNode : forwardPropRoots)
|
||||
rootNode->MarkValueNonSharable();
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 5\n");
|
||||
/* guoye: end */
|
||||
// Due to special topology, if a node is solely induced by parameters, its function value should not be shared
|
||||
MarkValueNonSharableNodes();
|
||||
|
||||
bool performingBackPropagation = (trainRootNode != nullptr);
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 6\n");
|
||||
/* guoye: end */
|
||||
// Create a composite Eval order with the specified nodes as roots
|
||||
// For each node determine parents and whether the output of the
|
||||
// node is needed during back propagation
|
||||
std::unordered_map<ComputationNodeBasePtr, bool> outputValueNeededDuringBackProp;
|
||||
std::unordered_map<ComputationNodeBasePtr, std::unordered_set<ComputationNodeBasePtr>> parentsMap;
|
||||
std::unordered_set<ComputationNodeBasePtr> uniqueForwardPropEvalNodes;
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 7\n");
|
||||
/* guoye: end */
|
||||
for (auto& rootNode : forwardPropRoots)
|
||||
{
|
||||
for (const auto& node : GetEvalOrder(rootNode))
|
||||
|
@ -1115,7 +1144,9 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 8\n");
|
||||
/* guoye: end */
|
||||
// gradient reuse maps
|
||||
std::unordered_map<MatrixPool::AliasNodePtr, std::unordered_set<MatrixPool::AliasNodePtr>> gradientReuseChildrenMap;
|
||||
std::unordered_map<MatrixPool::AliasNodePtr, MatrixPool::AliasNodePtr> gradientReuseParentMap;
|
||||
|
@ -1151,6 +1182,9 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
|
|||
}
|
||||
}
|
||||
}
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 9\n");
|
||||
/* guoye: end */
|
||||
|
||||
m_matrixPool.Reset();
|
||||
|
||||
|
@ -1175,7 +1209,9 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
|
|||
ReleaseMatricesAfterEvalForChildren(node, parentsMap);
|
||||
}
|
||||
});
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10\n");
|
||||
/* guoye: end */
|
||||
if (trainRootNode != nullptr)
|
||||
{
|
||||
const std::list<ComputationNodeBasePtr>& backPropNodes = GetEvalOrder(trainRootNode);
|
||||
|
@ -1184,10 +1220,11 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
|
|||
|
||||
std::unordered_map<MatrixPool::AliasNodePtr, std::unordered_set<MatrixPool::AliasNodePtr>> compactGradientAliasMap;
|
||||
std::unordered_map<MatrixPool::AliasNodePtr, MatrixPool::AliasNodePtr> compactGradientAliasRootMap;
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.1\n");
|
||||
for (const auto& gradientReuseKeyValue : gradientReuseChildrenMap)
|
||||
{
|
||||
// keep searching parent until reaching root
|
||||
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.2\n");
|
||||
auto parent = gradientReuseKeyValue.first;
|
||||
auto parentIter = gradientReuseParentMap.find(parent);
|
||||
while (parentIter != gradientReuseParentMap.end())
|
||||
|
@ -1214,20 +1251,22 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
|
|||
compactGradientAliasMap[parent].insert(parent);
|
||||
compactGradientAliasRootMap[parent] = parent;
|
||||
}
|
||||
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.3\n");
|
||||
// print the memory aliasing info
|
||||
if (TraceLevel() > 0 && compactGradientAliasRootMap.size() > 0)
|
||||
{
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.4\n");
|
||||
fprintf(stderr, "\nGradient Memory Aliasing: %d are aliased.\n", (int)compactGradientAliasRootMap.size());
|
||||
for (const auto pair : compactGradientAliasRootMap)
|
||||
{
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.5\n");
|
||||
auto child = (const ComputationNodeBase*)pair.first;
|
||||
auto parent = (const ComputationNodeBase*)pair.second;
|
||||
if (child != parent)
|
||||
fprintf(stderr, "\t%S (gradient) reuses %S (gradient)\n", child->GetName().c_str(), parent->GetName().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.6\n");
|
||||
m_matrixPool.SetAliasInfo(compactGradientAliasMap, compactGradientAliasRootMap);
|
||||
|
||||
// now, simulate the gradient computation order to determine how to allocate matrices
|
||||
|
@ -1235,38 +1274,55 @@ void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBa
|
|||
|
||||
// we need to call it here since we always compute gradients for children and root node is not children of other node
|
||||
trainRootNode->RequestMatricesBeforeBackprop(m_matrixPool);
|
||||
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.7\n");
|
||||
for (auto iter = backPropNodes.rbegin(); iter != backPropNodes.rend(); iter++) // for gradient computation, traverse in reverse order
|
||||
{
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8\n");
|
||||
auto n = *iter;
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.1\n");
|
||||
if (n->IsPartOfLoop())
|
||||
{
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.2\n");
|
||||
std::vector<ComputationNodeBasePtr> recurrentNodes;
|
||||
shared_ptr<SEQTraversalFlowControlNode> recInfo = FindInRecurrentLoops(m_allSEQNodes, n);
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.3\n");
|
||||
if (completedGradient.insert(recInfo).second)
|
||||
{
|
||||
// SEQ mode: allocate all in loop first, then deallocate again
|
||||
// TODO: next step: use PARTraversalFlowControlNode::AllocateGradientMatricesForInputs() and ReleaseMatricesAfterBackprop()...
|
||||
// BUGBUG: naw, ^^ would not work! Wrong order! Need to rethink this. Need to make AllocateEvalMatrices() and AllocateGradientMatrices() the virtual functions.
|
||||
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.4\n");
|
||||
recInfo->AllocateGradientMatricesForInputs(m_matrixPool);
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.5\n");
|
||||
// Loops are computed sample by sample so we have to allocate them all
|
||||
recInfo->ReleaseMatricesAfterBackprop(m_matrixPool);
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.6\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// PAR mode: we can allocate and immediately deallocate one by one
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.7\n");
|
||||
n->AllocateGradientMatricesForInputs(m_matrixPool);
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.8\n");
|
||||
// Root node's information will be used and should not be shared with others, also it's small (1x1)
|
||||
if ((n != trainRootNode) && n->NeedsGradient())
|
||||
n->ReleaseMatricesAfterBackprop(m_matrixPool);
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.9\n");
|
||||
}
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.8.10\n");
|
||||
}
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 10.9\n");
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 11\n");
|
||||
/* guoye: end */
|
||||
m_matrixPool.OptimizedMemoryAllocation();
|
||||
m_areMatricesAllocated = true;
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\nAllocateAllMatrices: debug 12\n");
|
||||
/* guoye: end */
|
||||
// TO DO: At the time of AllocateAllMatrices we don't know the minibatch size. In theory one may allocate memory again once we start to receive
|
||||
// data from the reader (and the minibatch size is known). For some problems, minibatch size can change constantly, and there needs to be a
|
||||
// tradeoff in deciding how frequent to run optimized memory allocation. For now, we do it only once at the very beginning for speed concerns.
|
||||
|
|
|
@ -1874,24 +1874,54 @@ public:
|
|||
|
||||
virtual void AllocateGradientMatricesForInputs(MatrixPool& matrixPool) override
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 0, m_inputs.size() = %d \n", int(m_inputs.size()));
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
for (int i = 0; i < m_inputs.size(); i++)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 1, i = %d \n", int(i));
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
if (m_inputs[i]->NeedsGradient())
|
||||
m_inputs[i]->RequestMatricesBeforeBackprop(matrixPool);
|
||||
|
||||
// fprintf(stderr, "\n AllocateGradientMatricesForInputs: debug 2, i = %d \n", int(i));
|
||||
}
|
||||
}
|
||||
|
||||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
size_t matrixSize = m_sampleLayout.GetNumElements();
|
||||
RequestMatrixFromPool(m_gradient, matrixPool, matrixSize, HasMBLayout(), /*isWorkSpace*/false, ParentGradientReused() || IsGradientReused());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
|
||||
size_t matrixSize = m_sampleLayout.GetNumElements();
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 2, matrixSize = %d, mbscale = %d \n", int(matrixSize), int(HasMBLayout()));
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_gradient, matrixPool, matrixSize, HasMBLayout(), /*isWorkSpace*/false, ParentGradientReused() || IsGradientReused());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
auto multiOutputNode = dynamic_cast<MultiOutputNode<ElemType>*>(this);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
if (multiOutputNode)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 5, multiOutputNode->m_numOutputs = %d \n", int(multiOutputNode->m_numOutputs));
|
||||
/* guoye: end */
|
||||
for (size_t i = 1; i < multiOutputNode->m_numOutputs; ++i)
|
||||
{
|
||||
// fprintf(stderr, "\n computationnode.h: RequestMatricesBeforeBackprop: debug 6, i = %d \n", int(i));
|
||||
RequestMatrixFromPool(multiOutputNode->m_outputsGradient[i], matrixPool, multiOutputNode->m_outputsShape[i].GetNumElements(), multiOutputNode->m_outputsMBLayout[i] != nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1957,6 +1987,11 @@ protected:
|
|||
template<typename ValueType>
|
||||
void TypedRequestMatrixFromPool(shared_ptr<Matrix<ValueType>>& matrixPtr, MatrixPool& matrixPool, size_t matrixSize=0, bool mbScale=false, bool isWorkSpace=false, bool aliasing=false)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n computationnode.h:RequestMatrixFromPool, debug 0 \n");
|
||||
|
||||
// fprintf(stderr, "\n computationnode.h:RequestMatrixFromPool, debug 1 \n");
|
||||
/* guoye: end */
|
||||
if (matrixPtr == nullptr)
|
||||
{
|
||||
if (aliasing)
|
||||
|
|
|
@ -616,8 +616,17 @@ public:
|
|||
|
||||
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n convolutionalnodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n convolutionalnodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempMatrixBackward, matrixPool, 0, false, true);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n convolutionalnodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
|
||||
|
|
|
@ -269,9 +269,21 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_innerproduct, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_rightGradient, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n deprecatednode.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
|
|
@ -1893,9 +1893,21 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_invNormSquare, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_temp, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n linearalgebranodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
|
|
@ -137,14 +137,52 @@ public:
|
|||
template <class ElemType>
|
||||
void RequestAllocate(DEVICEID_TYPE deviceId, shared_ptr<Matrix<ElemType>>*pMatrixPtr, size_t matrixSize, bool mbScale, bool isWorkSpace)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 1 \n");
|
||||
/* guoye: end */
|
||||
vector<MemRequestInfo<ElemType>>& memInfoVec = GetMemRequestInfoVec<ElemType>();
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 2 \n");
|
||||
/* guoye: end */
|
||||
MemRequestInfo<ElemType> memInfo(deviceId, pMatrixPtr, matrixSize, mbScale, isWorkSpace, m_stepCounter);
|
||||
memInfoVec.push_back(memInfo);
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 3, memInfo.pMatrixPtrs.size() = %d, memInfoVec.size() = %d, sizeof(meminfo) = %d, \n", int(memInfo.pMatrixPtrs.size()), int(memInfoVec.size()), int(sizeof(memInfo)));
|
||||
/*
|
||||
if (memInfoVec.size() >= 256)
|
||||
{
|
||||
fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 3.5, sizeof(meminfo) is equal or large than 256, do no push \n");
|
||||
memInfoVec.resize(memInfoVec.size() + 1, memInfo);
|
||||
}
|
||||
*/
|
||||
/* guoye: end */
|
||||
/* guoye: start */
|
||||
/*
|
||||
else
|
||||
*/
|
||||
{
|
||||
|
||||
memInfoVec.push_back(memInfo);
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 4 \n");
|
||||
/* guoye: end */
|
||||
m_deviceIDSet.insert(deviceId);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 5 \n");
|
||||
/* guoye: end */
|
||||
m_stepCounter++;
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 6 \n");
|
||||
/* guoye: end */
|
||||
|
||||
// assign some temporary pointer, they will be replaced later unless the matrix is sparse
|
||||
*pMatrixPtr = make_shared<Matrix<ElemType>>(deviceId);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAllocate, debug 7 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
void OptimizedMemoryAllocation()
|
||||
|
@ -207,24 +245,59 @@ public:
|
|||
template <class ElemType>
|
||||
void RequestAliasedAllocate(DEVICEID_TYPE deviceId, AliasNodePtr node, shared_ptr<Matrix<ElemType>>*pMatrixPtr, size_t matrixSize, bool mbScale)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 1 \n");
|
||||
/* guoye: end */
|
||||
const auto iter = m_aliasLookup.find(node);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 2 \n");
|
||||
/* guoye: end */
|
||||
if (iter == m_aliasLookup.end())
|
||||
LogicError("node not aliased");
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 3 \n");
|
||||
/* guoye: end */
|
||||
auto parent = iter->second;
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 4 \n");
|
||||
/* guoye: end */
|
||||
auto& aliasInfo = m_aliasGroups[parent];
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 5 \n");
|
||||
/* guoye: end */
|
||||
if (aliasInfo.pMatrixPtr == nullptr)
|
||||
{
|
||||
// first allocation for the group
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 6 \n");
|
||||
/* guoye: end */
|
||||
aliasInfo.pMatrixPtr = pMatrixPtr;
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 7 \n");
|
||||
/* guoye: end */
|
||||
RequestAllocate(deviceId, pMatrixPtr, matrixSize, mbScale, false);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 8 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
else
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 9 \n");
|
||||
/* guoye: end */
|
||||
auto aliasRootMatrixPtr = (shared_ptr<Matrix<ElemType>>*)aliasInfo.pMatrixPtr;
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 10 \n");
|
||||
/* guoye: end */
|
||||
*pMatrixPtr = *aliasRootMatrixPtr;
|
||||
GetMemInfo<ElemType>(aliasRootMatrixPtr)->pMatrixPtrs.push_back(pMatrixPtr);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 11 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n matrixpool.h:RequestAliasedAllocate, debug 12 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -247,8 +247,17 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 8 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_gradientTemp, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 9 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
@ -318,8 +327,17 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 5 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_diff, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
@ -385,8 +403,17 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_softmax, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n nonlinearitynodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
|
|
@ -62,9 +62,21 @@ public:
|
|||
// request matrices needed to do node derivative value evaluation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_transposedDInput, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_transposedDOutput, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n rnnnodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
|
|
@ -342,8 +342,17 @@ public:
|
|||
|
||||
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 5 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempGatherIndices, matrixPool, 1, InputRef(0).HasMBLayout());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
|
||||
|
@ -508,9 +517,21 @@ public:
|
|||
|
||||
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempScatterIndices, matrixPool, 1, HasMBLayout());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempUnpackedData, matrixPool, GetSampleLayout().GetNumElements(), HasMBLayout());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n reshapingnodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
|
||||
|
|
|
@ -128,9 +128,21 @@ public:
|
|||
|
||||
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
/* guoye: start */
|
||||
//fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempScatterIndices, matrixPool, 1, HasMBLayout());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 3 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempUnpackedData, matrixPool, InputRef(0).GetSampleLayout().GetNumElements(), InputRef(0).HasMBLayout());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 4 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
|
||||
|
@ -464,9 +476,21 @@ public:
|
|||
|
||||
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 5 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempGatherIndices, matrixPool, 1, HasMBLayout());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_tempPackedGradientData, matrixPool, InputRef(0).GetSampleLayout().GetNumElements(), InputRef(0).HasMBLayout());
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n sequencereshapenodes.h: RequestMatricesBeforeBackprop: debug 8 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
|
||||
|
|
|
@ -477,12 +477,15 @@ public:
|
|||
{
|
||||
Input(inputIndex)->Gradient().SetValue(0.0f);
|
||||
Value().SetValue(1.0f);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
FrameRange fr(Input(0)->GetMBLayout());
|
||||
BackpropToRight(*m_softmaxOfRight, Input(0)->Value(), Input(inputIndex)->Gradient(),
|
||||
Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold);
|
||||
/* guoye: start */
|
||||
// Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold);
|
||||
Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold, m_MBR);
|
||||
/* guoye: end */
|
||||
MaskMissingColumnsToZero(Input(inputIndex)->Gradient(), Input(0)->GetMBLayout(), fr);
|
||||
}
|
||||
#ifdef _DEBUG
|
||||
|
@ -518,7 +521,10 @@ public:
|
|||
|
||||
static void WINAPI BackpropToRight(const Matrix<ElemType>& softmaxOfRight, const Matrix<ElemType>& inputFunctionValues,
|
||||
Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& gradientValues,
|
||||
const Matrix<ElemType>& gammaFromLattice, double hsmoothingWeight, double frameDropThresh)
|
||||
/* guoye: start */
|
||||
//const Matrix<ElemType>& gammaFromLattice, double hsmoothingWeight, double frameDropThresh)
|
||||
const Matrix<ElemType>& gammaFromLattice, double hsmoothingWeight, double frameDropThresh, bool MBR)
|
||||
/* guoye: end */
|
||||
{
|
||||
#if DUMPOUTPUT
|
||||
softmaxOfRight.Print("SequenceWithSoftmaxNode Partial-softmaxOfRight");
|
||||
|
@ -526,8 +532,10 @@ public:
|
|||
gradientValues.Print("SequenceWithSoftmaxNode Partial-gradientValues");
|
||||
inputGradientValues.Print("SequenceWithSoftmaxNode Partial-Right-in");
|
||||
#endif
|
||||
|
||||
inputGradientValues.AssignSequenceError((ElemType) hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element());
|
||||
/* guoye: start */
|
||||
// inputGradientValues.AssignSequenceError((ElemType) hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element());
|
||||
inputGradientValues.AssignSequenceError((ElemType)hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element(), MBR);
|
||||
/* guoye: end */
|
||||
inputGradientValues.DropFrame(inputFunctionValues, gammaFromLattice, (ElemType) frameDropThresh);
|
||||
#if DUMPOUTPUT
|
||||
inputGradientValues.Print("SequenceWithSoftmaxNode Partial-Right");
|
||||
|
@ -561,9 +569,20 @@ public:
|
|||
|
||||
m_gammaFromLattice->SwitchToMatrixType(m_softmaxOfRight->GetMatrixType(), m_softmaxOfRight->GetFormat(), false);
|
||||
m_gammaFromLattice->Resize(*m_softmaxOfRight);
|
||||
// guoye: start
|
||||
// fprintf(stderr, "guoye debug: calgammaformb, m_m_nws.size() = %d \n", int(m_nws.size()));
|
||||
for (size_t i = 0; i < m_nws.size(); i++)
|
||||
{
|
||||
// fprintf(stderr, "guoye debug: calgammaformb, i = %d, m_nws[i] = %d \n", int(i), int(m_nws[i]));
|
||||
}
|
||||
// guoye: end
|
||||
|
||||
m_gammaCalculator.calgammaformb(Value(), m_lattices, Input(2)->Value() /*log LLs*/,
|
||||
Input(0)->Value() /*labels*/, *m_gammaFromLattice,
|
||||
m_uids, m_boundaries, Input(1)->GetNumParallelSequences(),
|
||||
/* guoye: start */
|
||||
// m_uids, m_boundaries, Input(1)->GetNumParallelSequences(),
|
||||
m_uids, m_wids, m_nws, m_boundaries, Input(1)->GetNumParallelSequences(),
|
||||
/* guoye: end */
|
||||
Input(0)->GetMBLayout(), m_extraUttMap, m_doReferenceAlignment);
|
||||
|
||||
#if NANCHECK
|
||||
|
@ -635,15 +654,25 @@ public:
|
|||
// TODO: method names should be CamelCase
|
||||
std::vector<shared_ptr<const msra::dbn::latticepair>>* getLatticePtr() { return &m_lattices; }
|
||||
std::vector<size_t>* getuidprt() { return &m_uids; }
|
||||
/* guoye: start */
|
||||
std::vector<size_t>* getwidprt() { return &m_wids; }
|
||||
|
||||
std::vector<short>* getnwprt() { return &m_nws; }
|
||||
|
||||
/* guoye: end */
|
||||
std::vector<size_t>* getboundaryprt() { return &m_boundaries; }
|
||||
std::vector<size_t>* getextrauttmap() { return &m_extraUttMap; }
|
||||
msra::asr::simplesenonehmm* gethmm() { return &m_hmm; }
|
||||
|
||||
void SetSmoothWeight(double fsSmoothingWeight) { m_fsSmoothingWeight = fsSmoothingWeight; }
|
||||
/* guoye : start */
|
||||
void SetMBR(bool MBR) { m_MBR = MBR; }
|
||||
/* guoye : end */
|
||||
void SetFrameDropThresh(double frameDropThresh) { m_frameDropThreshold = frameDropThresh; }
|
||||
void SetReferenceAlign(const bool doreferencealign) { m_doReferenceAlignment = doreferencealign; }
|
||||
|
||||
void SetGammarCalculationParam(const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR)
|
||||
void SetGammarCalculationParam(const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR,
|
||||
const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR)
|
||||
{
|
||||
msra::lattices::SeqGammarCalParam param;
|
||||
param.amf = amf;
|
||||
|
@ -651,6 +680,20 @@ public:
|
|||
param.wp = wp;
|
||||
param.bMMIfactor = bMMIfactor;
|
||||
param.sMBRmode = sMBR;
|
||||
|
||||
/* guoye: start */
|
||||
param.EMBR = EMBR;
|
||||
param.EMBRUnit = EMBRUnit;
|
||||
param.numPathsEMBR = numPathsEMBR;
|
||||
param.enforceValidPathEMBR = enforceValidPathEMBR;
|
||||
param.getPathMethodEMBR = getPathMethodEMBR;
|
||||
param.showWERMode = showWERMode;
|
||||
param.excludeSpecialWords = excludeSpecialWords;
|
||||
param.wordNbest = wordNbest;
|
||||
param.useAccInNbest = useAccInNbest;
|
||||
param.accWeightInNbest = accWeightInNbest;
|
||||
param.numRawPathsEMBR = numRawPathsEMBR;
|
||||
/* guoye: end */
|
||||
m_gammaCalculator.SetGammarCalculationParams(param);
|
||||
}
|
||||
|
||||
|
@ -667,6 +710,9 @@ protected:
|
|||
bool m_invalidMinibatch; // for single minibatch
|
||||
double m_frameDropThreshold;
|
||||
double m_fsSmoothingWeight; // frame-sequence criterion interpolation weight --TODO: can this be done outside?
|
||||
/* guoye: start */
|
||||
bool m_MBR;
|
||||
/* guoye: end */
|
||||
double m_seqGammarAMF;
|
||||
double m_seqGammarLMF;
|
||||
double m_seqGammarWP;
|
||||
|
@ -678,6 +724,11 @@ protected:
|
|||
msra::lattices::GammaCalculation<ElemType> m_gammaCalculator;
|
||||
bool m_gammaCalcInitialized;
|
||||
std::vector<size_t> m_uids;
|
||||
/* guoye: start */
|
||||
std::vector<size_t> m_wids;
|
||||
|
||||
std::vector<short> m_nws;
|
||||
/* guoye: end */
|
||||
std::vector<size_t> m_boundaries;
|
||||
std::vector<size_t> m_extraUttMap;
|
||||
|
||||
|
@ -806,7 +857,9 @@ public:
|
|||
auto& currentLatticeSeq = latticeMBLayout->FindSequence(currentLabelSeq.seqId);
|
||||
std::shared_ptr<msra::dbn::latticepair> latticePair(new msra::dbn::latticepair);
|
||||
const char* buffer = bufferStart + latticeMBNumTimeSteps * sizeof(float) * currentLatticeSeq.s + currentLatticeSeq.tBegin;
|
||||
latticePair->second.ReadFromBuffer(buffer, m_idmap, m_idmap.back());
|
||||
|
||||
|
||||
latticePair->second.ReadFromBuffer(buffer, m_idmap, m_idmap.back(), specialwordids());
|
||||
assert((currentLabelSeq.tEnd - currentLabelSeq.tBegin) == latticePair->second.info.numframes);
|
||||
// The size of the vector is small -- the number of sequences in the minibatch.
|
||||
// Iteration likely will be faster than the overhead with unordered_map
|
||||
|
|
|
@ -351,8 +351,17 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 6 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 7 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_leftDivRight, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 8 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
@ -444,8 +453,17 @@ public:
|
|||
// request matrices that are needed for gradient computation
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 9 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 10 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_gradientOfL1Norm, matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 11 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// release gradient and temp matrices that no longer needed after all the children's gradients are computed.
|
||||
|
@ -2925,7 +2943,13 @@ public:
|
|||
|
||||
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 1 \n");
|
||||
/* guoye: end */
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n TrainingNodes.h: RequestMatricesBeforeBackprop: debug 2 \n");
|
||||
/* guoye: end */
|
||||
RequestMatrixFromPool(m_dDataDummy, matrixPool);
|
||||
this->template TypedRequestMatrixFromPool<StatType>(m_dScale, matrixPool);
|
||||
this->template TypedRequestMatrixFromPool<StatType>(m_dBias, matrixPool);
|
||||
|
|
Двоичные данные
Source/Extensibility/EvalWrapper/EvalWrapperKeyPair.snk
Двоичные данные
Source/Extensibility/EvalWrapper/EvalWrapperKeyPair.snk
Двоичный файл не отображается.
|
@ -4675,7 +4675,10 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::DropFrame(const GPUMatrix<ElemType>& l
|
|||
|
||||
template <class ElemType>
|
||||
GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label,
|
||||
const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha)
|
||||
/* guoye: start */
|
||||
// const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha)
|
||||
const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha, bool MBR)
|
||||
/* guoye: end */
|
||||
{
|
||||
if (IsEmpty())
|
||||
LogicError("AssignSequenceError: Matrix is empty.");
|
||||
|
@ -4685,7 +4688,10 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignSequenceError(const ElemType hsm
|
|||
SyncGuard syncGuard;
|
||||
long N = (LONG64) label.GetNumElements();
|
||||
int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock);
|
||||
_AssignSequenceError<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N);
|
||||
/* guoye: start */
|
||||
//_AssignSequenceError<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N);
|
||||
_AssignSequenceError << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> >(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N, MBR);
|
||||
/* guoye: end */
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
@ -370,8 +370,10 @@ public:
|
|||
|
||||
// sequence training
|
||||
GPUMatrix<ElemType>& DropFrame(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& gamma, const ElemType& threshhold);
|
||||
GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha);
|
||||
|
||||
/* guoye: start */
|
||||
//GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha);
|
||||
GPUMatrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& dnnoutput, const GPUMatrix<ElemType>& gamma, ElemType alpha, bool MBR);
|
||||
/* guoye: end */
|
||||
GPUMatrix<ElemType>& AssignCTCScore(const GPUMatrix<ElemType>& prob, GPUMatrix<ElemType>& alpha, GPUMatrix<ElemType>& beta,
|
||||
const GPUMatrix<ElemType> phoneSeq, const GPUMatrix<ElemType> phoneBoundary, GPUMatrix<ElemType> & totalScore, const vector<size_t>& uttMap, const vector<size_t> & uttBeginFrame, const vector<size_t> & uttFrameNum,
|
||||
const vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const size_t blankTokenId, const int delayConstraint, const bool isColWise);
|
||||
|
|
|
@ -5256,13 +5256,24 @@ __global__ void _DropFrame(
|
|||
|
||||
template <class ElemType>
|
||||
__global__ void _AssignSequenceError(const ElemType hsmoothingWeight, ElemType* error, const ElemType* label,
|
||||
const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N)
|
||||
/* guoye: start */
|
||||
// const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N)
|
||||
const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N, bool MBR)
|
||||
/* guoye: end */
|
||||
{
|
||||
typedef typename TypeSelector<ElemType>::comp_t comp_t;
|
||||
int id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (id >= N)
|
||||
return;
|
||||
error[id] = (comp_t)error[id] - (comp_t)alpha * ((comp_t)label[id] - (1.0 - (comp_t)hsmoothingWeight) * (comp_t)dnnoutput[id] - (comp_t)hsmoothingWeight * (comp_t)gamma[id]);
|
||||
/* guoye: start */
|
||||
// error[id] -= alpha * (label[id] - (1.0 - hsmoothingWeight) * dnnoutput[id] - hsmoothingWeight * gamma[id]);
|
||||
if(!MBR)
|
||||
error[id] -= alpha * (label[id] - (1.0 - hsmoothingWeight) * dnnoutput[id] - hsmoothingWeight * gamma[id]);
|
||||
else
|
||||
error[id] -= alpha * ( (1.0 - hsmoothingWeight) * (label[id] - dnnoutput[id]) + hsmoothingWeight * gamma[id]);
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
// change to ce
|
||||
// error[id] -= alpha * (label[id] - dnnoutput[id] );
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@
|
|||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>$(MathIncludePath);$(BOOST_INCLUDE_PATH);$(SolutionDir)Source\Common\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
|
||||
<DisableSpecificWarnings>4819</DisableSpecificWarnings>
|
||||
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
<AdditionalOptions>/d2Zi+ /bigobj /FS %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
|
|
|
@ -99,7 +99,7 @@ xcopy /D /Y "$(CuDnnDll)" "$(OutputPath)"
|
|||
<ClCompile>
|
||||
<EnableParallelCodeGeneration>true</EnableParallelCodeGeneration>
|
||||
<FloatingPointExceptions>false</FloatingPointExceptions>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<AdditionalOptions>/d2Zi+ /FS </AdditionalOptions>
|
||||
</ClCompile>
|
||||
<CudaCompile>
|
||||
<HostDebugInfo>false</HostDebugInfo>
|
||||
|
|
|
@ -6098,7 +6098,10 @@ Matrix<ElemType>& Matrix<ElemType>::DropFrame(const Matrix<ElemType>& label, con
|
|||
/// <param name="c">Resulting matrix, user is responsible for allocating this</param>
|
||||
template <class ElemType>
|
||||
Matrix<ElemType>& Matrix<ElemType>::AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label,
|
||||
const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha)
|
||||
/* guoye: start */
|
||||
// const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha)
|
||||
const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha, bool MBR)
|
||||
/* guoye: end */
|
||||
{
|
||||
DecideAndMoveToRightDevice(label, dnnoutput, gamma);
|
||||
|
||||
|
@ -6106,11 +6109,16 @@ Matrix<ElemType>& Matrix<ElemType>::AssignSequenceError(const ElemType hsmoothin
|
|||
NOT_IMPLEMENTED;
|
||||
|
||||
SwitchToMatrixType(label.GetMatrixType(), label.GetFormat(), false);
|
||||
|
||||
|
||||
|
||||
DISPATCH_MATRIX_ON_FLAG(this,
|
||||
this,
|
||||
m_CPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_CPUMatrix, *dnnoutput.m_CPUMatrix, *gamma.m_CPUMatrix, alpha),
|
||||
m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha),
|
||||
/* guoye: start */
|
||||
// m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha),
|
||||
m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha, MBR),
|
||||
/* guoye: end */
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
return *this;
|
||||
|
|
|
@ -402,8 +402,10 @@ public:
|
|||
|
||||
// sequence training
|
||||
Matrix<ElemType>& DropFrame(const Matrix<ElemType>& label, const Matrix<ElemType>& gamma, const ElemType& threshhold);
|
||||
Matrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label, const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha);
|
||||
|
||||
/* guoye: start */
|
||||
// Matrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label, const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha);
|
||||
Matrix<ElemType>& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix<ElemType>& label, const Matrix<ElemType>& dnnoutput, const Matrix<ElemType>& gamma, ElemType alpha, bool MBR);
|
||||
/* guoye: end */
|
||||
Matrix<ElemType>& AssignCTCScore(const Matrix<ElemType>& prob, Matrix<ElemType>& alpha, Matrix<ElemType>& beta, const Matrix<ElemType>& phoneSeq, const Matrix<ElemType>& phoneBound, Matrix<ElemType>& totalScore,
|
||||
const vector<size_t> & extraUttMap, const vector<size_t> & uttBeginFrame, const vector<size_t> & uttFrameNum, const vector<size_t> & uttPhoneNum, const size_t samplesInRecurrentStep,
|
||||
const size_t mbSize, const size_t blankTokenId, const int delayConstraint, const bool isColWise);
|
||||
|
|
|
@ -162,6 +162,22 @@ private:
|
|||
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(Eframescorrectbuf),
|
||||
logEframescorrecttotal, totalfwscore);
|
||||
}
|
||||
/* guoye: start */
|
||||
void backwardlatticeEMBR(const size_t *batchsizebackward, const size_t numlaunchbackward,
|
||||
const floatvector &edgeacscores, const edgeinfowithscoresvector &edges,
|
||||
const nodeinfovector &nodes, doublevector &edgelogbetas, doublevector &logbetas,
|
||||
const float lmf, const float wp, const float amf, double &totalbwscore)
|
||||
{
|
||||
ondevice no(deviceid);
|
||||
latticefunctionsops::backwardlatticeEMBR(batchsizebackward, numlaunchbackward,
|
||||
dynamic_cast<const vectorbaseimpl<floatvector, vectorref<float>> &>(edgeacscores),
|
||||
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
|
||||
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
|
||||
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(edgelogbetas),
|
||||
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logbetas),
|
||||
lmf, wp, amf, totalbwscore);
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
void sMBRerrorsignal(const ushortvector &alignstateids,
|
||||
const uintvector &alignoffsets,
|
||||
|
@ -183,6 +199,25 @@ private:
|
|||
logEframescorrecttotal, dengammasMatrixRef, dengammasbufMatrixRef);
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
void EMBRerrorsignal(const ushortvector &alignstateids,
|
||||
const uintvector &alignoffsets,
|
||||
const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
|
||||
const doublevector &edgeweights,
|
||||
Microsoft::MSR::CNTK::Matrix<float> &dengammas)
|
||||
{
|
||||
ondevice no(deviceid);
|
||||
|
||||
matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
|
||||
|
||||
latticefunctionsops::EMBRerrorsignal(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
|
||||
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
|
||||
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
|
||||
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
|
||||
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(edgeweights),
|
||||
dengammasMatrixRef);
|
||||
}
|
||||
/* guoye: end */
|
||||
void mmierrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets,
|
||||
const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
|
||||
const doublevector &logpps, Microsoft::MSR::CNTK::Matrix<float> &dengammas)
|
||||
|
|
|
@ -99,6 +99,19 @@ struct latticefunctions : public vectorbase<msra::lattices::empty>
|
|||
doublevector& logaccalphas, doublevector& logaccbetas,
|
||||
doublevector& logframescorrectedge, doublevector& logEframescorrect,
|
||||
doublevector& Eframescorrectbuf, double& logEframescorrecttotal, double& totalfwscore) = 0;
|
||||
|
||||
/* guoye: start */
|
||||
virtual void backwardlatticeEMBR(const size_t* batchsizebackward, const size_t numlaunchbackward,
|
||||
const floatvector& edgeacscores, const edgeinfowithscoresvector& edges,
|
||||
const nodeinfovector& nodes, doublevector& edgelogbetas, doublevector& logbetas,
|
||||
const float lmf, const float wp, const float amf, double& totalbwscore) = 0;
|
||||
|
||||
virtual void EMBRerrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets,
|
||||
const edgeinfowithscoresvector& edges, const nodeinfovector& nodes,
|
||||
const doublevector& edgeweights, Microsoft::MSR::CNTK::Matrix<float>& dengammas) = 0;
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
virtual void sMBRerrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets,
|
||||
const edgeinfowithscoresvector& edges, const nodeinfovector& nodes,
|
||||
const doublevector& logpps, const float amf, const doublevector& logEframescorrect,
|
||||
|
|
|
@ -226,7 +226,25 @@ __global__ void backwardlatticej(const size_t batchsize, const size_t startindex
|
|||
logEframescorrect, logaccbetas);
|
||||
}
|
||||
}
|
||||
/* guoye: start */
|
||||
__global__ void backwardlatticejEMBR(const size_t batchsize, const size_t startindex, const vectorref<float> edgeacscores,
|
||||
vectorref<msra::lattices::edgeinfowithscores> edges, vectorref<msra::lattices::nodeinfo> nodes,
|
||||
vectorref<double> edgelogbetas, vectorref<double> logbetas,
|
||||
float lmf, float wp, float amf)
|
||||
{
|
||||
const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block
|
||||
const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
size_t j = jinblock + blockIdx.x * tpb;
|
||||
if (j < batchsize) // note: will cause issues if we ever use __synctreads()
|
||||
{
|
||||
msra::lattices::latticefunctionskernels::backwardlatticejEMBR(j + startindex, edgeacscores,
|
||||
edges, nodes, edgelogbetas, logbetas,
|
||||
lmf, wp, amf);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
/* guoye: end */
|
||||
void latticefunctionsops::forwardbackwardlattice(const size_t *batchsizeforward, const size_t *batchsizebackward,
|
||||
const size_t numlaunchforward, const size_t numlaunchbackward,
|
||||
const size_t spalignunitid, const size_t silalignunitid,
|
||||
|
@ -326,6 +344,46 @@ void latticefunctionsops::forwardbackwardlattice(const size_t *batchsizeforward,
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
void latticefunctionsops::backwardlatticeEMBR( const size_t *batchsizebackward, const size_t numlaunchbackward,
|
||||
const vectorref<float> &edgeacscores,
|
||||
const vectorref<msra::lattices::edgeinfowithscores> &edges,
|
||||
const vectorref<msra::lattices::nodeinfo> &nodes, vectorref<double> &edgelogbetas, vectorref<double> &logbetas,
|
||||
const float lmf, const float wp, const float amf, double &totalbwscore) const
|
||||
{
|
||||
// initialize log{,acc}(alhas/betas)
|
||||
dim3 t(32, 8);
|
||||
const size_t tpb = t.x * t.y;
|
||||
dim3 b((unsigned int)((logbetas.size() + tpb - 1) / tpb));
|
||||
|
||||
// TODO: is this really efficient? One thread per value?
|
||||
setvaluej << <b, t, 0, GetCurrentStream() >> >(logbetas, LOGZERO, logbetas.size());
|
||||
checklaunch("setvaluej");
|
||||
|
||||
// set initial tokens to probability 1 (0 in log)
|
||||
double log1 = 0.0;
|
||||
memcpy(logbetas.get(), nodes.size() - 1, &log1, 1);
|
||||
|
||||
|
||||
// backward pass
|
||||
size_t startindex = 0;
|
||||
startindex = edges.size();
|
||||
for (size_t i = 0; i < numlaunchbackward; i++)
|
||||
{
|
||||
dim3 b2((unsigned int)((batchsizebackward[i] + tpb - 1) / tpb));
|
||||
backwardlatticejEMBR << <b2, t, 0, GetCurrentStream() >> >(batchsizebackward[i], startindex - batchsizebackward[i],
|
||||
edgeacscores, edges, nodes, edgelogbetas, logbetas,
|
||||
lmf, wp, amf);
|
||||
|
||||
|
||||
checklaunch("edgealignment");
|
||||
startindex -= batchsizebackward[i];
|
||||
}
|
||||
memcpy<double>(&totalbwscore, logbetas.get(), 0, 1);
|
||||
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
// -----------------------------------------------------------------------
|
||||
// sMBRerrorsignal -- accumulate difference of logEframescorrect and logEframescorrecttotal into errorsignal
|
||||
// -----------------------------------------------------------------------
|
||||
|
@ -342,6 +400,22 @@ __global__ void sMBRerrorsignalj(const vectorref<unsigned short> alignstateids,
|
|||
msra::lattices::latticefunctionskernels::sMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg);
|
||||
}
|
||||
}
|
||||
/* guoye: start */
|
||||
|
||||
__global__ void EMBRerrorsignalj(const vectorref<unsigned short> alignstateids, const vectorref<unsigned int> alignoffsets,
|
||||
const vectorref<msra::lattices::edgeinfowithscores> edges, const vectorref<msra::lattices::nodeinfo> nodes,
|
||||
vectorref<double> edgeweights,
|
||||
matrixref<float> errorsignal)
|
||||
{
|
||||
const size_t shufflemode = 1; // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle)
|
||||
const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode);
|
||||
if (j < edges.size()) // note: will cause issues if we ever use __synctreads()
|
||||
{
|
||||
msra::lattices::latticefunctionskernels::EMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
|
||||
}
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// stateposteriors --accumulate a per-edge quantity into the states that the edge is aligned with
|
||||
|
@ -433,6 +507,28 @@ void latticefunctionsops::sMBRerrorsignal(const vectorref<unsigned short> &align
|
|||
#endif
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
void latticefunctionsops::EMBRerrorsignal(const vectorref<unsigned short> &alignstateids, const vectorref<unsigned int> &alignoffsets,
|
||||
const vectorref<msra::lattices::edgeinfowithscores> &edges, const vectorref<msra::lattices::nodeinfo> &nodes,
|
||||
const vectorref<double> &edgeweights,
|
||||
matrixref<float> &errorsignal) const
|
||||
{
|
||||
// Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks.
|
||||
// This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well.
|
||||
const size_t numedges = edges.size();
|
||||
dim3 t(32, 8);
|
||||
const size_t tpb = t.x * t.y;
|
||||
dim3 b((unsigned int)((numedges + tpb - 1) / tpb));
|
||||
|
||||
setvaluei << <dim3((((unsigned int)errorsignal.rows()) + 31) / 32), 32, 0, GetCurrentStream() >> >(errorsignal, 0);
|
||||
checklaunch("setvaluei");
|
||||
|
||||
EMBRerrorsignalj << <b, t, 0, GetCurrentStream() >> >(alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
|
||||
checklaunch("EMBRerrorsignal");
|
||||
|
||||
|
||||
}
|
||||
/* guoye: end */
|
||||
void latticefunctionsops::mmierrorsignal(const vectorref<unsigned short> &alignstateids, const vectorref<unsigned int> &alignoffsets,
|
||||
const vectorref<msra::lattices::edgeinfowithscores> &edges, const vectorref<msra::lattices::nodeinfo> &nodes,
|
||||
const vectorref<double> &logpps, matrixref<float> &errorsignal) const
|
||||
|
|
|
@ -53,6 +53,17 @@ protected:
|
|||
vectorref<double>& logframescorrectedge, vectorref<double>& logEframescorrect, vectorref<double>& Eframescorrectbuf,
|
||||
double& logEframescorrecttotal, double& totalfwscore) const;
|
||||
|
||||
/* guoye: start */
|
||||
void backwardlatticeEMBR(const size_t *batchsizebackward, const size_t numlaunchbackward,
|
||||
const vectorref<float> &edgeacscores,
|
||||
const vectorref<msra::lattices::edgeinfowithscores> &edges,
|
||||
const vectorref<msra::lattices::nodeinfo> &nodes, vectorref<double> &edgelogbetas, vectorref<double> &logbetas,
|
||||
const float lmf, const float wp, const float amf, double &totalbwscore) const;
|
||||
void EMBRerrorsignal(const vectorref<unsigned short> &alignstateids, const vectorref<unsigned int> &alignoffsets,
|
||||
const vectorref<msra::lattices::edgeinfowithscores> &edges, const vectorref<msra::lattices::nodeinfo> &nodes,
|
||||
const vectorref<double> &edgeweights,
|
||||
matrixref<float> &errorsignal) const;
|
||||
/* guoye: end */
|
||||
void sMBRerrorsignal(const vectorref<unsigned short>& alignstateids, const vectorref<unsigned int>& alignoffsets,
|
||||
const vectorref<msra::lattices::edgeinfowithscores>& edges, const vectorref<msra::lattices::nodeinfo>& nodes,
|
||||
const vectorref<double>& logpps, const float amf, const vectorref<double>& logEframescorrect, const double logEframescorrecttotal,
|
||||
|
|
|
@ -302,6 +302,25 @@ struct latticefunctionskernels
|
|||
// note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always)
|
||||
return bitsasfloat(old);
|
||||
}
|
||||
/* guoye: start */
|
||||
|
||||
template <typename FLOAT> // adapted from [http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ixzz32EuzZjxV]
|
||||
static __device__ FLOAT atomicAdd(FLOAT *address, FLOAT val) // direct adaptation from NVidia source code
|
||||
{
|
||||
typedef decltype(floatasbits(val)) bitstype;
|
||||
bitstype *address_as_ull = (bitstype *)address;
|
||||
bitstype old = *address_as_ull, assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
FLOAT sum = bitsasfloat(assumed);
|
||||
sum = sum + val;
|
||||
old = atomicCAS(address_as_ull, assumed, floatasbits(sum));
|
||||
} while (assumed != old);
|
||||
// note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always)
|
||||
return bitsasfloat(old);
|
||||
}
|
||||
/* guoye: end */
|
||||
#else // this code does not work because (assumed != old) will not compare correctly in case of NaNs
|
||||
// same pattern as atomicAdd(), but performing the log-add operation instead
|
||||
template <typename FLOAT>
|
||||
|
@ -889,6 +908,66 @@ struct latticefunctionskernels
|
|||
logEframescorrect[j] = logEframescorrectj;
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
template <typename edgeinforvector, typename nodeinfovector, typename floatvector, typename doublevector>
|
||||
static inline __device__ void backwardlatticejEMBR(size_t j, const floatvector &edgeacscores,
|
||||
const edgeinforvector &edges, const nodeinfovector &nodes, doublevector & edgelogbetas,
|
||||
doublevector &logbetas, float lmf, float wp, float amf)
|
||||
{
|
||||
|
||||
// edge info
|
||||
const edgeinfowithscores &e = edges[j];
|
||||
double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
|
||||
// zhaorui to deal with the abnormal score for sent start.
|
||||
if (e.l < -200.0f)
|
||||
edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
|
||||
|
||||
|
||||
|
||||
#ifdef FORBID_INVALID_SIL_PATHS
|
||||
// original mode
|
||||
const bool forbidinvalidsilpath = (logbetas.size() > nodes.size()); // we prune sil to sil path if alphabetablowup != 1
|
||||
const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1); // HACK: 'unused' indicates artificially added sil/sp edge
|
||||
|
||||
if (!isaddedsil) // original mode
|
||||
#endif
|
||||
{
|
||||
const size_t S = e.S;
|
||||
const size_t E = e.E;
|
||||
|
||||
// backward pass
|
||||
const double inscore = logbetas[E];
|
||||
const double pathscore = inscore + edgescore;
|
||||
edgelogbetas[j] = pathscore;
|
||||
atomicLogAdd(&logbetas[S], pathscore);
|
||||
}
|
||||
|
||||
#ifdef FORBID_INVALID_SIL_PATHS
|
||||
|
||||
// silence edge or second speech edge
|
||||
if ((isaddedsil && e.E != nodes.size() - 1) || (forbidinvalidsilpath && e.S != 0))
|
||||
{
|
||||
const size_t S = (size_t)(!isaddedsil ? e.S + nodes.size() : e.S); // second speech edge comes from special 'silence state' node
|
||||
const size_t E = (size_t)(isaddedsil ? e.E + nodes.size() : e.E); // silence edge goes into special 'silence state' node
|
||||
// remaining lines here are code dup from above, with two changes: logadd2/logEframescorrectj2 instead of logadd/logEframescorrectj
|
||||
|
||||
// backward pass
|
||||
const double inscore = logbetas[E];
|
||||
const double pathscore = inscore + edgescore;
|
||||
edgelogbetas[j] = pathscore;
|
||||
atomicLogAdd(&logbetas[S], pathscore);
|
||||
|
||||
}
|
||||
#else
|
||||
nodes;
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
template <typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
|
||||
static inline __device__ void sMBRerrorsignalj(size_t j, const ushortvector &alignstateids, const uintvector &alignoffsets,
|
||||
const edgeinfowithscoresvector &edges,
|
||||
|
@ -930,6 +1009,63 @@ struct latticefunctionskernels
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
template <typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
|
||||
static inline __device__ void EMBRerrorsignalj(size_t j, const ushortvector &alignstateids, const uintvector &alignoffsets,
|
||||
const edgeinfowithscoresvector &edges,
|
||||
const nodeinfovector &nodes, const doublevector &edgeweights,
|
||||
matrix &errorsignal)
|
||||
{
|
||||
size_t ts = nodes[edges[j].S].t;
|
||||
size_t te = nodes[edges[j].E].t;
|
||||
if (ts != te)
|
||||
{
|
||||
|
||||
|
||||
const float weight = (float)(edgeweights[j]);
|
||||
size_t offset = alignoffsets[j];
|
||||
|
||||
/*
|
||||
if (weight <= 1)
|
||||
{
|
||||
*/
|
||||
// size_t k = 0;
|
||||
for (size_t t = ts; t < te; t++)
|
||||
{
|
||||
const size_t s = (size_t)alignstateids[t - ts + offset];
|
||||
|
||||
/*
|
||||
errorsignal(0, k) = float(t);
|
||||
k = k + 1;
|
||||
errorsignal(0, k) = float(s);
|
||||
k = k + 1;
|
||||
*/
|
||||
// atomicLogAdd(&errorsignal(s, t), weight);
|
||||
// use atomic function for lock the value
|
||||
atomicAdd(&errorsignal(s, t), weight);
|
||||
//errorsignal(s, t) = errorsignal(s, t) + weight;
|
||||
// errorsignal(s, t) = errorsignal(s, t) + (float)(ts);
|
||||
}
|
||||
// }
|
||||
|
||||
|
||||
/* guoye: start */
|
||||
/*
|
||||
if (weight > 1)
|
||||
{
|
||||
for (size_t t = 63; t < 71; t++)
|
||||
errorsignal(7, t) = errorsignal(7, t) + ts;
|
||||
errorsignal(8, 71) = errorsignal(8, 71) + te;
|
||||
|
||||
}
|
||||
*/
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
// accumulate a per-edge quantity into the states that the edge is aligned with
|
||||
// Use this for MMI passing the edge posteriors logpps[] as logq, or for sMBR passing logEframescorrect[].
|
||||
// j=edge index, alignment in (alignstateids, alignoffsets)
|
||||
|
|
|
@ -26,6 +26,10 @@
|
|||
#include "ScriptableObjects.h"
|
||||
#include "HTKMLFReader.h"
|
||||
#include "TimerUtility.h"
|
||||
/* guoye: start */
|
||||
#include "fileutil.h"
|
||||
#include <string>
|
||||
/* guoye: end */
|
||||
#ifdef LEAKDETECT
|
||||
#include <vld.h> // for memory leak detection
|
||||
#endif
|
||||
|
@ -99,6 +103,33 @@ void HTKMLFReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConfig
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
void readwordidmap(const std::wstring &pathname, std::unordered_map<std::string, int>& wordidmap, int start_id)
|
||||
{
|
||||
std::unordered_map<std::string, int>::iterator mp_itr;
|
||||
auto_file_ptr f(fopenOrDie(pathname, L"rbS"));
|
||||
fprintf(stderr, "readwordidmap: reading %ls \n", pathname.c_str());
|
||||
char buf[1024];
|
||||
char word[1024];
|
||||
int dumid;
|
||||
while (!feof(f))
|
||||
{
|
||||
fgetline(f, buf);
|
||||
if (sscanf(buf, "%s %d", word, &dumid) != 2)
|
||||
{
|
||||
fprintf(stderr, "readwordidmap: reaching the end of line, with content = %s", buf);
|
||||
break;
|
||||
}
|
||||
if (wordidmap.find(std::string(word)) == wordidmap.end())
|
||||
{
|
||||
wordidmap.insert(pair<std::string, int>(string(word),start_id++));
|
||||
}
|
||||
}
|
||||
|
||||
fclose(f);
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
// Load all input and output data.
|
||||
// Note that the terms features imply be real-valued quantities and
|
||||
// labels imply categorical quantities, irrespective of whether they
|
||||
|
@ -116,6 +147,9 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
vector<vector<wstring>> infilesmulti;
|
||||
size_t numFiles;
|
||||
wstring unigrampath(L"");
|
||||
/* guoye: start */
|
||||
wstring wordidmappath(L"");
|
||||
/* guoye: end */
|
||||
|
||||
size_t randomize = randomizeAuto;
|
||||
size_t iFeat, iLabel;
|
||||
|
@ -443,19 +477,146 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
if (readerConfig.Exists(L"unigram"))
|
||||
unigrampath = (const wstring&) readerConfig(L"unigram");
|
||||
|
||||
/*guoye: start */
|
||||
if (readerConfig.Exists(L"wordidmap"))
|
||||
wordidmappath = (const wstring&)readerConfig(L"wordidmap");
|
||||
/*guoye: end */
|
||||
|
||||
// load a unigram if needed (this is used for MMI training)
|
||||
msra::lm::CSymbolSet unigramsymbols;
|
||||
/* guoye: start */
|
||||
std::set<int> specialwordids;
|
||||
std::vector<string> specialwords;
|
||||
std::unordered_map<std::string, int> wordidmap;
|
||||
std::unordered_map<std::string, int>::iterator wordidmap_itr;
|
||||
/* guoye: end */
|
||||
|
||||
std::unique_ptr<msra::lm::CMGramLM> unigram;
|
||||
size_t silencewordid = SIZE_MAX;
|
||||
size_t startwordid = SIZE_MAX;
|
||||
size_t endwordid = SIZE_MAX;
|
||||
/* guoye: debug */
|
||||
if (unigrampath != L"")
|
||||
// if(true)
|
||||
{
|
||||
// RuntimeError("should not come here.");
|
||||
/* guoye: start (this code order must be consistent with dbn.exe in main.cpp */
|
||||
|
||||
unigram.reset(new msra::lm::CMGramLM());
|
||||
|
||||
unigramsymbols["!NULL"];
|
||||
unigramsymbols["<s>"];
|
||||
unigramsymbols["</s>"];
|
||||
unigramsymbols["!sent_start"];
|
||||
unigramsymbols["!sent_end"];
|
||||
unigramsymbols["!silence"];
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
unigram->read(unigrampath, unigramsymbols, false /*filterVocabulary--false will build the symbol map*/, 1 /*maxM--unigram only*/);
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
silencewordid = unigramsymbols["!silence"]; // give this an id (even if not in the LM vocabulary)
|
||||
startwordid = unigramsymbols["<s>"];
|
||||
endwordid = unigramsymbols["</s>"];
|
||||
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
specialwordids.clear();
|
||||
|
||||
|
||||
|
||||
specialwordids.insert(unigramsymbols["<s>"]);
|
||||
specialwordids.insert(unigramsymbols["</s>"]);
|
||||
specialwordids.insert(unigramsymbols["!NULL"]);
|
||||
specialwordids.insert(unigramsymbols["!sent_start"]);
|
||||
specialwordids.insert(unigramsymbols["!sent_end"]);
|
||||
specialwordids.insert(unigramsymbols["!silence"]);
|
||||
specialwordids.insert(unigramsymbols["[/CNON]"]);
|
||||
specialwordids.insert(unigramsymbols["[/CSPN]"]);
|
||||
specialwordids.insert(unigramsymbols["[/NPS]"]);
|
||||
specialwordids.insert(unigramsymbols["[CNON/]"]);
|
||||
specialwordids.insert(unigramsymbols["[CNON]"]);
|
||||
specialwordids.insert(unigramsymbols["[CSPN]"]);
|
||||
specialwordids.insert(unigramsymbols["[FILL/]"]);
|
||||
specialwordids.insert(unigramsymbols["[NON/]"]);
|
||||
specialwordids.insert(unigramsymbols["[NONNATIVE/]"]);
|
||||
specialwordids.insert(unigramsymbols["[NPS]"]);
|
||||
|
||||
specialwordids.insert(unigramsymbols["[SB/]"]);
|
||||
specialwordids.insert(unigramsymbols["[SBP/]"]);
|
||||
specialwordids.insert(unigramsymbols["[SN/]"]);
|
||||
specialwordids.insert(unigramsymbols["[SPN/]"]);
|
||||
specialwordids.insert(unigramsymbols["[UNKNOWN/]"]);
|
||||
specialwordids.insert(unigramsymbols[".]"]);
|
||||
|
||||
// this is to exclude the unknown words in lattice brought when merging the numerator lattice into denominator lattice.
|
||||
specialwordids.insert(0xfffff);
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
|
||||
else if (wordidmappath != L"")
|
||||
// if(true)
|
||||
{
|
||||
wordidmap.insert(pair<std::string, int>("!NULL", 0));
|
||||
wordidmap.insert(pair<std::string, int>("<s>", 1));
|
||||
wordidmap.insert(pair<std::string, int>("</s>", 2));
|
||||
wordidmap.insert(pair<std::string, int>("!sent_start", 3));
|
||||
wordidmap.insert(pair<std::string, int>("!sent_end", 4));
|
||||
wordidmap.insert(pair<std::string, int>("!silence", 5));
|
||||
|
||||
silencewordid = 5; // give this an id (even if not in the LM vocabulary)
|
||||
startwordid = 1;
|
||||
endwordid = 2;
|
||||
|
||||
int start_id = 6;
|
||||
readwordidmap(wordidmappath, wordidmap, start_id);
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
specialwordids.clear();
|
||||
specialwords.clear();
|
||||
|
||||
specialwords.push_back("<s>");
|
||||
|
||||
specialwords.push_back("</s>");
|
||||
specialwords.push_back("!NULL");
|
||||
specialwords.push_back("!sent_start");
|
||||
specialwords.push_back("!sent_end");
|
||||
specialwords.push_back("!silence");
|
||||
specialwords.push_back("[/CNON]");
|
||||
specialwords.push_back("[/CSPN]");
|
||||
specialwords.push_back("[/NPS]");
|
||||
specialwords.push_back("[CNON/]");
|
||||
specialwords.push_back("[CNON]");
|
||||
specialwords.push_back("[CSPN]");
|
||||
specialwords.push_back("[FILL/]");
|
||||
specialwords.push_back("[NON/]");
|
||||
specialwords.push_back("[NONNATIVE/]");
|
||||
specialwords.push_back("[NPS]");
|
||||
|
||||
specialwords.push_back("[SB/]");
|
||||
specialwords.push_back("[SBP/]");
|
||||
specialwords.push_back("[SN/]");
|
||||
specialwords.push_back("[SPN/]");
|
||||
specialwords.push_back("[UNKNOWN/]");
|
||||
specialwords.push_back(".]");
|
||||
|
||||
for (size_t i = 0; i < specialwords.size(); i++)
|
||||
{
|
||||
wordidmap_itr = wordidmap.find(specialwords[i]);
|
||||
specialwordids.insert((wordidmap_itr == wordidmap.end()) ? -1 : wordidmap_itr->second);
|
||||
}
|
||||
|
||||
// this is to exclude the unknown words in lattice brought when merging the numerator lattice into denominator lattice.
|
||||
specialwordids.insert(0xfffff);
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
|
||||
if (!unigram && latticetocs.second.size() > 0)
|
||||
|
@ -497,19 +658,41 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
|
||||
double htktimetoframe = 100000.0; // default is 10ms
|
||||
// std::vector<msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>> labelsmulti;
|
||||
std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
|
||||
/* guoye: start */
|
||||
// std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
|
||||
std::vector<std::map<std::wstring, std::pair<std::vector<msra::asr::htkmlfentry>, std::vector<unsigned int>>>> labelsmulti;
|
||||
// std::vector<std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence>> wordlabelsmulti;
|
||||
|
||||
/* debug to clean wordidmap */
|
||||
// wordidmap.clear();
|
||||
/* guoye: end */
|
||||
// std::vector<std::wstring> pagepath;
|
||||
foreach_index (i, mlfpathsmulti)
|
||||
{
|
||||
/* guoye: start */
|
||||
/*
|
||||
const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
|
||||
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
|
||||
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
|
||||
*/
|
||||
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
|
||||
// msra::asr::htkmlfreader<msra::asr::htkmlfentry>
|
||||
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordidmap, htktimetoframe); // label MLF
|
||||
// labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordidmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
|
||||
/* guoye: end */
|
||||
// get the temp file name for the page file
|
||||
|
||||
// Make sure 'msra::asr::htkmlfreader' type has a move constructor
|
||||
static_assert(std::is_move_constructible<msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>>::value,
|
||||
"Type 'msra::asr::htkmlfreader' should be move constructible!");
|
||||
|
||||
/* guoye: start */
|
||||
// map<wstring, msra::lattices::lattice::htkmlfwordsequence> wordlabels = labels.get_wordlabels();
|
||||
// guoye debug purpose
|
||||
// fprintf(stderr, "debug to set wordlabels to empty");
|
||||
// map<wstring, msra::lattices::lattice::htkmlfwordsequence> wordlabels;
|
||||
// wordlabelsmulti.push_back(std::move(wordlabels));
|
||||
/* guoye: end */
|
||||
labelsmulti.push_back(std::move(labels));
|
||||
}
|
||||
|
||||
|
@ -522,7 +705,11 @@ void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType&
|
|||
|
||||
// now get the frame source. This has better randomization and doesn't create temp files
|
||||
bool useMersenneTwisterRand = readerConfig(L"useMersenneTwisterRand", false);
|
||||
m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims,
|
||||
/* guoye: start */
|
||||
// m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims,
|
||||
// m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, wordlabelsmulti, specialwordids, m_featDims, m_labelDims,
|
||||
m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, specialwordids, m_featDims, m_labelDims,
|
||||
/* guoye: end */
|
||||
numContextLeft, numContextRight, randomize,
|
||||
*m_lattices, m_latticeMap, m_frameMode,
|
||||
m_expandToUtt, m_maxUtteranceLength, m_truncated));
|
||||
|
@ -756,6 +943,10 @@ void HTKMLFReader<ElemType>::StartDistributedMinibatchLoop(size_t requestedMBSiz
|
|||
// for the multi-utterance process for lattice and phone boundary
|
||||
m_latticeBufferMultiUtt.assign(m_numSeqsPerMB, nullptr);
|
||||
m_labelsIDBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
/* guoye: start */
|
||||
m_wlabelsIDBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
m_nwsBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
/* guoye: end */
|
||||
m_phoneboundaryIDBufferMultiUtt.resize(m_numSeqsPerMB);
|
||||
|
||||
if (m_frameMode && (m_numSeqsPerMB > 1))
|
||||
|
@ -894,11 +1085,17 @@ void HTKMLFReader<ElemType>::StartMinibatchLoopToWrite(size_t mbSize, size_t /*e
|
|||
|
||||
template <class ElemType>
|
||||
bool HTKMLFReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput,
|
||||
vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
/* guoye: start */
|
||||
vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
// vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
|
||||
/* guoye: end */
|
||||
{
|
||||
if (m_trainOrTest)
|
||||
{
|
||||
return GetMinibatch4SEToTrainOrTest(latticeinput, uids, boundaries, extrauttmap);
|
||||
/* guoye: start */
|
||||
// return GetMinibatch4SEToTrainOrTest(latticeinput, uids, boundaries, extrauttmap);
|
||||
return GetMinibatch4SEToTrainOrTest(latticeinput, uids, wids, nws, boundaries, extrauttmap);
|
||||
/* guoye: end */
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -907,16 +1104,31 @@ bool HTKMLFReader<ElemType>::GetMinibatch4SE(std::vector<shared_ptr<const msra::
|
|||
}
|
||||
template <class ElemType>
|
||||
bool HTKMLFReader<ElemType>::GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput,
|
||||
std::vector<size_t>& uids, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
|
||||
|
||||
/* guoye: start */
|
||||
std::vector<size_t>& uids, std::vector<size_t>& wids, std::vector<short>& nws, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
|
||||
// std::vector<size_t>& uids, std::vector<size_t>& boundaries, std::vector<size_t>& extrauttmap)
|
||||
|
||||
/* guoye: end */
|
||||
{
|
||||
latticeinput.clear();
|
||||
uids.clear();
|
||||
/* guoye: start */
|
||||
wids.clear();
|
||||
nws.clear();
|
||||
/* guoye: end */
|
||||
boundaries.clear();
|
||||
extrauttmap.clear();
|
||||
for (size_t i = 0; i < m_extraSeqsPerMB.size(); i++)
|
||||
{
|
||||
latticeinput.push_back(m_extraLatticeBufferMultiUtt[i]);
|
||||
uids.insert(uids.end(), m_extraLabelsIDBufferMultiUtt[i].begin(), m_extraLabelsIDBufferMultiUtt[i].end());
|
||||
/* guoye: start */
|
||||
wids.insert(wids.end(), m_extraWLabelsIDBufferMultiUtt[i].begin(), m_extraWLabelsIDBufferMultiUtt[i].end());
|
||||
|
||||
nws.insert(nws.end(), m_extraNWsBufferMultiUtt[i].begin(), m_extraNWsBufferMultiUtt[i].end());
|
||||
|
||||
/* guoye: end */
|
||||
boundaries.insert(boundaries.end(), m_extraPhoneboundaryIDBufferMultiUtt[i].begin(), m_extraPhoneboundaryIDBufferMultiUtt[i].end());
|
||||
}
|
||||
|
||||
|
@ -984,6 +1196,11 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
|
|||
m_extraLabelsIDBufferMultiUtt.clear();
|
||||
m_extraPhoneboundaryIDBufferMultiUtt.clear();
|
||||
m_extraSeqsPerMB.clear();
|
||||
/* guoye: start */
|
||||
m_extraWLabelsIDBufferMultiUtt.clear();
|
||||
|
||||
m_extraNWsBufferMultiUtt.clear();
|
||||
/* guoye: end */
|
||||
if (m_noData && m_numFramesToProcess[0] == 0) // no data left for the first channel of this minibatch,
|
||||
{
|
||||
return false;
|
||||
|
@ -1064,6 +1281,11 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
|
|||
{
|
||||
m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[i]);
|
||||
m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[i]);
|
||||
/* guoye: start */
|
||||
m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[i]);
|
||||
|
||||
m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[i]);
|
||||
/* guoye: end */
|
||||
m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[i]);
|
||||
}
|
||||
}
|
||||
|
@ -1106,6 +1328,12 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma
|
|||
{
|
||||
m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[src]);
|
||||
m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[src]);
|
||||
/* guoye: start */
|
||||
m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[src]);
|
||||
|
||||
m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[src]);
|
||||
|
||||
/* guoye: end */
|
||||
m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[src]);
|
||||
}
|
||||
|
||||
|
@ -1811,6 +2039,15 @@ bool HTKMLFReader<ElemType>::ReNewBufferForMultiIO(size_t i)
|
|||
m_phoneboundaryIDBufferMultiUtt[i] = m_mbiter->bounds();
|
||||
m_labelsIDBufferMultiUtt[i].clear();
|
||||
m_labelsIDBufferMultiUtt[i] = m_mbiter->labels();
|
||||
/* guoye: start */
|
||||
m_wlabelsIDBufferMultiUtt[i].clear();
|
||||
m_wlabelsIDBufferMultiUtt[i] = m_mbiter->wlabels();
|
||||
|
||||
m_nwsBufferMultiUtt[i].clear();
|
||||
m_nwsBufferMultiUtt[i] = m_mbiter->nwords();
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
|
||||
m_processedFrame[i] = 0;
|
||||
|
@ -2031,8 +2268,7 @@ unique_ptr<CUDAPageLockedMemAllocator>& HTKMLFReader<ElemType>::GetCUDAAllocator
|
|||
if (m_cudaAllocator == nullptr)
|
||||
{
|
||||
m_cudaAllocator.reset(new CUDAPageLockedMemAllocator(deviceID));
|
||||
}
|
||||
|
||||
}
|
||||
return m_cudaAllocator;
|
||||
}
|
||||
|
||||
|
@ -2049,6 +2285,7 @@ std::shared_ptr<ElemType> HTKMLFReader<ElemType>::AllocateIntermediateBuffer(int
|
|||
this->GetCUDAAllocator(deviceID)->Free((char*) p);
|
||||
});
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
return std::shared_ptr<ElemType>(new ElemType[numElements],
|
||||
|
@ -2059,6 +2296,9 @@ std::shared_ptr<ElemType> HTKMLFReader<ElemType>::AllocateIntermediateBuffer(int
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template class HTKMLFReader<float>;
|
||||
template class HTKMLFReader<double>;
|
||||
} } }
|
||||
|
||||
|
||||
|
|
|
@ -77,6 +77,16 @@ private:
|
|||
std::vector<std::vector<size_t>> m_phoneboundaryIDBufferMultiUtt;
|
||||
std::vector<shared_ptr<const msra::dbn::latticepair>> m_extraLatticeBufferMultiUtt;
|
||||
std::vector<std::vector<size_t>> m_extraLabelsIDBufferMultiUtt;
|
||||
|
||||
/* guoye: start */
|
||||
/* word labels */
|
||||
std::vector<std::vector<size_t>> m_wlabelsIDBufferMultiUtt;
|
||||
std::vector<std::vector<size_t>> m_extraWLabelsIDBufferMultiUtt;
|
||||
|
||||
std::vector<std::vector<short>> m_nwsBufferMultiUtt;
|
||||
std::vector<std::vector<short>> m_extraNWsBufferMultiUtt;
|
||||
|
||||
/* guoye: end */
|
||||
std::vector<std::vector<size_t>> m_extraPhoneboundaryIDBufferMultiUtt;
|
||||
|
||||
// hmm
|
||||
|
@ -109,7 +119,10 @@ private:
|
|||
void PrepareForWriting(const ConfigRecordType& config);
|
||||
|
||||
bool GetMinibatchToTrainOrTest(StreamMinibatchInputs& matrices);
|
||||
bool GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, std::vector<size_t>& extrauttmap);
|
||||
/* guoye: start */
|
||||
// bool GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, std::vector<size_t>& extrauttmap);
|
||||
bool GetMinibatch4SEToTrainOrTest(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, std::vector<size_t>& extrauttmap);
|
||||
/* guoye: end */
|
||||
void fillOneUttDataforParallelmode(StreamMinibatchInputs& matrices, size_t startFr, size_t framenum, size_t channelIndex, size_t sourceChannelIndex); // TODO: PascalCase()
|
||||
bool GetMinibatchToWrite(StreamMinibatchInputs& matrices);
|
||||
|
||||
|
@ -189,7 +202,10 @@ public:
|
|||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
|
||||
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
|
||||
/* guoye: start */
|
||||
// virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
|
||||
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
|
||||
/* guoye: end */
|
||||
virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm);
|
||||
|
||||
virtual bool DataEnd();
|
||||
|
|
|
@ -71,6 +71,7 @@
|
|||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<AdditionalIncludeDirectories Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">..\..\common\include;..\..\Math</AdditionalIncludeDirectories>
|
||||
<AdditionalIncludeDirectories Condition="'$(Configuration)|$(Platform)'=='Debug_CpuOnly|x64'">..\..\common\include;..\..\Math</AdditionalIncludeDirectories>
|
||||
<AdditionalOptions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">/bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
|
|
|
@ -864,14 +864,23 @@ public:
|
|||
setdata(ts, te, uid);
|
||||
}
|
||||
};
|
||||
|
||||
/* guoye: start */
|
||||
template <class ENTRY, class WORDSEQUENCE>
|
||||
class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
|
||||
|
||||
// vector<ENTRY> stores the state-level label, vector<size_t> stores the word-level label
|
||||
// template <class ENTRY>
|
||||
class htkmlfreader : public map<wstring, std::pair<vector<ENTRY>, vector<unsigned int>>> // [key][i] the data
|
||||
// class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
|
||||
/* guoye: end */
|
||||
{
|
||||
wstring curpath; // for error messages
|
||||
unordered_map<std::string, size_t> statelistmap; // for state <=> index
|
||||
/* guoye: start */
|
||||
map<wstring, WORDSEQUENCE> wordsequences; // [key] word sequences (if we are building word entries as well, for MMI)
|
||||
|
||||
/* guoye: end */
|
||||
std::unordered_map<std::string, size_t> symmap;
|
||||
|
||||
void strtok(char* s, const char* delim, vector<char*>& toks)
|
||||
{
|
||||
toks.resize(0);
|
||||
|
@ -900,10 +909,14 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
|
|||
return lines;
|
||||
}
|
||||
|
||||
template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
// template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
template <typename WORDSYMBOLTABLE>
|
||||
void parseentry(const vector<std::string>& lines, size_t line, const set<wstring>& restricttokeys,
|
||||
const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap,
|
||||
vector<typename WORDSEQUENCE::word>& wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo>& alignseqbuffer,
|
||||
/* guoye: start */
|
||||
const WORDSYMBOLTABLE* wordmap, /* const UNITSYMBOLTABLE* unitmap, */
|
||||
|
||||
// vector<typename WORDSEQUENCE::word>& wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo>& alignseqbuffer,
|
||||
/* guoye: end */
|
||||
const double htkTimeToFrame)
|
||||
{
|
||||
size_t idx = 0;
|
||||
|
@ -936,13 +949,25 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
|
|||
// don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs)
|
||||
if (!restricttokeys.empty() && restricttokeys.find(key) == restricttokeys.end())
|
||||
return;
|
||||
/* guoye: start */
|
||||
// vector<ENTRY>& entries = (*this)[key]; // this creates a new entry
|
||||
|
||||
vector<ENTRY>& entries = (*this)[key]; // this creates a new entry
|
||||
vector<ENTRY>& entries = (*this)[key].first; // this creates a new entry
|
||||
if (!entries.empty())
|
||||
malformed(msra::strfun::strprintf("duplicate entry '%ls'", key.c_str()));
|
||||
/* guoye: start */
|
||||
// malformed(msra::strfun::strprintf("duplicate entry '%ls'", key.c_str()));
|
||||
// do not want to die immediately
|
||||
fprintf(stderr,
|
||||
"Warning: duplicate entry: %ls \n",
|
||||
key.c_str());
|
||||
/* guoye: end */
|
||||
entries.resize(e - s);
|
||||
wordseqbuffer.resize(0);
|
||||
alignseqbuffer.resize(0);
|
||||
|
||||
// wordseqbuffer.resize(0);
|
||||
// alignseqbuffer.resize(0);
|
||||
vector<size_t>& wordids = (*this)[key].second;
|
||||
wordids.resize(0);
|
||||
/* guoye: end */
|
||||
vector<char*> toks;
|
||||
for (size_t i = s; i < e; i++)
|
||||
{
|
||||
|
@ -957,13 +982,100 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
|
|||
{
|
||||
if (toks.size() > 6 /*word entry are in this column*/)
|
||||
{
|
||||
// convert letter to uppercase
|
||||
if (strcmp(toks[6], "<s>") != 0
|
||||
&& strcmp(toks[6], "</s>") != 0
|
||||
&& strcmp(toks[6], "!sent_start") != 0
|
||||
&& strcmp(toks[6], "!sent_end") != 0
|
||||
&& strcmp(toks[6], "!silence") != 0)
|
||||
{
|
||||
for(size_t j = 0; j < strlen(toks[6]); j++)
|
||||
{
|
||||
if(toks[6][j] >= 'a' && toks[6][j] <= 'z')
|
||||
{
|
||||
toks[6][j] = toks[6][j] + 'A' - 'a';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char* w = toks[6]; // the word name
|
||||
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
|
||||
size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t) wid;
|
||||
wordseqbuffer.push_back(typename WORDSEQUENCE::word(wordindex, entries[i - s].firstframe, alignseqbuffer.size()));
|
||||
/* guoye: start */
|
||||
// For some alignment MLF the sentence start and end are both represented by <s>, we change sentence end <s> to be </s>
|
||||
if (i > s && strcmp(w, "<s>") == 0)
|
||||
{
|
||||
w = "</s>";
|
||||
}
|
||||
/* guoye: end */
|
||||
/*guoye: start */
|
||||
/* skip the words that are not used in WER computation */
|
||||
/* ugly hard code, will improve later */
|
||||
if (strcmp(w, "<s>") != 0
|
||||
&& strcmp(w, "</s>") != 0
|
||||
&& strcmp(w, "!NULL") != 0
|
||||
&& strcmp(w, "!sent_start") != 0
|
||||
&& strcmp(w, "!sent_end") != 0
|
||||
&& strcmp(w, "!silence") != 0
|
||||
&& strcmp(w, "[/CNON]") != 0
|
||||
&& strcmp(w, "[/CSPN]") != 0
|
||||
&& strcmp(w, "[/NPS]") != 0
|
||||
&& strcmp(w, "[CNON/]") != 0
|
||||
&& strcmp(w, "[CNON]") != 0
|
||||
&& strcmp(w, "[CSPN]") != 0
|
||||
&& strcmp(w, "[FILL/]") != 0
|
||||
&& strcmp(w, "[NON/]") != 0
|
||||
&& strcmp(w, "[NONNATIVE/]") != 0
|
||||
&& strcmp(w, "[NPS]") != 0
|
||||
&& strcmp(w, "[SB/]") != 0
|
||||
&& strcmp(w, "[SBP/]") != 0
|
||||
&& strcmp(w, "[SN/]") != 0
|
||||
&& strcmp(w, "[SPN/]") != 0
|
||||
&& strcmp(w, "[UNKNOWN/]") != 0
|
||||
&& strcmp(w, ".]") != 0
|
||||
)
|
||||
{
|
||||
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
|
||||
/* guoye: start */
|
||||
// size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t)wid;
|
||||
// wordseqbuffer.push_back(typename WORDSEQUENCE::word(wordindex, entries[i - s].firstframe, alignseqbuffer.size()));
|
||||
static const unsigned int unknownwordindex = 0xfffff;
|
||||
|
||||
// TNed the word, to try one more time if there is an OOV
|
||||
/*
|
||||
if (wid == -1)
|
||||
{
|
||||
// remove / \ * _ - to see if a match could be found
|
||||
char tnw[200];
|
||||
size_t i = 0, j = 0;
|
||||
while (w[i] != '\0')
|
||||
{
|
||||
if (w[i] != '\\' && w[i] != '/' && w[i] != '*' && w[i] != '-' && w[i] != '_')
|
||||
{
|
||||
tnw[j] = w[i]; j++;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
tnw[j] = '\0';
|
||||
|
||||
wid = (*wordmap)[tnw];
|
||||
|
||||
fprintf(stderr,
|
||||
"Warning: parseentry: wid = %d, new wid = %d, w = %s, tnw = %s \n",
|
||||
-1, wid, w, tnw);
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
size_t wordindex = (wid == -1) ? unknownwordindex : (size_t)wid;
|
||||
wordids.push_back(wordindex);
|
||||
/* guoye: end */
|
||||
}
|
||||
/*guoye: end */
|
||||
}
|
||||
/* guoye: start */
|
||||
/*
|
||||
if (unitmap)
|
||||
{
|
||||
|
||||
if (toks.size() > 4)
|
||||
{
|
||||
const char* u = toks[4]; // the triphone name
|
||||
|
@ -971,41 +1083,346 @@ class htkmlfreader : public map<wstring, vector<ENTRY>> // [key][i] the data
|
|||
if (iter == unitmap->end())
|
||||
RuntimeError("parseentry: unknown unit %s in utterance %ls", u, key.c_str());
|
||||
const size_t uid = iter->second;
|
||||
alignseqbuffer.push_back(typename WORDSEQUENCE::aligninfo(uid, 0 /*#frames--we accumulate*/));
|
||||
alignseqbuffer.push_back(typename WORDSEQUENCE::aligninfo(uid, 0 /*#frames--we accumulate*/ /* ));
|
||||
}
|
||||
|
||||
if (alignseqbuffer.empty())
|
||||
RuntimeError("parseentry: lonely senone entry at start without phone/word entry found, for utterance %ls", key.c_str());
|
||||
alignseqbuffer.back().frames += entries[i - s].numframes; // (we do not have an overflow check here, but should...)
|
||||
|
||||
}
|
||||
*/
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
if (wordmap) // if reading word sequences as well (for MMI), then record it (in a separate map)
|
||||
{
|
||||
if (!entries.empty() && wordseqbuffer.empty())
|
||||
RuntimeError("parseentry: got state alignment but no word-level info, although being requested, for utterance %ls", key.c_str());
|
||||
/* guoye: start */
|
||||
// if (!entries.empty() && wordseqbuffer.empty())
|
||||
if (!entries.empty() && wordids.empty())
|
||||
/* guoye: end */
|
||||
// RuntimeError("parseentry: got state alignment but no word-level info, although being requested, for utterance %ls", key.c_str());
|
||||
{
|
||||
fprintf(stderr,
|
||||
"Warning: parseentry: got state alignment but no word-level info, although being requested, for utterance %ls \n",
|
||||
key.c_str());
|
||||
}
|
||||
|
||||
// post-process silence
|
||||
// - first !silence -> !sent_start
|
||||
// - last !silence -> !sent_end
|
||||
int silence = (*wordmap)["!silence"];
|
||||
if (silence >= 0)
|
||||
else
|
||||
{
|
||||
int sentstart = (*wordmap)["!sent_start"]; // these must have been created
|
||||
int sentend = (*wordmap)["!sent_end"];
|
||||
// map first and last !silence to !sent_start and !sent_end, respectively
|
||||
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t) silence)
|
||||
wordseqbuffer.front().wordindex = sentstart;
|
||||
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t) silence)
|
||||
wordseqbuffer.back().wordindex = sentend;
|
||||
int silence = (*wordmap)["!silence"];
|
||||
if (silence >= 0)
|
||||
{
|
||||
int sentstart = (*wordmap)["!sent_start"]; // these must have been created
|
||||
int sentend = (*wordmap)["!sent_end"];
|
||||
// map first and last !silence to !sent_start and !sent_end, respectively
|
||||
/* guoye: start */
|
||||
/*
|
||||
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t) silence)
|
||||
wordseqbuffer.front().wordindex = sentstart;
|
||||
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t) silence)
|
||||
wordseqbuffer.back().wordindex = sentend;
|
||||
*/
|
||||
if (sentstart >= 0 && wordids.front() == (size_t)silence)
|
||||
wordids.front() = sentstart;
|
||||
if (sentend >= 0 && wordids.back() == (size_t)silence)
|
||||
wordids.back() = sentend;
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
/* guoye: end */
|
||||
// if (sentstart < 0 || sentend < 0 || silence < 0)
|
||||
// LogicError("parseentry: word map must contain !silence, !sent_start, and !sent_end");
|
||||
// implant
|
||||
/* guoye: start */
|
||||
/*
|
||||
auto& wordsequence = wordsequences[key]; // this creates the map entry
|
||||
wordsequence.words = wordseqbuffer; // makes a copy
|
||||
wordsequence.align = alignseqbuffer;
|
||||
*/
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
/* guoye: start */
|
||||
// template <typename UNITSYMBOLTABLE>
|
||||
void parseentry(const vector<std::string>& lines, size_t line, const set<wstring>& restricttokeys,
|
||||
/* guoye: start */
|
||||
const std::unordered_map<std::string, int>& wordidmap, /* const UNITSYMBOLTABLE* unitmap,
|
||||
vector<typename WORDSEQUENCE::word>& wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo>& alignseqbuffer,
|
||||
*/
|
||||
/* guoye: end */
|
||||
const double htkTimeToFrame)
|
||||
{
|
||||
|
||||
std::unordered_map<std::string, int>::const_iterator mp_itr;
|
||||
|
||||
size_t idx = 0;
|
||||
string filename = lines[idx++];
|
||||
while (filename == "#!MLF!#") // skip embedded duplicate MLF headers (so user can 'cat' MLFs)
|
||||
filename = lines[idx++];
|
||||
|
||||
// some mlf file have write errors, so skip malformed entry
|
||||
if (filename.length() < 3 || filename[0] != '"' || filename[filename.length() - 1] != '"')
|
||||
{
|
||||
fprintf(stderr, "warning: filename entry (%s)\n", filename.c_str());
|
||||
fprintf(stderr, "skip current mlf entry from line (%lu) until line (%lu).\n", (unsigned long)(line + idx), (unsigned long)(line + lines.size()));
|
||||
return;
|
||||
}
|
||||
|
||||
filename = filename.substr(1, filename.length() - 2); // strip quotes
|
||||
if (filename.find("*/") == 0)
|
||||
filename = filename.substr(2);
|
||||
#ifdef _MSC_VER
|
||||
wstring key = msra::strfun::utf16(regex_replace(filename, regex("\\.[^\\.\\\\/:]*$"), string())); // delete extension (or not if none)
|
||||
#else
|
||||
wstring key = msra::strfun::utf16(msra::dbn::removeExtension(filename)); // note that c++ 4.8 is incomplete for supporting regex
|
||||
#endif
|
||||
|
||||
// determine lines range
|
||||
size_t s = idx;
|
||||
size_t e = lines.size() - 1;
|
||||
// lines range: [s,e)
|
||||
|
||||
// don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs)
|
||||
if (!restricttokeys.empty() && restricttokeys.find(key) == restricttokeys.end())
|
||||
return;
|
||||
/* guoye: start */
|
||||
// vector<ENTRY>& entries = (*this)[key]; // this creates a new entry
|
||||
vector<ENTRY>& entries = (*this)[key].first;
|
||||
if (!entries.empty())
|
||||
//malformed(msra::strfun::strprintf("duplicate entry '%ls'", key.c_str()));
|
||||
// do not want to die immediately
|
||||
fprintf(stderr,
|
||||
"Warning: duplicate entry : %ls \n",
|
||||
key.c_str());
|
||||
|
||||
entries.resize(e - s);
|
||||
|
||||
vector<unsigned int>& wordids = (*this)[key].second;
|
||||
wordids.resize(0);
|
||||
/*
|
||||
wordseqbuffer.resize(0);
|
||||
alignseqbuffer.resize(0);
|
||||
*/
|
||||
/* guoye: end */
|
||||
vector<char*> toks;
|
||||
for (size_t i = s; i < e; i++)
|
||||
{
|
||||
// We can mutate the original string as it is no longer needed after tokenization
|
||||
strtok(const_cast<char*>(lines[i].c_str()), " \t", toks);
|
||||
if (statelistmap.size() == 0)
|
||||
entries[i - s].parse(toks, htkTimeToFrame);
|
||||
else
|
||||
entries[i - s].parsewithstatelist(toks, statelistmap, htkTimeToFrame, symmap);
|
||||
// if we also read word entries, do it here
|
||||
if (wordidmap.size() != 0)
|
||||
{
|
||||
if (toks.size() > 6 /*word entry are in this column*/)
|
||||
{
|
||||
|
||||
// convert word to uppercase
|
||||
if (strcmp(toks[6], "<s>") != 0
|
||||
&& strcmp(toks[6], "</s>") != 0
|
||||
&& strcmp(toks[6], "!sent_start") != 0
|
||||
&& strcmp(toks[6], "!sent_end") != 0
|
||||
&& strcmp(toks[6], "!silence") != 0)
|
||||
{
|
||||
for(size_t j = 0; j < strlen(toks[6]); j++)
|
||||
{
|
||||
if(toks[6][j] >= 'a' && toks[6][j] <= 'z')
|
||||
{
|
||||
toks[6][j] = toks[6][j] + 'A' - 'a';
|
||||
}
|
||||
}
|
||||
}
|
||||
const char* w = toks[6]; // the word name
|
||||
/* guoye: start */
|
||||
// For some alignment MLF the sentence start and end are both represented by <s>, we change sentence end <s> to be </s>
|
||||
if (i > s && strcmp(w, "<s>") == 0)
|
||||
{
|
||||
w = "</s>";
|
||||
}
|
||||
/* guoye: end */
|
||||
/*guoye: start */
|
||||
/* skip the words that are not used in WER computation */
|
||||
/* ugly hard code, will improve later */
|
||||
if (strcmp(w, "<s>") != 0
|
||||
&& strcmp(w, "</s>") != 0
|
||||
&& strcmp(w, "!NULL") != 0
|
||||
&& strcmp(w, "!sent_start") != 0
|
||||
&& strcmp(w, "!sent_end") != 0
|
||||
&& strcmp(w, "!silence") != 0
|
||||
&& strcmp(w, "[/CNON]") != 0
|
||||
&& strcmp(w, "[/CSPN]") != 0
|
||||
&& strcmp(w, "[/NPS]") != 0
|
||||
&& strcmp(w, "[CNON/]") != 0
|
||||
&& strcmp(w, "[CNON]") != 0
|
||||
&& strcmp(w, "[CSPN]") != 0
|
||||
&& strcmp(w, "[FILL/]") != 0
|
||||
&& strcmp(w, "[NON/]") != 0
|
||||
&& strcmp(w, "[NONNATIVE/]") != 0
|
||||
&& strcmp(w, "[NPS]") != 0
|
||||
&& strcmp(w, "[SB/]") != 0
|
||||
&& strcmp(w, "[SBP/]") != 0
|
||||
&& strcmp(w, "[SN/]") != 0
|
||||
&& strcmp(w, "[SPN/]") != 0
|
||||
&& strcmp(w, "[UNKNOWN/]") != 0
|
||||
&& strcmp(w, ".]") != 0
|
||||
)
|
||||
{
|
||||
// int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
|
||||
|
||||
mp_itr = wordidmap.find(std::string(w));
|
||||
int wid = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
|
||||
|
||||
// debug
|
||||
// int wid = -1;
|
||||
|
||||
/* guoye: start */
|
||||
// size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t)wid;
|
||||
|
||||
// guoye: debug
|
||||
// wordseqbuffer.push_back(typename WORDSEQUENCE::word(wordindex, entries[i - s].firstframe, alignseqbuffer.size()));
|
||||
|
||||
|
||||
// TNed the word, to try one more time if there is an OOV
|
||||
/**/
|
||||
/*
|
||||
if (wid == -1)
|
||||
{
|
||||
// remove / \ * _ - to see if a match could be found
|
||||
char tnw[200];
|
||||
size_t i1 = 0, j = 0;
|
||||
while (w[i1] != '\0')
|
||||
{
|
||||
if (w[i1] != '\\' && w[i1] != '/' && w[i1] != '*' && w[i1] != '-' && w[i1] != '_')
|
||||
{
|
||||
tnw[j] = w[i1]; j++;
|
||||
}
|
||||
i1++;
|
||||
}
|
||||
tnw[j] = '\0';
|
||||
|
||||
|
||||
mp_itr = wordidmap.find(std::string(tnw));
|
||||
wid = ((mp_itr == wordidmap.end()) ? -1 : mp_itr->second);
|
||||
|
||||
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
|
||||
static const unsigned int unknownwordindex = 0xfffff;
|
||||
unsigned int wordindex = (wid == -1) ? unknownwordindex : (unsigned int)wid;
|
||||
wordids.push_back(wordindex);
|
||||
/* guoye: end */
|
||||
}
|
||||
/*guoye: end */
|
||||
}
|
||||
/* guoye: start */
|
||||
/*
|
||||
if (unitmap)
|
||||
{
|
||||
|
||||
if (toks.size() > 4)
|
||||
{
|
||||
const char* u = toks[4]; // the triphone name
|
||||
auto iter = unitmap->find(u); // map to unit id
|
||||
if (iter == unitmap->end())
|
||||
RuntimeError("parseentry: unknown unit %s in utterance %ls", u, key.c_str());
|
||||
const size_t uid = iter->second;
|
||||
alignseqbuffer.push_back(typename WORDSEQUENCE::aligninfo(uid, 0 /*#frames--we accumulate*/ /* ));
|
||||
|
||||
}
|
||||
if (alignseqbuffer.empty())
|
||||
RuntimeError("parseentry: lonely senone entry at start without phone/word entry found, for utterance %ls", key.c_str());
|
||||
alignseqbuffer.back().frames += entries[i - s].numframes; // (we do not have an overflow check here, but should...)
|
||||
|
||||
}
|
||||
*/
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
if (wordidmap.size() != 0) // if reading word sequences as well (for MMI), then record it (in a separate map)
|
||||
{
|
||||
/* guoye: start */
|
||||
// if (!entries.empty() && wordseqbuffer.empty())
|
||||
if (!entries.empty() && wordids.empty())
|
||||
/* guoye: end */
|
||||
// RuntimeError("parseentry: got state alignment but no word-level info, although being requested, for utterance %ls", key.c_str());
|
||||
{
|
||||
|
||||
fprintf(stderr,
|
||||
"Warning: parseentry: got state alignment but no word-level info, although being requested, for utterance %ls. Ignoring this utterance for EMBR \n",
|
||||
key.c_str());
|
||||
// delete this item
|
||||
(*this).erase(key);
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
// post-process silence
|
||||
// - first !silence -> !sent_start
|
||||
// - last !silence -> !sent_end
|
||||
else
|
||||
{
|
||||
|
||||
mp_itr = wordidmap.find("!silence");
|
||||
int silence = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
|
||||
|
||||
|
||||
// debug
|
||||
// int silence = -1;
|
||||
|
||||
if (silence >= 0)
|
||||
{
|
||||
mp_itr = wordidmap.find("!sent_start");
|
||||
int sentstart = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
|
||||
|
||||
mp_itr = wordidmap.find("!sent_end");
|
||||
int sentend = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second);
|
||||
|
||||
// map first and last !silence to !sent_start and !sent_end, respectively
|
||||
/* guoye: start */
|
||||
/*
|
||||
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t)silence)
|
||||
wordseqbuffer.front().wordindex = sentstart;
|
||||
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t)silence)
|
||||
wordseqbuffer.back().wordindex = sentend;
|
||||
*/
|
||||
if (sentstart >= 0 && wordids.front() == (size_t)silence)
|
||||
wordids.front() = sentstart;
|
||||
if (sentend >= 0 && wordids.back() == (size_t)silence)
|
||||
wordids.back() = sentend;
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
/* guoye: end */
|
||||
// if (sentstart < 0 || sentend < 0 || silence < 0)
|
||||
// LogicError("parseentry: word map must contain !silence, !sent_start, and !sent/_end");
|
||||
// implant
|
||||
|
||||
/* guoye: start */
|
||||
/*
|
||||
wordseqbuffer.resize(0);
|
||||
alignseqbuffer.resize(0);
|
||||
|
||||
auto& wordsequence = wordsequences[key]; // this creates the map entry
|
||||
|
||||
wordsequence.words = wordseqbuffer; // makes a copy
|
||||
|
||||
|
||||
wordsequence.align = alignseqbuffer;
|
||||
*/
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
public:
|
||||
// return if input statename is sil state (hard code to compared first 3 chars with "sil")
|
||||
bool issilstate(const string& statename) const // (later use some configuration table)
|
||||
|
@ -1013,6 +1430,14 @@ public:
|
|||
return (statename.size() > 3 && statename.at(0) == 's' && statename.at(1) == 'i' && statename.at(2) == 'l');
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
/*
|
||||
map<wstring, WORDSEQUENCE> get_wordlabels()
|
||||
{
|
||||
return wordsequences;
|
||||
}
|
||||
*/
|
||||
/* guoye: end */
|
||||
vector<bool> issilstatetable; // [state index] => true if is sil state (cached)
|
||||
|
||||
// return if input stateid represent sil state (by table lookup)
|
||||
|
@ -1044,9 +1469,48 @@ public:
|
|||
read(paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame);
|
||||
}
|
||||
|
||||
// note: this function is not designed to be pretty but to be fast
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
|
||||
// alternate constructor that takes wordidmap
|
||||
// template <typename UNITSYMBOLTABLE>
|
||||
/* guoye: start */
|
||||
// htkmlfreader(const vector<wstring>& paths, const set<wstring>& restricttokeys, const wstring& stateListPath, const std::unordered_map<std::string, int>& wordidmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
|
||||
htkmlfreader(const vector<wstring>& paths, const set<wstring>& restricttokeys, const wstring& stateListPath, const std::unordered_map<std::string, int>& wordidmap, const double htkTimeToFrame)
|
||||
/* guoye: end */
|
||||
{
|
||||
// read state list
|
||||
if (stateListPath != L"")
|
||||
readstatelist(stateListPath);
|
||||
|
||||
// read MLF(s) --note: there can be multiple, so this is a loop
|
||||
foreach_index(i, paths)
|
||||
/* guoye: start */
|
||||
// read(paths[i], restricttokeys, wordidmap, unitmap, htkTimeToFrame);
|
||||
read(paths[i], restricttokeys, wordidmap, htkTimeToFrame);
|
||||
/* guoye: end */
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
// phone boundary
|
||||
template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
void read(const wstring& path, const set<wstring>& restricttokeys, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
|
||||
htkmlfreader(const vector<wstring>& paths, const set<wstring>& restricttokeys, const wstring& stateListPath, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap,
|
||||
const double htkTimeToFrame, const msra::asr::simplesenonehmm& hset)
|
||||
{
|
||||
if (stateListPath != L"")
|
||||
readstatelist(stateListPath);
|
||||
symmap = hset.symmap;
|
||||
foreach_index (i, paths)
|
||||
read(paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame);
|
||||
}
|
||||
// note: this function is not designed to be pretty but to be fast
|
||||
/* guoye: start */
|
||||
// template <typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
template <typename WORDSYMBOLTABLE>
|
||||
// void read(const wstring& path, const set<wstring>& restricttokeys, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
|
||||
void read(const wstring& path, const set<wstring>& restricttokeys, const WORDSYMBOLTABLE* wordmap, const double htkTimeToFrame)
|
||||
/* guoye: end */
|
||||
{
|
||||
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
|
||||
return;
|
||||
|
@ -1060,8 +1524,12 @@ public:
|
|||
malformed("header missing");
|
||||
|
||||
// Read the file in blocks and parse MLF entries
|
||||
/* guoye: start */
|
||||
/*
|
||||
std::vector<typename WORDSEQUENCE::word> wordsequencebuffer;
|
||||
std::vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
|
||||
*/
|
||||
/* guoye: end */
|
||||
size_t readBlockSize = 1000000;
|
||||
std::vector<char> currBlockBuf(readBlockSize + 1);
|
||||
size_t currLineNum = 1;
|
||||
|
@ -1091,7 +1559,10 @@ public:
|
|||
{
|
||||
if (restricttokeys.empty() || (this->size() < restricttokeys.size()))
|
||||
{
|
||||
parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
|
||||
/* guoye: start */
|
||||
// parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
|
||||
parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, htkTimeToFrame);
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
currMLFLines.clear();
|
||||
|
@ -1134,6 +1605,104 @@ public:
|
|||
fprintf(stderr, " total %lu entries\n", (unsigned long)this->size());
|
||||
}
|
||||
|
||||
// note: this function is not designed to be pretty but to be fast
|
||||
/* guoye: start */
|
||||
// template <typename UNITSYMBOLTABLE>
|
||||
// void read(const wstring& path, const set<wstring>& restricttokeys, const std::unordered_map<std::string, int>& wordidmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame)
|
||||
void read(const wstring& path, const set<wstring>& restricttokeys, const std::unordered_map<std::string, int>& wordidmap, const double htkTimeToFrame)
|
||||
{
|
||||
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
|
||||
return;
|
||||
|
||||
fprintf(stderr, "htkmlfreader: reading MLF file %ls ...", path.c_str());
|
||||
curpath = path; // for error messages only
|
||||
|
||||
auto_file_ptr f(fopenOrDie(path, L"rb"));
|
||||
std::string headerLine = fgetline(f);
|
||||
if (headerLine != "#!MLF!#")
|
||||
malformed("header missing");
|
||||
|
||||
// Read the file in blocks and parse MLF entries
|
||||
/* guoye: start */
|
||||
/*
|
||||
std::vector<typename WORDSEQUENCE::word> wordsequencebuffer;
|
||||
std::vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
|
||||
*/
|
||||
/* guoye: end */
|
||||
size_t readBlockSize = 1000000;
|
||||
std::vector<char> currBlockBuf(readBlockSize + 1);
|
||||
size_t currLineNum = 1;
|
||||
std::vector<string> currMLFLines;
|
||||
bool reachedEOF = (feof(f) != 0);
|
||||
char* nextReadPtr = currBlockBuf.data();
|
||||
size_t nextReadSize = readBlockSize;
|
||||
while (!reachedEOF)
|
||||
{
|
||||
size_t numBytesRead = fread(nextReadPtr, sizeof(char), nextReadSize, f);
|
||||
reachedEOF = (numBytesRead != nextReadSize);
|
||||
if (ferror(f))
|
||||
RuntimeError("error reading from file: %s", strerror(errno));
|
||||
|
||||
// Add 0 at the end to make it a proper C string
|
||||
nextReadPtr[numBytesRead] = 0;
|
||||
|
||||
// Now extract lines from the currBlockBuf and parse MLF entries
|
||||
char* context = nullptr;
|
||||
const char* delim = "\r\n";
|
||||
|
||||
auto consumeMLFLine = [&](const char* mlfLine)
|
||||
{
|
||||
currLineNum++;
|
||||
currMLFLines.push_back(mlfLine);
|
||||
if ((mlfLine[0] == '.') && (mlfLine[1] == 0)) // utterance end delimiter: a single dot on a line
|
||||
{
|
||||
if (restricttokeys.empty() || (this->size() < restricttokeys.size()))
|
||||
{
|
||||
// parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordidmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
|
||||
parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordidmap, htkTimeToFrame);
|
||||
}
|
||||
|
||||
currMLFLines.clear();
|
||||
}
|
||||
};
|
||||
|
||||
char* prevLine = strtok_s(currBlockBuf.data(), delim, &context);
|
||||
for (char* currLine = strtok_s(NULL, delim, &context); currLine; currLine = strtok_s(NULL, delim, &context))
|
||||
{
|
||||
consumeMLFLine(prevLine);
|
||||
prevLine = currLine;
|
||||
}
|
||||
|
||||
// The last line read from the block may be a full line or part of a line
|
||||
// We can tell by whether the terminating NULL for this line is the NULL
|
||||
// we inserted after reading from the file
|
||||
size_t prevLineLen = strlen(prevLine);
|
||||
if ((prevLine + prevLineLen) == (nextReadPtr + numBytesRead))
|
||||
{
|
||||
// This is not a full line, but just a truncated part of a line.
|
||||
// Lets copy this to the start of the currBlockBuf and read new data
|
||||
// from there on
|
||||
strcpy_s(currBlockBuf.data(), currBlockBuf.size(), prevLine);
|
||||
nextReadPtr = currBlockBuf.data() + prevLineLen;
|
||||
nextReadSize = readBlockSize - prevLineLen;
|
||||
}
|
||||
else
|
||||
{
|
||||
// A full line
|
||||
consumeMLFLine(prevLine);
|
||||
nextReadPtr = currBlockBuf.data();
|
||||
nextReadSize = readBlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
if (!currMLFLines.empty())
|
||||
malformed("unexpected end in mid-utterance");
|
||||
|
||||
curpath.clear();
|
||||
fprintf(stderr, " total %lu entries\n", (unsigned long)this->size());
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
// read state list, index is from 0
|
||||
void readstatelist(const wstring& stateListPath = L"")
|
||||
{
|
||||
|
@ -1168,10 +1737,14 @@ public:
|
|||
}
|
||||
|
||||
// access to word sequences
|
||||
/* guoye: start */
|
||||
|
||||
const map<wstring, WORDSEQUENCE>& allwordtranscripts() const
|
||||
{
|
||||
return wordsequences;
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
};
|
||||
};
|
||||
}; // namespaces
|
||||
|
|
|
@ -405,8 +405,11 @@ void lattice::dedup()
|
|||
// - empty ("") -> don't output, just check the format
|
||||
// - dash ("-") -> dump lattice to stdout instead
|
||||
/*static*/ void archive::convert(const std::wstring &intocpath, const std::wstring &intocpath2, const std::wstring &outpath,
|
||||
const msra::asr::simplesenonehmm &hset)
|
||||
{
|
||||
/* guoye: start */
|
||||
// const msra::asr::simplesenonehmm &hset)
|
||||
const msra::asr::simplesenonehmm &hset, std::set<int>& specialwordids)
|
||||
/* guoye: end */
|
||||
{
|
||||
const auto &modelsymmap = hset.getsymmap();
|
||||
|
||||
const std::wstring tocpath = outpath + L".toc";
|
||||
|
@ -457,8 +460,10 @@ void lattice::dedup()
|
|||
|
||||
// fetch lattice --this performs any necessary format conversions already
|
||||
lattice L;
|
||||
archive.getlattice(key, L);
|
||||
|
||||
/* guoye: start */
|
||||
// archive.getlattice(key, L);
|
||||
archive.getlattice(key, L, specialwordids);
|
||||
/* guoye: end */
|
||||
lattice L2;
|
||||
if (mergemode)
|
||||
{
|
||||
|
@ -468,8 +473,10 @@ void lattice::dedup()
|
|||
skippedmerges++;
|
||||
continue;
|
||||
}
|
||||
archive2.getlattice(key, L2);
|
||||
|
||||
/* guoye: start */
|
||||
// archive2.getlattice(key, L2);
|
||||
archive2.getlattice(key, L2, specialwordids);
|
||||
/* guoye: end */
|
||||
// merge it in
|
||||
// This will connect each node with matching 1-phone context conditions; aimed at merging numer lattices.
|
||||
L.removefinalnull(); // get rid of that final !NULL headache
|
||||
|
@ -563,6 +570,9 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map<std::
|
|||
|
||||
assert(info.numnodes > 0);
|
||||
nodes.reserve(info.numnodes);
|
||||
/* guoye: start */
|
||||
vt_node_out_edge_indices.resize(info.numnodes);
|
||||
/* guoye: end */
|
||||
// parse the nodes
|
||||
for (size_t i = 0; i < info.numnodes; i++, iter++)
|
||||
{
|
||||
|
@ -570,11 +580,24 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map<std::
|
|||
RuntimeError("lattice: not enough I lines in lattice");
|
||||
unsigned long itest;
|
||||
float t;
|
||||
if (sscanf_s(*iter, "I=%lu t=%f%c", &itest, &t, &dummychar, (unsigned int)sizeof(dummychar)) < 2)
|
||||
/* guoye: start */
|
||||
|
||||
char d[100];
|
||||
// if (sscanf_s(*iter, "I=%lu t=%f%c", &itest, &t, &dummychar, (unsigned int)sizeof(dummychar)) < 2)
|
||||
if (sscanf_s(*iter, "I=%lu t=%f W=%s", &itest, &t, &d, (unsigned int)sizeof(d)) < 3)
|
||||
RuntimeError("lattice: mal-formed node line in lattice: %s", *iter);
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
if (i != (size_t) itest)
|
||||
RuntimeError("lattice: out-of-sequence node line in lattice: %s", *iter);
|
||||
nodes.push_back(nodeinfo((unsigned int) (t / info.frameduration + 0.5)));
|
||||
/* guoye: start */
|
||||
// nodes.push_back(nodeinfo((unsigned int) (t / info.frameduration + 0.5)));
|
||||
// To do: we need to map the d to the wordid. It is P2 task.
|
||||
// For current speech production pipeline, we read from lattice archive rather than from the raw lattice. So, this code is actually not used.
|
||||
|
||||
nodes.push_back(nodeinfo((unsigned int)(t / info.frameduration + 0.5), 0));
|
||||
/* guoye: end */
|
||||
info.numframes = max(info.numframes, (size_t) nodes.back().t);
|
||||
}
|
||||
// parse the edges
|
||||
|
@ -600,6 +623,10 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map<std::
|
|||
if (j != (size_t) jtest)
|
||||
RuntimeError("lattice: out-of-sequence edge line in lattice: %s", *iter);
|
||||
edges.push_back(edgeinfowithscores(S, E, a, l, align.size()));
|
||||
|
||||
/* guoye: start */
|
||||
vt_node_out_edge_indices[S].push_back(j);
|
||||
/* guoye: end */
|
||||
// build align array
|
||||
size_t edgeframes = 0; // (for checking whether the alignment sums up right)
|
||||
const char *p = d;
|
||||
|
@ -731,5 +758,10 @@ void lattice::frommlf(const wstring &key2, const std::unordered_map<std::string,
|
|||
|
||||
showstats();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
};
|
||||
};
|
||||
|
|
|
@ -34,35 +34,64 @@ public:
|
|||
// - lattices are returned as a shared_ptr
|
||||
// Thus, getbatch() can be called in a thread-safe fashion, allowing for a 'minibatchsource' implementation that wraps another with a read-ahead thread.
|
||||
// Return value is 'true' if it did read anything from disk, and 'false' if data came only from RAM cache. This is used for controlling the read-ahead thread.
|
||||
|
||||
virtual bool getbatch(const size_t globalts,
|
||||
const size_t framesrequested, msra::dbn::matrix &feat, std::vector<size_t> &uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices) = 0;
|
||||
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
|
||||
|
||||
|
||||
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
|
||||
virtual bool getbatch(const size_t globalts,
|
||||
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries) = 0;
|
||||
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries) = 0;
|
||||
|
||||
|
||||
|
||||
// getbatch() overload to support subsetting of mini-batches for parallel training
|
||||
// Default implementation does not support subsetting and throws an exception on
|
||||
// calling this overload with a numsubsets value other than 1.
|
||||
|
||||
virtual bool getbatch(const size_t globalts,
|
||||
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
|
||||
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries)
|
||||
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
|
||||
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries)
|
||||
{
|
||||
assert((subsetnum == 0) && (numsubsets == 1) && !supportsbatchsubsetting());
|
||||
subsetnum;
|
||||
numsubsets;
|
||||
|
||||
bool retVal = getbatch(globalts, framesrequested, feat, uids, transcripts, lattices, sentendmark, phoneboundaries);
|
||||
framesadvanced = feat[0].cols();
|
||||
|
||||
return retVal;
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
virtual bool getbatch(const size_t globalts,
|
||||
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
|
||||
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids, std::vector<std::vector<size_t>> &wids, std::vector<std::vector<short>> &nws,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries)
|
||||
{
|
||||
wids.resize(0);
|
||||
nws.resize(0);
|
||||
|
||||
|
||||
bool retVal = getbatch(globalts, framesrequested, subsetnum, numsubsets, framesadvanced, feat, uids, transcripts, lattices, sentendmark, phoneboundaries);
|
||||
|
||||
return retVal;
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
virtual bool supportsbatchsubsetting() const
|
||||
{
|
||||
return false;
|
||||
|
@ -102,6 +131,10 @@ class minibatchiterator
|
|||
|
||||
std::vector<msra::dbn::matrix> featbuf; // buffer for holding curernt minibatch's frames
|
||||
std::vector<std::vector<size_t>> uids; // buffer for storing current minibatch's frame-level label sequence
|
||||
/* guoye: start */
|
||||
std::vector<std::vector<size_t>> wids; // buffer for storing current minibatch's word-level label sequence
|
||||
std::vector<std::vector<short>> nws; // buffer for storing current minibatch's number of words for each utterance
|
||||
/* guoye: end */
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts; // buffer for storing current minibatch's word-level label sequences (if available and used; empty otherwise)
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> lattices; // lattices of the utterances in current minibatch (empty in frame mode)
|
||||
|
||||
|
@ -126,7 +159,13 @@ private:
|
|||
|
||||
foreach_index (i, uids)
|
||||
uids[i].clear();
|
||||
/* guoye: start */
|
||||
foreach_index(i, wids)
|
||||
wids[i].clear();
|
||||
|
||||
foreach_index(i, nws)
|
||||
nws[i].clear();
|
||||
/* guoye: end */
|
||||
transcripts.clear();
|
||||
actualmbframes = 0;
|
||||
return;
|
||||
|
@ -135,7 +174,10 @@ private:
|
|||
assert(requestedmbframes > 0);
|
||||
const size_t requestedframes = std::min(requestedmbframes, epochendframe - mbstartframe); // (< mbsize at end)
|
||||
assert(requestedframes > 0);
|
||||
source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices, sentendmark, phoneboundaries);
|
||||
/* guoye: start */
|
||||
// source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices, sentendmark, phoneboundaries);
|
||||
source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, wids, nws, transcripts, lattices, sentendmark, phoneboundaries);
|
||||
/* guoye: end */
|
||||
timegetbatch = source.gettimegetbatch();
|
||||
actualmbframes = featbuf[0].cols(); // for single i/o, there featbuf is length 1
|
||||
// note:
|
||||
|
@ -314,6 +356,26 @@ public:
|
|||
assert(uids.size() >= i + 1);
|
||||
return uids[i];
|
||||
}
|
||||
/* guoye: start */
|
||||
// return the reference transcript word labels (word labels) for current minibatch
|
||||
/*const*/ std::vector<size_t> &wlabels()
|
||||
{
|
||||
checkhasdata();
|
||||
assert(wids.size() == 1);
|
||||
|
||||
return wids[0];
|
||||
}
|
||||
|
||||
// return the number of words for current minibatch
|
||||
/*const*/ std::vector<short> &nwords()
|
||||
{
|
||||
checkhasdata();
|
||||
assert(nws.size() == 1);
|
||||
|
||||
return nws[0];
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
std::vector<size_t> &sentends()
|
||||
{
|
||||
|
|
|
@ -194,6 +194,9 @@ static void augmentneighbors(const std::vector<std::vector<float>>& frames, cons
|
|||
// TODO: This is currently being hardcoded to unsigned short for saving space, which means untied context-dependent phones
|
||||
// will not work. This needs to be changed to dynamically choose what size to use based on the number of class ids.
|
||||
typedef unsigned short CLASSIDTYPE;
|
||||
/* guoye: start */
|
||||
typedef unsigned int WORDIDTYPE;
|
||||
/* guoye: end */
|
||||
typedef unsigned short HMMIDTYPE;
|
||||
|
||||
#ifndef _MSC_VER
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
@ -9,12 +9,16 @@
|
|||
|
||||
#include "Basics.h"
|
||||
#include "fileutil.h" // for opening/reading the ARPA file
|
||||
/* guoye: start */
|
||||
// #include "fileutil.cpp"
|
||||
/* guoye: end */
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <algorithm> // for various sort() calls
|
||||
#include <math.h>
|
||||
|
||||
|
||||
namespace msra { namespace lm {
|
||||
|
||||
// ===========================================================================
|
||||
|
@ -92,15 +96,40 @@ static inline double invertlogprob(double logP)
|
|||
// compare function to allow char* as keys (without, unordered_map will correctly
|
||||
// compute a hash key from the actual strings, but then compare the pointers
|
||||
// -- duh!)
|
||||
struct less_strcmp : public std::binary_function<const char *, const char *, bool>
|
||||
/* guoye: start */
|
||||
// struct less_strcmp : public std::binary_function<const char *, const char *, bool>
|
||||
struct equal_strcmp : public std::binary_function<const char *, const char *, bool>
|
||||
{ // this implements operator<
|
||||
bool operator()(const char *const &_Left, const char *const &_Right) const
|
||||
{
|
||||
return strcmp(_Left, _Right) < 0;
|
||||
// return strcmp(_Left, _Right) < 0;
|
||||
return strcmp(_Left, _Right) == 0;
|
||||
}
|
||||
};
|
||||
/* guoye: end */
|
||||
struct BKDRHash {
|
||||
//BKDR hash algorithm
|
||||
int operator()(const char * str)const
|
||||
{
|
||||
unsigned int seed = 131; //31 131 1313 13131131313 etc//
|
||||
unsigned int hash = 0;
|
||||
while (*str)
|
||||
{
|
||||
hash = (hash * seed) + (*str);
|
||||
str++;
|
||||
}
|
||||
|
||||
return hash & (0x7FFFFFFF);
|
||||
}
|
||||
};
|
||||
|
||||
class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char *>, less_strcmp>
|
||||
|
||||
/* guoye: start */
|
||||
/* bug fix: the customize function of compare should be written in the one commented below is not right. The generated behavior is very strange: it does not correctly make a map. So, fix it. */
|
||||
// class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char *>, less_strcmp>
|
||||
// class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char *>, equal_strcmp>
|
||||
class CSymbolSet : public std::unordered_map<const char *, int, BKDRHash, equal_strcmp>
|
||||
/* guoye: end */
|
||||
{
|
||||
std::vector<const char *> symbols; // the symbols
|
||||
|
||||
|
@ -128,7 +157,9 @@ public:
|
|||
// get id for an existing word, returns -1 if not existing
|
||||
int operator[](const char *key) const
|
||||
{
|
||||
unordered_map<const char *, int>::const_iterator iter = find(key);
|
||||
/* guoye: start */
|
||||
// unordered_map<const char *, int>::const_iterator iter = find(key);
|
||||
unordered_map<const char *, int, BKDRHash, equal_strcmp>::const_iterator iter = find(key);
|
||||
return (iter != end()) ? iter->second : -1;
|
||||
}
|
||||
|
||||
|
@ -136,7 +167,10 @@ public:
|
|||
// determine unique id for a word ('key')
|
||||
int operator[](const char *key)
|
||||
{
|
||||
unordered_map<const char *, int>::const_iterator iter = find(key);
|
||||
/* guoye: start */
|
||||
// unordered_map<const char *, int>::const_iterator iter = find(key);
|
||||
unordered_map<const char *, int, BKDRHash, equal_strcmp>::const_iterator iter = find(key);
|
||||
|
||||
if (iter != end())
|
||||
return iter->second;
|
||||
|
||||
|
@ -149,7 +183,11 @@ public:
|
|||
{
|
||||
int id = (int) symbols.size();
|
||||
symbols.push_back(p); // we own the memory--remember to free it
|
||||
insert(std::make_pair(p, id));
|
||||
/* guoye: start */
|
||||
// insert(std::make_pair(p, id));
|
||||
if(!insert(std::make_pair(p, id)).second)
|
||||
RuntimeError("Insertion key %s into map failed in msra_mgram.h", p);
|
||||
/* guoye: end */
|
||||
return id;
|
||||
}
|
||||
catch (...)
|
||||
|
@ -1450,30 +1488,67 @@ public:
|
|||
int lineNo = 0;
|
||||
auto_file_ptr f(fopenOrDie(pathname, L"rbS"));
|
||||
fprintf(stderr, "read: reading %ls", pathname.c_str());
|
||||
/* guoye: start */
|
||||
//fprintf(stderr, "\n msra_mgram.h: read: debug 0\n");
|
||||
/* guoye: end */
|
||||
filename = pathname; // (keep this info for debugging)
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.1\n");
|
||||
/* guoye: end */
|
||||
// --- read header information
|
||||
|
||||
// search for header line
|
||||
char buf[1024];
|
||||
lineNo++, fgetline(f, buf);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.2\n");
|
||||
/* guoye: end */
|
||||
/* guoye: start */
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.25\n");
|
||||
fgetline(f, buf);
|
||||
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 0.3\n");
|
||||
/* guoye: end */
|
||||
while (strcmp(buf, "\\data\\") != 0 && !feof(f))
|
||||
lineNo++, fgetline(f, buf);
|
||||
lineNo++, fgetline(f, buf);
|
||||
/* guoye: start */
|
||||
{
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
}
|
||||
/* guoye: end */
|
||||
/* guoye: start */
|
||||
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 1\n");
|
||||
/* guoye: end */
|
||||
|
||||
// get the dimensions
|
||||
std::vector<int> dims;
|
||||
dims.reserve(4);
|
||||
|
||||
while (buf[0] == 0 && !feof(f))
|
||||
lineNo++, fgetline(f, buf);
|
||||
|
||||
/* guoye: start */
|
||||
{
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
}
|
||||
/* guoye: end */
|
||||
int n, dim;
|
||||
dims.push_back(1); // dummy zerogram entry
|
||||
while (sscanf(buf, "ngram %d=%d", &n, &dim) == 2 && n == (int) dims.size())
|
||||
{
|
||||
dims.push_back(dim);
|
||||
lineNo++, fgetline(f, buf);
|
||||
/* guoye: start */
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
M = (int) dims.size() - 1;
|
||||
|
@ -1483,6 +1558,9 @@ public:
|
|||
if (M > maxM)
|
||||
M = maxM;
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 2\n");
|
||||
/* guoye: end */
|
||||
// allocate main storage
|
||||
map.init(M);
|
||||
logP.init(M);
|
||||
|
@ -1502,18 +1580,34 @@ public:
|
|||
std::vector<bool> skipWord; // true: skip entry containing this word
|
||||
skipWord.reserve(lmSymbols.capacity());
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 3\n");
|
||||
/* guoye: end */
|
||||
|
||||
// --- read main sections
|
||||
|
||||
const double ln10xLMF = log(10.0); // ARPA scores are strangely scaled
|
||||
msra::strfun::tokenizer tokens(" \t\n\r", M + 1); // used in tokenizing the input line
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 4\n");
|
||||
/* guoye: end */
|
||||
for (int m = 1; m <= M; m++)
|
||||
{
|
||||
while (buf[0] == 0 && !feof(f))
|
||||
lineNo++, fgetline(f, buf);
|
||||
|
||||
/* guoye: start */
|
||||
{
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
}
|
||||
/* guoye: end */
|
||||
if (sscanf(buf, "\\%d-grams:", &n) != 1 || n != m)
|
||||
RuntimeError("read: mal-formed LM file, bad section header (%d): %ls", lineNo, pathname.c_str());
|
||||
lineNo++, fgetline(f, buf);
|
||||
/* guoye: start */
|
||||
//lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
/* guoye: end */
|
||||
|
||||
std::vector<int> mgram(m + 1, -1); // current mgram being read ([0]=dummy)
|
||||
std::vector<int> prevmgram(m + 1, -1); // cache to speed up symbol lookup
|
||||
|
@ -1524,7 +1618,11 @@ public:
|
|||
{
|
||||
if (buf[0] == 0)
|
||||
{
|
||||
lineNo++, fgetline(f, buf);
|
||||
/* guoye: start */
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
/* guoye: end */
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1576,8 +1674,11 @@ public:
|
|||
double boVal = atof(tokens[m + 1]); // ... use sscanf() instead for error checking?
|
||||
thisLogB = boVal * ln10xLMF; // convert to natural log
|
||||
}
|
||||
|
||||
lineNo++, fgetline(f, buf);
|
||||
/* guoye: start */
|
||||
// lineNo++, fgetline(f, buf);
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
/* guoye: end */
|
||||
|
||||
if (skipEntry) // word contained unknown vocabulary: skip entire entry
|
||||
goto skipMGram;
|
||||
|
@ -1615,17 +1716,29 @@ public:
|
|||
|
||||
fprintf(stderr, ", %d %d-grams", map.size(m), m);
|
||||
}
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 5\n");
|
||||
/* guoye: end */
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// check end tag
|
||||
if (M == fileM)
|
||||
{ // only if caller did not restrict us to a lower order
|
||||
while (buf[0] == 0 && !feof(f))
|
||||
lineNo++, fgetline(f, buf);
|
||||
/* guoye: start */
|
||||
{
|
||||
lineNo++;
|
||||
fgetline(f, buf);
|
||||
// lineNo++, fgetline(f, buf);
|
||||
}
|
||||
if (strcmp(buf, "\\end\\") != 0)
|
||||
RuntimeError("read: mal-formed LM file, no \\end\\ tag (%d): %ls", lineNo, pathname.c_str());
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 6 \n");
|
||||
/* guoye: end */
|
||||
|
||||
// update zerogram score by one appropriate for OOVs
|
||||
updateOOVScore();
|
||||
|
||||
|
@ -1638,7 +1751,13 @@ public:
|
|||
int id = symbolToId(sym); // may be -1 if not found
|
||||
userToLMSymMap[i] = id;
|
||||
}
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 7 \n");
|
||||
/* guoye: end */
|
||||
map.created(userToLMSymMap);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "\n msra_mgram.h: read: debug 8 \n");
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
@ -561,7 +561,10 @@ class minibatchframesourcemulti : public minibatchsource
|
|||
public:
|
||||
// constructor
|
||||
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
|
||||
minibatchframesourcemulti(const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
|
||||
/* guoye: start */
|
||||
// minibatchframesourcemulti(const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
|
||||
minibatchframesourcemulti(const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::pair<std::vector<msra::asr::htkmlfentry>, std::vector<unsigned int>>>> &labels,
|
||||
/* guoye: end */
|
||||
std::vector<size_t> vdim, std::vector<size_t> udim, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t randomizationrange, const std::vector<std::wstring> &pagepath, const bool mayhavenoframe = false, int addEnergy = 0)
|
||||
: vdim(vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod(0), featdim(0), numframes(0), timegetbatch(0), verbosity(2), maxvdim(0)
|
||||
{
|
||||
|
@ -656,7 +659,10 @@ public:
|
|||
// HVite occasionally generates mismatching output --skip such files
|
||||
if (!key.empty()) // (we have a key if supervised mode)
|
||||
{
|
||||
const auto &labseq = labels[0].find(key)->second; // (we already checked above that it exists)
|
||||
/* guoye: start */
|
||||
// const auto &labseq = labels[0].find(key)->second; // (we already checked above that it exists)
|
||||
const auto &labseq = labels[0].find(key)->second.first; // (we already checked above that it exists)
|
||||
/* guoye: end */
|
||||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size() - 1].firstframe + labseq[labseq.size() - 1].numframes);
|
||||
if (abs((int) labframes - (int) feat.cols()) > 0)
|
||||
{
|
||||
|
@ -695,7 +701,7 @@ public:
|
|||
{
|
||||
foreach_index (j, labels)
|
||||
{
|
||||
const auto &labseq = labels[j].find(key)->second; // (we already checked above that it exists)
|
||||
const auto &labseq = labels[j].find(key)->second.first; // (we already checked above that it exists)
|
||||
foreach_index (i2, labseq)
|
||||
{
|
||||
const auto &e = labseq[i2];
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
#include "minibatchiterator.h"
|
||||
#include <unordered_set>
|
||||
#include <random>
|
||||
/* guoye: start */
|
||||
#include <set>
|
||||
/* guoye: end */
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
|
@ -36,6 +39,9 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
const bool truncated; //false -> truncated utterance or not within minibatch
|
||||
size_t maxUtteranceLength; //10000 ->maximum utterance length in non-frame and non-truncated mode
|
||||
|
||||
/* guoye: start */
|
||||
std::set<int> specialwordids; // stores the word ids that will not be counted for WER computation
|
||||
/* guoye: end */
|
||||
std::vector<std::vector<size_t>> counts; // [s] occurence count for all states (used for priors)
|
||||
int verbosity;
|
||||
// lattice reader
|
||||
|
@ -55,9 +61,18 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
{
|
||||
msra::asr::htkfeatreader::parsedpath parsedpath; // archive filename and frame range in that file
|
||||
size_t classidsbegin; // index into allclassids[] array (first frame)
|
||||
/* guoye: start */
|
||||
size_t wordidsbegin;
|
||||
|
||||
short numwords;
|
||||
|
||||
//utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin)
|
||||
// : parsedpath(std::move(ppath)), classidsbegin(classidsbegin), framesToExpand(0), needsExpansion(false)
|
||||
utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin, size_t wordidsbegin)
|
||||
: parsedpath(std::move(ppath)), classidsbegin(classidsbegin), wordidsbegin(wordidsbegin), framesToExpand(0), needsExpansion(false)
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin)
|
||||
: parsedpath(std::move(ppath)), classidsbegin(classidsbegin), framesToExpand(0), needsExpansion(false)
|
||||
{
|
||||
}
|
||||
bool needsExpansion; // ivector type of feature
|
||||
|
@ -73,6 +88,17 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
else
|
||||
return parsedpath.numframes();
|
||||
}
|
||||
/* guoye: start */
|
||||
short getnumwords() const
|
||||
{
|
||||
return numwords;
|
||||
}
|
||||
|
||||
void setnumwords(short nw)
|
||||
{
|
||||
numwords = nw;
|
||||
}
|
||||
/* guoye: end */
|
||||
std::wstring key() const // key used for looking up lattice (not stored to save space)
|
||||
{
|
||||
#ifdef _MSC_VER
|
||||
|
@ -129,6 +155,18 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
{
|
||||
return utteranceset[i].classidsbegin;
|
||||
}
|
||||
/* guoye: start */
|
||||
size_t getwordidsbegin(size_t i) const
|
||||
{
|
||||
return utteranceset[i].wordidsbegin;
|
||||
}
|
||||
|
||||
short numwords(size_t i) const
|
||||
{
|
||||
return utteranceset[i].numwords;
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
msra::dbn::matrixstripe getutteranceframes(size_t i) const // return the frame set for a given utterance
|
||||
{
|
||||
if (!isinram())
|
||||
|
@ -152,8 +190,9 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
}
|
||||
// page in data for this chunk
|
||||
// We pass in the feature info variables by ref which will be filled lazily upon first read
|
||||
void requiredata(std::string &featkind, size_t &featdim, unsigned int &sampperiod, const latticesource &latticesource, int verbosity = 0) const
|
||||
void requiredata(std::string &featkind, size_t &featdim, unsigned int &sampperiod, const latticesource &latticesource, std::set<int>& specialwordids, int verbosity = 0) const
|
||||
{
|
||||
|
||||
if (numutterances() == 0)
|
||||
LogicError("requiredata: cannot page in virgin block");
|
||||
if (isinram())
|
||||
|
@ -181,8 +220,12 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
auto uttframes = getutteranceframes(i); // matrix stripe for this utterance (currently unfilled)
|
||||
reader.read(utteranceset[i].parsedpath, (const std::string &)featkind, sampperiod, uttframes, utteranceset[i].needsExpansion); // note: file info here used for checkuing only
|
||||
// page in lattice data
|
||||
/* guoye: start */
|
||||
|
||||
if (!latticesource.empty())
|
||||
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols());
|
||||
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols(), specialwordids);
|
||||
|
||||
/* guoye: end */
|
||||
}
|
||||
if (verbosity)
|
||||
{
|
||||
|
@ -234,6 +277,10 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
std::vector<std::vector<utterancechunkdata>> allchunks; // set of utterances organized in chunks, referred to by an iterator (not an index)
|
||||
std::vector<std::unique_ptr<biggrowablevector<CLASSIDTYPE>>> classids; // [classidsbegin+t] concatenation of all state sequences
|
||||
|
||||
/* guoye: start */
|
||||
std::vector<std::unique_ptr<biggrowablevector<WORDIDTYPE>>> wordids; // [wordidsbegin+t] concatenation of all state sequences
|
||||
|
||||
/* guoye: end */
|
||||
bool m_generatePhoneBoundaries;
|
||||
std::vector<std::unique_ptr<biggrowablevector<HMMIDTYPE>>> phoneboundaries;
|
||||
bool issupervised() const
|
||||
|
@ -299,6 +346,9 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
}
|
||||
|
||||
size_t numframes; // (cached since we cannot directly access the underlying data from here)
|
||||
/* guoye: start */
|
||||
short numwords;
|
||||
/* guoye: end */
|
||||
size_t globalts; // start frame in global space after randomization (for mapping frame index to utterance position)
|
||||
size_t globalte() const
|
||||
{
|
||||
|
@ -850,6 +900,32 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
}
|
||||
return allclassids; // nothing to return
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
template <class UTTREF>
|
||||
std::vector<shiftedvector<biggrowablevector<WORDIDTYPE>>> getwordids(const UTTREF &uttref) // return sub-vector of classids[] for a given utterance
|
||||
{
|
||||
std::vector<shiftedvector<biggrowablevector<WORDIDTYPE>>> allwordids;
|
||||
|
||||
if (!issupervised())
|
||||
{
|
||||
foreach_index(i, wordids)
|
||||
allwordids.push_back(std::move(shiftedvector<biggrowablevector<WORDIDTYPE>>((*wordids[i]), 0, 0)));
|
||||
return allwordids; // nothing to return
|
||||
}
|
||||
const auto &chunk = randomizedchunks[0][uttref.chunkindex];
|
||||
const auto &chunkdata = chunk.getchunkdata();
|
||||
const size_t wordidsbegin = chunkdata.getwordidsbegin(uttref.utteranceindex()); // index of first state label in global concatenated classids[] array
|
||||
const size_t n = chunkdata.numwords(uttref.utteranceindex());
|
||||
foreach_index(i, wordids)
|
||||
{
|
||||
if ((*wordids[i])[wordidsbegin + n] != (WORDIDTYPE)-1)
|
||||
LogicError("getwordids: expected boundary marker not found, internal data structure screwed up");
|
||||
allwordids.push_back(std::move(shiftedvector<biggrowablevector<WORDIDTYPE>>((*wordids[i]), wordidsbegin, n)));
|
||||
}
|
||||
return allwordids; // nothing to return
|
||||
}
|
||||
/* guoye: end */
|
||||
template <class UTTREF>
|
||||
std::vector<shiftedvector<biggrowablevector<HMMIDTYPE>>> getphonebound(const UTTREF &uttref) // return sub-vector of classids[] for a given utterance
|
||||
{
|
||||
|
@ -882,13 +958,23 @@ public:
|
|||
// constructor
|
||||
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
|
||||
// This mode requires utterances with time stamps.
|
||||
minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
|
||||
/* guoye: start */
|
||||
|
||||
// minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
|
||||
// minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> &labels,
|
||||
minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector<std::vector<std::wstring>> &infiles, const std::vector<std::map<std::wstring, std::pair<std::vector<msra::asr::htkmlfentry>, std::vector<unsigned int>>>> &labels,
|
||||
// const std::vector<std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence>>& wordlabels,
|
||||
std::set<int>& specialwordids,
|
||||
/* guoye: end */
|
||||
std::vector<size_t> vdim, std::vector<size_t> udim, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t randomizationrange,
|
||||
const latticesource &lattices, const std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence> &allwordtranscripts, const bool framemode, std::vector<bool> expandToUtt,
|
||||
const size_t maxUtteranceLength, const bool truncated)
|
||||
: vdim(vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod(0), featdim(0), randomizationrange(randomizationrange), currentsweep(SIZE_MAX),
|
||||
lattices(lattices), allwordtranscripts(allwordtranscripts), framemode(framemode), chunksinram(0), timegetbatch(0), verbosity(2), m_generatePhoneBoundaries(!lattices.empty()),
|
||||
m_frameRandomizer(randomizedchunks, useMersenneTwister), expandToUtt(expandToUtt), m_useMersenneTwister(useMersenneTwister), maxUtteranceLength(maxUtteranceLength), truncated(truncated)
|
||||
/* guoye: start */
|
||||
, specialwordids(specialwordids)
|
||||
/* guoye: end */
|
||||
// [v-hansu] change framemode (lattices.empty()) into framemode (false) to run utterance mode without lattice
|
||||
// you also need to change another line, search : [v-hansu] comment out to run utterance mode without lattice
|
||||
{
|
||||
|
@ -905,6 +991,9 @@ public:
|
|||
std::vector<size_t> uttduration; // track utterance durations to determine utterance validity
|
||||
|
||||
std::vector<size_t> classidsbegin;
|
||||
/* guoye: start */
|
||||
std::vector<size_t> wordidsbegin;
|
||||
/* guoye: end */
|
||||
|
||||
allchunks = std::vector<std::vector<utterancechunkdata>>(infiles.size(), std::vector<utterancechunkdata>());
|
||||
featdim = std::vector<size_t>(infiles.size(), 0);
|
||||
|
@ -917,6 +1006,9 @@ public:
|
|||
foreach_index (i, labels)
|
||||
{
|
||||
classids.push_back(std::unique_ptr<biggrowablevector<CLASSIDTYPE>>(new biggrowablevector<CLASSIDTYPE>()));
|
||||
/* guoye: start */
|
||||
wordids.push_back(std::unique_ptr<biggrowablevector<WORDIDTYPE>>(new biggrowablevector<WORDIDTYPE>()));
|
||||
/* guoye: end */
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries.push_back(std::unique_ptr<biggrowablevector<HMMIDTYPE>>(new biggrowablevector<HMMIDTYPE>()));
|
||||
|
||||
|
@ -945,7 +1037,10 @@ public:
|
|||
|
||||
foreach_index (i, infiles[m])
|
||||
{
|
||||
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0); // mseltzer - is this foolproof for multiio? is classids always non-empty?
|
||||
/* guoye: start */
|
||||
// utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0); // mseltzer - is this foolproof for multiio? is classids always non-empty?
|
||||
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0, 0);
|
||||
/* guoye: end */
|
||||
const size_t uttframes = utterance.numframes(); // will throw if frame bounds not given --required to be given in this mode
|
||||
if (expandToUtt[m] && uttframes != 1)
|
||||
RuntimeError("minibatchutterancesource: utterance-based features must be 1 frame in duration");
|
||||
|
@ -1002,9 +1097,13 @@ public:
|
|||
// else
|
||||
// if (infiles[m].size()!=numutts)
|
||||
// RuntimeError("minibatchutterancesourcemulti: all feature files must have same number of utterances\n");
|
||||
/* guoye: start */
|
||||
if (m == 0)
|
||||
{
|
||||
classidsbegin.clear();
|
||||
|
||||
wordidsbegin.clear();
|
||||
}
|
||||
/* guoye: end */
|
||||
foreach_index (i, infiles[m])
|
||||
{
|
||||
if (i % (infiles[m].size() / 100 + 1) == 0)
|
||||
|
@ -1013,12 +1112,21 @@ public:
|
|||
fflush(stderr);
|
||||
}
|
||||
// build utterance descriptor
|
||||
/* guoye: start */
|
||||
if (m == 0 && !labels.empty())
|
||||
{
|
||||
classidsbegin.push_back(classids[0]->size());
|
||||
wordidsbegin.push_back(wordids[0]->size());
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
if (uttisvalid[i])
|
||||
{
|
||||
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty?
|
||||
/* guoye: start */
|
||||
// utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty?
|
||||
utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i], labels.empty() ? 0 : wordidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty?
|
||||
|
||||
/* guoye: end */
|
||||
const size_t uttframes = utterance.numframes(); // will throw if frame bounds not given --required to be given in this mode
|
||||
if (expandToUtt[m])
|
||||
{
|
||||
|
@ -1078,7 +1186,10 @@ public:
|
|||
// first verify that all the label files have the proper duration
|
||||
foreach_index (j, labels)
|
||||
{
|
||||
const auto &labseq = labels[j].find(key)->second;
|
||||
/* guoye: start */
|
||||
// const auto &labseq = labels[j].find(key)->second;
|
||||
const auto &labseq = labels[j].find(key)->second.first;
|
||||
/* guoye: end */
|
||||
// check if durations match; skip if not
|
||||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size() - 1].firstframe + labseq[labseq.size() - 1].numframes);
|
||||
if (labframes != uttframes)
|
||||
|
@ -1092,12 +1203,19 @@ public:
|
|||
}
|
||||
if (uttisvalid[i])
|
||||
{
|
||||
utteranceset.push_back(std::move(utterance));
|
||||
/* guoye: start */
|
||||
// utteranceset.push_back(std::move(utterance));
|
||||
/* guoye: end */
|
||||
_totalframes += uttframes;
|
||||
// then parse each mlf if the durations are consistent
|
||||
foreach_index (j, labels)
|
||||
{
|
||||
const auto &labseq = labels[j].find(key)->second;
|
||||
/* guoye: start */
|
||||
// const auto &labseq = labels[j].find(key)->second;
|
||||
const auto & seqs = labels[j].find(key)->second;
|
||||
// const auto &labseq = labels[j].find(key)->second.first;
|
||||
const auto &labseq = seqs.first;
|
||||
/* guoye: end */
|
||||
// expand classid sequence into flat array
|
||||
foreach_index (i2, labseq)
|
||||
{
|
||||
|
@ -1126,18 +1244,75 @@ public:
|
|||
}
|
||||
|
||||
classids[j]->push_back((CLASSIDTYPE) -1); // append a boundary marker marker for checking
|
||||
|
||||
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries[j]->push_back((HMMIDTYPE) -1); // append a boundary marker marker for checking
|
||||
|
||||
/* guoye: start */
|
||||
/*
|
||||
if (!labels[j].empty() && classids[j]->size() != _totalframes + utteranceset.size())
|
||||
LogicError("minibatchutterancesource: label duration inconsistent with feature file in MLF label set: %ls", key.c_str());
|
||||
assert(labels[j].empty() || classids[j]->size() == _totalframes + utteranceset.size());
|
||||
*/
|
||||
// guoye: because we do utteranceset.push_back(std::move(utterance)) in the late stage
|
||||
if (!labels[j].empty() && classids[j]->size() != _totalframes + utteranceset.size() + 1)
|
||||
LogicError("minibatchutterancesource: label duration inconsistent with feature file in MLF label set: %ls", key.c_str());
|
||||
assert(labels[j].empty() || classids[j]->size() == _totalframes + utteranceset.size() + 1);
|
||||
/* guoye: end */
|
||||
|
||||
const auto &wordlabseq = seqs.second;
|
||||
|
||||
if (j == 0)
|
||||
utterance.setnumwords(short(wordlabseq.size()));
|
||||
|
||||
foreach_index(i2, wordlabseq)
|
||||
{
|
||||
const auto &e = wordlabseq[i2];
|
||||
if (e != (WORDIDTYPE)e)
|
||||
RuntimeError("WORDIDTYPE has too few bits");
|
||||
|
||||
wordids[j]->push_back(e);
|
||||
}
|
||||
wordids[j]->push_back((WORDIDTYPE)-1); // append a boundary marker marker for checking
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
/* mask for guoye debug */
|
||||
/*
|
||||
foreach_index(j, wordlabels)
|
||||
{
|
||||
const auto &wordlabseq = wordlabels[j].find(key)->second.words;
|
||||
// expand classid sequence into flat array
|
||||
|
||||
if (j == 0)
|
||||
utterance.setnumwords(short(wordlabseq.size()));
|
||||
|
||||
foreach_index(i2, wordlabseq)
|
||||
{
|
||||
const auto &e = wordlabseq[i2];
|
||||
if (e.wordindex != (WORDIDTYPE)e.wordindex)
|
||||
RuntimeError("WORDIDTYPE has too few bits");
|
||||
|
||||
wordids[j]->push_back(e.wordindex);
|
||||
}
|
||||
wordids[j]->push_back((WORDIDTYPE)-1); // append a boundary marker marker for checking
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
/* guoye: end */
|
||||
utteranceset.push_back(std::move(utterance));
|
||||
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(classids.empty() && labels.empty());
|
||||
/* guoye: start */
|
||||
// assert(classids.empty() && labels.empty());
|
||||
assert(classids.empty() && labels.empty() && wordids.empty());
|
||||
/* guoye: end */
|
||||
utteranceset.push_back(std::move(utterance));
|
||||
_totalframes += uttframes;
|
||||
}
|
||||
|
@ -1424,6 +1599,9 @@ private:
|
|||
auto &uttref = randomizedutterancerefs[i];
|
||||
uttref.globalts = t;
|
||||
uttref.numframes = randomizedchunks[0][uttref.chunkindex].getchunkdata().numframes(uttref.utteranceindex());
|
||||
/* guoye: start */
|
||||
uttref.numwords = randomizedchunks[0][uttref.chunkindex].getchunkdata().numwords(uttref.utteranceindex());
|
||||
/* guoye: end */
|
||||
t = uttref.globalte();
|
||||
}
|
||||
assert(t == sweepts + _totalframes);
|
||||
|
@ -1486,6 +1664,7 @@ private:
|
|||
// Returns true if we actually did read something.
|
||||
bool requirerandomizedchunk(const size_t chunkindex, const size_t windowbegin, const size_t windowend)
|
||||
{
|
||||
|
||||
size_t numinram = 0;
|
||||
|
||||
if (chunkindex < windowbegin || chunkindex >= windowend)
|
||||
|
@ -1510,7 +1689,10 @@ private:
|
|||
fprintf(stderr, "feature set %d: requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", m, (int) chunkindex, (int) chunk.globalts, (int) (chunk.globalte() - 1), (int) (chunksinram + 1));
|
||||
msra::util::attempt(5, [&]() // (reading from network)
|
||||
{
|
||||
chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, verbosity);
|
||||
/* guoye: start */
|
||||
// chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, verbosity);
|
||||
chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, specialwordids, verbosity);
|
||||
/* guoye: end */
|
||||
});
|
||||
}
|
||||
chunksinram++;
|
||||
|
@ -1561,6 +1743,8 @@ public:
|
|||
verbosity = newverbosity;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// get the next minibatch
|
||||
// A minibatch is made up of one or more utterances.
|
||||
// We will return less than 'framesrequested' unless the first utterance is too long.
|
||||
|
@ -1569,16 +1753,21 @@ public:
|
|||
// This is efficient since getbatch() is called with sequential 'globalts' except at epoch start.
|
||||
// Note that the start of an epoch does not necessarily fall onto an utterance boundary. The caller must use firstvalidglobalts() to find the first valid globalts at or after a given time.
|
||||
// Support for data parallelism: If mpinodes > 1 then we will
|
||||
|
||||
// - load only a subset of blocks from the disk
|
||||
// - skip frames/utterances in not-loaded blocks in the returned data
|
||||
// - 'framesadvanced' will still return the logical #frames; that is, by how much the global time index is advanced
|
||||
bool getbatch(const size_t globalts, const size_t framesrequested,
|
||||
const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
|
||||
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
/* guoye: start */
|
||||
// std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids, std::vector<std::vector<size_t>> &wids, std::vector<std::vector<short>> &nws,
|
||||
/* guoye: end */
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &latticepairs, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries2) override
|
||||
{
|
||||
|
||||
bool readfromdisk = false; // return value: shall be 'true' if we paged in anything
|
||||
|
||||
auto_timer timergetbatch;
|
||||
|
@ -1624,6 +1813,9 @@ public:
|
|||
|
||||
// determine the true #frames we return, for allocation--it is less than mbframes in the case of MPI/data-parallel sub-set mode
|
||||
size_t tspos = 0;
|
||||
/* guoye: start */
|
||||
size_t twrds = 0;
|
||||
/* guoye: end */
|
||||
for (size_t pos = spos; pos < epos; pos++)
|
||||
{
|
||||
const auto &uttref = randomizedutterancerefs[pos];
|
||||
|
@ -1631,11 +1823,18 @@ public:
|
|||
continue;
|
||||
|
||||
tspos += uttref.numframes;
|
||||
/* guoye: start */
|
||||
twrds += uttref.numwords;
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// resize feat and uids
|
||||
feat.resize(vdim.size());
|
||||
uids.resize(classids.size());
|
||||
/* guoye: start */
|
||||
wids.resize(wordids.size());
|
||||
nws.resize(wordids.size());
|
||||
/* guoye: end */
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries2.resize(classids.size());
|
||||
sentendmark.resize(vdim.size());
|
||||
|
@ -1649,15 +1848,29 @@ public:
|
|||
{
|
||||
foreach_index (j, uids)
|
||||
{
|
||||
/* guoye: start */
|
||||
nws[j].clear();
|
||||
/* guoye: end */
|
||||
|
||||
if (issupervised()) // empty means unsupervised training -> return empty uids
|
||||
{
|
||||
uids[j].resize(tspos);
|
||||
/* guoye: start */
|
||||
wids[j].resize(twrds);
|
||||
/* guoye: end */
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries2[j].resize(tspos);
|
||||
}
|
||||
else
|
||||
{
|
||||
uids[i].clear();
|
||||
/* guoye: start */
|
||||
|
||||
// uids[i].clear();
|
||||
// guoye: i think it is a bug, i should be j
|
||||
uids[j].clear();
|
||||
|
||||
wids[j].clear();
|
||||
/* guoye: end */
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries2[i].clear();
|
||||
}
|
||||
|
@ -1674,6 +1887,9 @@ public:
|
|||
if (verbosity > 0)
|
||||
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int) spos, (int) (epos - 1), (int) tspos, (int) mbframes, (int) framesrequested, (int) sweep);
|
||||
tspos = 0; // relative start of utterance 'pos' within the returned minibatch
|
||||
/* guoye: start */
|
||||
twrds = 0;
|
||||
/* guoye: end */
|
||||
for (size_t pos = spos; pos < epos; pos++)
|
||||
{
|
||||
const auto &uttref = randomizedutterancerefs[pos];
|
||||
|
@ -1681,6 +1897,9 @@ public:
|
|||
continue;
|
||||
|
||||
size_t n = 0;
|
||||
/* guoye: start */
|
||||
size_t nw = 0;
|
||||
/* guoye: end */
|
||||
foreach_index (i, randomizedchunks)
|
||||
{
|
||||
const auto &chunk = randomizedchunks[i][uttref.chunkindex];
|
||||
|
@ -1692,6 +1911,9 @@ public:
|
|||
sentendmark[i].push_back(n + tspos);
|
||||
assert(n == uttframes.cols() && uttref.numframes == n && chunkdata.numframes(uttref.utteranceindex()) == n);
|
||||
|
||||
/* guoye: start */
|
||||
nw = uttref.numwords;
|
||||
/* guoye: end */
|
||||
// copy the frames and class labels
|
||||
for (size_t t = 0; t < n; t++) // t = time index into source utterance
|
||||
{
|
||||
|
@ -1714,6 +1936,9 @@ public:
|
|||
if (i == 0)
|
||||
{
|
||||
auto uttclassids = getclassids(uttref);
|
||||
/* guoye: start */
|
||||
auto uttwordids = getwordids(uttref);
|
||||
/* guoye: end */
|
||||
std::vector<shiftedvector<biggrowablevector<HMMIDTYPE>>> uttphoneboudaries;
|
||||
if (m_generatePhoneBoundaries)
|
||||
uttphoneboudaries = getphonebound(uttref);
|
||||
|
@ -1742,9 +1967,28 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
foreach_index(j, uttwordids)
|
||||
{
|
||||
nws[j].push_back(short(nw));
|
||||
|
||||
for (size_t t = 0; t < nw; t++) // t = time index into source utterance
|
||||
{
|
||||
if (issupervised())
|
||||
{
|
||||
wids[j][t + twrds] = uttwordids[j][t];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
tspos += n;
|
||||
/* guoye: start */
|
||||
twrds += nw;
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
foreach_index (i, feat)
|
||||
|
@ -1795,6 +2039,9 @@ public:
|
|||
// resize feat and uids
|
||||
feat.resize(vdim.size());
|
||||
uids.resize(classids.size());
|
||||
/* guoye: start */
|
||||
// no need to care about wids for framemode = true
|
||||
/* guoye: end */
|
||||
assert(feat.size() == vdim.size());
|
||||
assert(feat.size() == randomizedchunks.size());
|
||||
foreach_index (i, feat)
|
||||
|
@ -1878,31 +2125,360 @@ public:
|
|||
return readfromdisk;
|
||||
}
|
||||
|
||||
// get the next minibatch
|
||||
// A minibatch is made up of one or more utterances.
|
||||
// We will return less than 'framesrequested' unless the first utterance is too long.
|
||||
// Note that this may return frames that are beyond the epoch end, but the first frame is always within the epoch.
|
||||
// We specify the utterance by its global start time (in a space of a infinitely repeated training set).
|
||||
// This is efficient since getbatch() is called with sequential 'globalts' except at epoch start.
|
||||
// Note that the start of an epoch does not necessarily fall onto an utterance boundary. The caller must use firstvalidglobalts() to find the first valid globalts at or after a given time.
|
||||
// Support for data parallelism: If mpinodes > 1 then we will
|
||||
// - load only a subset of blocks from the disk
|
||||
// - skip frames/utterances in not-loaded blocks in the returned data
|
||||
// - 'framesadvanced' will still return the logical #frames; that is, by how much the global time index is advanced
|
||||
/* guoye: start */
|
||||
bool getbatch(const size_t globalts, const size_t framesrequested,
|
||||
const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
|
||||
std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &latticepairs, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries2) override
|
||||
{
|
||||
|
||||
bool readfromdisk = false; // return value: shall be 'true' if we paged in anything
|
||||
|
||||
auto_timer timergetbatch;
|
||||
assert(_totalframes > 0);
|
||||
|
||||
// update randomization if a new sweep is entered --this is a complex operation that updates many of the data members used below
|
||||
const size_t sweep = lazyrandomization(globalts);
|
||||
|
||||
size_t mbframes = 0;
|
||||
const std::vector<char> noboundaryflags; // dummy
|
||||
if (!framemode) // regular utterance mode
|
||||
{
|
||||
// find utterance position for globalts
|
||||
// There must be a precise match; it is not possible to specify frames that are not on boundaries.
|
||||
auto positer = randomizedutteranceposmap.find(globalts);
|
||||
if (positer == randomizedutteranceposmap.end())
|
||||
LogicError("getbatch: invalid 'globalts' parameter; must match an existing utterance boundary");
|
||||
const size_t spos = positer->second;
|
||||
|
||||
// determine how many utterances will fit into the requested minibatch size
|
||||
mbframes = randomizedutterancerefs[spos].numframes; // at least one utterance, even if too long
|
||||
size_t epos;
|
||||
for (epos = spos + 1; epos < numutterances && ((mbframes + randomizedutterancerefs[epos].numframes) < framesrequested); epos++) // add more utterances as long as they fit within requested minibatch size
|
||||
mbframes += randomizedutterancerefs[epos].numframes;
|
||||
|
||||
// do some paging housekeeping
|
||||
// This will also set the feature-kind information if it's the first time.
|
||||
// Free all chunks left of the range.
|
||||
// Page-in all chunks right of the range.
|
||||
// We are a little more blunt for now: Free all outside the range, and page in only what is touched. We could save some loop iterations.
|
||||
const size_t windowbegin = positionchunkwindows[spos].windowbegin();
|
||||
const size_t windowend = positionchunkwindows[epos - 1].windowend();
|
||||
for (size_t k = 0; k < windowbegin; k++)
|
||||
releaserandomizedchunk(k);
|
||||
for (size_t k = windowend; k < randomizedchunks[0].size(); k++)
|
||||
releaserandomizedchunk(k);
|
||||
for (size_t pos = spos; pos < epos; pos++)
|
||||
if ((randomizedutterancerefs[pos].chunkindex % numsubsets) == subsetnum)
|
||||
readfromdisk |= requirerandomizedchunk(randomizedutterancerefs[pos].chunkindex, windowbegin, windowend); // (window range passed in for checking only)
|
||||
|
||||
// Note that the above loop loops over all chunks incl. those that we already should have.
|
||||
// This has an effect, e.g., if 'numsubsets' has changed (we will fill gaps).
|
||||
|
||||
// determine the true #frames we return, for allocation--it is less than mbframes in the case of MPI/data-parallel sub-set mode
|
||||
size_t tspos = 0;
|
||||
for (size_t pos = spos; pos < epos; pos++)
|
||||
{
|
||||
const auto &uttref = randomizedutterancerefs[pos];
|
||||
if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node
|
||||
continue;
|
||||
|
||||
tspos += uttref.numframes;
|
||||
}
|
||||
|
||||
// resize feat and uids
|
||||
feat.resize(vdim.size());
|
||||
uids.resize(classids.size());
|
||||
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries2.resize(classids.size());
|
||||
sentendmark.resize(vdim.size());
|
||||
assert(feat.size() == vdim.size());
|
||||
assert(feat.size() == randomizedchunks.size());
|
||||
foreach_index(i, feat)
|
||||
{
|
||||
feat[i].resize(vdim[i], tspos);
|
||||
|
||||
if (i == 0)
|
||||
{
|
||||
foreach_index(j, uids)
|
||||
{
|
||||
if (issupervised()) // empty means unsupervised training -> return empty uids
|
||||
{
|
||||
uids[j].resize(tspos);
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries2[j].resize(tspos);
|
||||
}
|
||||
else
|
||||
{
|
||||
uids[i].clear();
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries2[i].clear();
|
||||
}
|
||||
latticepairs.clear(); // will push_back() below
|
||||
transcripts.clear();
|
||||
}
|
||||
foreach_index(j, sentendmark)
|
||||
{
|
||||
sentendmark[j].clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
// return these utterances
|
||||
if (verbosity > 0)
|
||||
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int)spos, (int)(epos - 1), (int)tspos, (int)mbframes, (int)framesrequested, (int)sweep);
|
||||
tspos = 0; // relative start of utterance 'pos' within the returned minibatch
|
||||
for (size_t pos = spos; pos < epos; pos++)
|
||||
{
|
||||
const auto &uttref = randomizedutterancerefs[pos];
|
||||
if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node
|
||||
continue;
|
||||
|
||||
size_t n = 0;
|
||||
foreach_index(i, randomizedchunks)
|
||||
{
|
||||
const auto &chunk = randomizedchunks[i][uttref.chunkindex];
|
||||
const auto &chunkdata = chunk.getchunkdata();
|
||||
assert((numsubsets > 1) || (uttref.globalts == globalts + tspos));
|
||||
auto uttframes = chunkdata.getutteranceframes(uttref.utteranceindex());
|
||||
matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[j].size() and m[j][i] as required by augmentneighbors())
|
||||
n = uttframevectors.size();
|
||||
sentendmark[i].push_back(n + tspos);
|
||||
assert(n == uttframes.cols() && uttref.numframes == n && chunkdata.numframes(uttref.utteranceindex()) == n);
|
||||
|
||||
// copy the frames and class labels
|
||||
for (size_t t = 0; t < n; t++) // t = time index into source utterance
|
||||
{
|
||||
size_t leftextent, rightextent;
|
||||
// page in the needed range of frames
|
||||
if (leftcontext[i] == 0 && rightcontext[i] == 0)
|
||||
{
|
||||
leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
leftextent = leftcontext[i];
|
||||
rightextent = rightcontext[i];
|
||||
}
|
||||
augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], t + tspos);
|
||||
// augmentneighbors(uttframevectors, noboundaryflags, t, feat[i], t + tspos);
|
||||
}
|
||||
|
||||
// copy the frames and class labels
|
||||
if (i == 0)
|
||||
{
|
||||
auto uttclassids = getclassids(uttref);
|
||||
std::vector<shiftedvector<biggrowablevector<HMMIDTYPE>>> uttphoneboudaries;
|
||||
if (m_generatePhoneBoundaries)
|
||||
uttphoneboudaries = getphonebound(uttref);
|
||||
foreach_index(j, uttclassids)
|
||||
{
|
||||
for (size_t t = 0; t < n; t++) // t = time index into source utterance
|
||||
{
|
||||
if (issupervised())
|
||||
{
|
||||
uids[j][t + tspos] = uttclassids[j][t];
|
||||
if (m_generatePhoneBoundaries)
|
||||
phoneboundaries2[j][t + tspos] = uttphoneboudaries[j][t];
|
||||
}
|
||||
}
|
||||
|
||||
if (!this->lattices.empty())
|
||||
{
|
||||
auto latticepair = chunkdata.getutterancelattice(uttref.utteranceindex());
|
||||
latticepairs.push_back(latticepair);
|
||||
// look up reference
|
||||
const auto &key = latticepair->getkey();
|
||||
if (!allwordtranscripts.empty())
|
||||
{
|
||||
const auto &transcript = allwordtranscripts.find(key)->second;
|
||||
transcripts.push_back(transcript.words);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tspos += n;
|
||||
}
|
||||
|
||||
foreach_index(i, feat)
|
||||
{
|
||||
assert(tspos == feat[i].cols());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const size_t sweepts = sweep * _totalframes; // first global frame index for this sweep
|
||||
const size_t sweepte = sweepts + _totalframes; // and its end
|
||||
const size_t globalte = std::min(globalts + framesrequested, sweepte); // we return as much as requested, but not exceeding sweep end
|
||||
mbframes = globalte - globalts; // that's our mb size
|
||||
|
||||
// Perform randomization of the desired frame range
|
||||
m_frameRandomizer.randomizeFrameRange(globalts, globalte);
|
||||
|
||||
// determine window range
|
||||
// We enumerate all frames--can this be done more efficiently?
|
||||
const size_t firstchunk = chunkforframepos(globalts);
|
||||
const size_t lastchunk = chunkforframepos(globalte - 1);
|
||||
const size_t windowbegin = randomizedchunks[0][firstchunk].windowbegin;
|
||||
const size_t windowend = randomizedchunks[0][lastchunk].windowend;
|
||||
if (verbosity > 0)
|
||||
fprintf(stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n",
|
||||
(int)globalts, (int)globalte, (int)mbframes, (int)framesrequested, (int)sweep, (int)firstchunk, (int)lastchunk, (int)windowbegin, (int)windowend);
|
||||
// release all data outside, and page in all data inside
|
||||
for (size_t k = 0; k < windowbegin; k++)
|
||||
releaserandomizedchunk(k);
|
||||
for (size_t k = windowbegin; k < windowend; k++)
|
||||
if ((k % numsubsets) == subsetnum) // in MPI mode, we skip chunks this way
|
||||
readfromdisk |= requirerandomizedchunk(k, windowbegin, windowend); // (window range passed in for checking only, redundant here)
|
||||
for (size_t k = windowend; k < randomizedchunks[0].size(); k++)
|
||||
releaserandomizedchunk(k);
|
||||
|
||||
// determine the true #frames we return--it is less than mbframes in the case of MPI/data-parallel sub-set mode
|
||||
// First determine it for all nodes, then pick the min over all nodes, as to give all the same #frames for better load balancing.
|
||||
// TODO: No, return all; and leave it to caller to redistribute them [Zhijie Yan]
|
||||
std::vector<size_t> subsetsizes(numsubsets, 0);
|
||||
for (size_t i = 0; i < mbframes; i++) // i is input frame index; j < i in case of MPI/data-parallel sub-set mode
|
||||
{
|
||||
const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + i);
|
||||
subsetsizes[frameref.chunkindex % numsubsets]++;
|
||||
}
|
||||
size_t j = subsetsizes[subsetnum]; // return what we have --TODO: we can remove the above full computation again now
|
||||
const size_t allocframes = std::max(j, (mbframes + numsubsets - 1) / numsubsets); // we leave space for the desired #frames, assuming caller will try to pad them later
|
||||
|
||||
// resize feat and uids
|
||||
feat.resize(vdim.size());
|
||||
uids.resize(classids.size());
|
||||
assert(feat.size() == vdim.size());
|
||||
assert(feat.size() == randomizedchunks.size());
|
||||
foreach_index(i, feat)
|
||||
{
|
||||
feat[i].resize(vdim[i], allocframes);
|
||||
feat[i].shrink(vdim[i], j);
|
||||
|
||||
if (i == 0)
|
||||
{
|
||||
foreach_index(k, uids)
|
||||
{
|
||||
if (issupervised()) // empty means unsupervised training -> return empty uids
|
||||
uids[k].resize(j);
|
||||
else
|
||||
uids[k].clear();
|
||||
latticepairs.clear(); // will push_back() below
|
||||
transcripts.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// return randomized frames for the time range of those utterances
|
||||
size_t currmpinodeframecount = 0;
|
||||
for (size_t j2 = 0; j2 < mbframes; j2++)
|
||||
{
|
||||
if (currmpinodeframecount >= feat[0].cols()) // MPI/data-parallel mode: all nodes return the same #frames, which is how feat(,) is allocated
|
||||
break;
|
||||
|
||||
// map to time index inside arrays
|
||||
const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + j2);
|
||||
|
||||
// in MPI/data-parallel mode, skip frames that are not in chunks loaded for this MPI node
|
||||
if ((frameref.chunkindex % numsubsets) != subsetnum)
|
||||
continue;
|
||||
|
||||
// random utterance
|
||||
readfromdisk |= requirerandomizedchunk(frameref.chunkindex, windowbegin, windowend); // (this is just a check; should not actually page in anything)
|
||||
|
||||
foreach_index(i, randomizedchunks)
|
||||
{
|
||||
const auto &chunk = randomizedchunks[i][frameref.chunkindex];
|
||||
const auto &chunkdata = chunk.getchunkdata();
|
||||
auto uttframes = chunkdata.getutteranceframes(frameref.utteranceindex());
|
||||
matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[.].size() and m[.][.] as required by augmentneighbors())
|
||||
const size_t n = uttframevectors.size();
|
||||
assert(n == uttframes.cols() && chunkdata.numframes(frameref.utteranceindex()) == n);
|
||||
n;
|
||||
|
||||
// copy frame and class labels
|
||||
const size_t t = frameref.frameindex();
|
||||
|
||||
size_t leftextent, rightextent;
|
||||
// page in the needed range of frames
|
||||
if (leftcontext[i] == 0 && rightcontext[i] == 0)
|
||||
{
|
||||
leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
leftextent = leftcontext[i];
|
||||
rightextent = rightcontext[i];
|
||||
}
|
||||
augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], currmpinodeframecount);
|
||||
|
||||
if (issupervised() && i == 0)
|
||||
{
|
||||
auto frameclassids = getclassids(frameref);
|
||||
foreach_index(k, uids)
|
||||
uids[k][currmpinodeframecount] = frameclassids[k][t];
|
||||
}
|
||||
}
|
||||
|
||||
currmpinodeframecount++;
|
||||
}
|
||||
}
|
||||
timegetbatch = timergetbatch;
|
||||
|
||||
// this is the number of frames we actually moved ahead in time
|
||||
framesadvanced = mbframes;
|
||||
|
||||
return readfromdisk;
|
||||
}
|
||||
bool supportsbatchsubsetting() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool getbatch(const size_t globalts,
|
||||
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices2, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries2)
|
||||
/* guoye: start */
|
||||
const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
|
||||
// const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids, std::vector<std::vector<size_t>> &wids,
|
||||
/* guoye: end */
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices2, std::vector<std::vector<size_t>> &sentendmark,
|
||||
std::vector<std::vector<size_t>> &phoneboundaries2)
|
||||
|
||||
{
|
||||
size_t dummy;
|
||||
/* guoye: start */
|
||||
return getbatch(globalts, framesrequested, 0, 1, dummy, feat, uids, transcripts, lattices2, sentendmark, phoneboundaries2);
|
||||
// return getbatch(globalts, framesrequested, 0, 1, dummy, feat, uids, wids, transcripts, lattices, sentendmark, phoneboundaries);
|
||||
/* guoye: end */
|
||||
|
||||
}
|
||||
|
||||
double gettimegetbatch()
|
||||
{
|
||||
return timegetbatch;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
|
||||
bool getbatch(const size_t /*globalts*/,
|
||||
const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & /*transcripts*/,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> & /*latticepairs*/)
|
||||
/* guoye: start */
|
||||
const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/,
|
||||
// const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/, std::vector<size_t> & /*wids*/,
|
||||
/* guoye: end */
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & /*transcripts*/,
|
||||
std::vector<std::shared_ptr<const latticesource::latticepair>> & /*latticepairs*/) override
|
||||
{
|
||||
// should never get here
|
||||
RuntimeError("minibatchframesourcemulti: getbatch() being called for single input feature and single output feature, should use minibatchutterancesource instead\n");
|
||||
|
@ -1912,6 +2488,14 @@ public:
|
|||
// uids.resize(1);
|
||||
// return getbatch(globalts, framesrequested, feat[0], uids[0], transcripts, latticepairs);
|
||||
}
|
||||
|
||||
|
||||
double gettimegetbatch()
|
||||
{
|
||||
return timegetbatch;
|
||||
}
|
||||
|
||||
|
||||
|
||||
size_t totalframes() const
|
||||
{
|
||||
|
|
|
@ -36,6 +36,8 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
// const std::vector<std::unique_ptr<latticesource>> &lattices;
|
||||
const latticesource &lattices;
|
||||
|
||||
|
||||
|
||||
// std::vector<latticesource> lattices;
|
||||
// word-level transcripts (for MMI mode when adding best path to lattices)
|
||||
const std::map<std::wstring, msra::lattices::lattice::htkmlfwordsequence> &allwordtranscripts; // (used for getting word-level transcripts)
|
||||
|
@ -158,7 +160,15 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
reader.readNoAlloc(utteranceset[i].parsedpath, (const string &) featkind, sampperiod, uttframes); // note: file info here used for checkuing only
|
||||
// page in lattice data
|
||||
if (!latticesource.empty())
|
||||
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols());
|
||||
/* guoye: start */
|
||||
// we currently don't care about kaldi format, so, just to make the compiler happy
|
||||
// latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols());
|
||||
{
|
||||
std::set<int> specialwordids;
|
||||
specialwordids.clear();
|
||||
latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols(), specialwordids);
|
||||
}
|
||||
/* guoye: end */
|
||||
}
|
||||
// fprintf (stderr, "\n");
|
||||
if (verbosity)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
|
||||
== Authors of the Linux Building README ==
|
||||
|
||||
Kaisheng Yao
|
||||
Kaisheng Yao
|
||||
Microsoft Research
|
||||
email: kaisheny@microsoft.com
|
||||
|
||||
Wengong Jin,
|
||||
Shanghai Jiao Tong University
|
||||
email: acmgokun@gmail.com
|
||||
Wengong Jin,
|
||||
Shanghai Jiao Tong University
|
||||
email: acmgokun@gmail.com
|
||||
|
||||
Yu Zhang, Leo Liu, Scott Cyphers
|
||||
CSAIL, Massachusetts Institute of Technology
|
||||
|
@ -78,10 +78,10 @@ To clean
|
|||
|
||||
== Run ==
|
||||
All executables are in bin directory:
|
||||
cntk: The main executable for CNTK
|
||||
*.so: shared library for corresponding reader, these readers will be linked and loaded dynamically at runtime.
|
||||
cntk: The main executable for CNTK
|
||||
*.so: shared library for corresponding reader, these readers will be linked and loaded dynamically at runtime.
|
||||
|
||||
./cntk configFile=${your cntk config file}
|
||||
./cntk configFile=${your cntk config file}
|
||||
|
||||
== Kaldi Reader ==
|
||||
This is a HTKMLF reader and kaldi writer (for decode)
|
||||
|
|
|
@ -46,7 +46,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
bool useParallelTrain,
|
||||
StreamMinibatchInputs& inputMatrices,
|
||||
size_t& actualMBSize,
|
||||
const MPIWrapperPtr& mpi)
|
||||
/* guoye: start */
|
||||
// const MPIWrapperPtr& mpi)
|
||||
const MPIWrapperPtr& mpi,
|
||||
size_t& actualNumWords)
|
||||
/* guoye: end */
|
||||
{
|
||||
// Reading consists of a sequence of Reader API calls:
|
||||
// - GetMinibatch() --fills the inputMatrices and copies the MBLayout from Reader into inputMatrices
|
||||
|
@ -71,8 +75,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
auto uids = node->getuidprt();
|
||||
auto boundaries = node->getboundaryprt();
|
||||
auto extrauttmap = node->getextrauttmap();
|
||||
/* guoye: start */
|
||||
auto wids = node->getwidprt();
|
||||
auto nws = node->getnwprt();
|
||||
// trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *boundaries, *extrauttmap);
|
||||
trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *wids, *nws, *boundaries, *extrauttmap);
|
||||
|
||||
trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *boundaries, *extrauttmap);
|
||||
actualNumWords = 0;
|
||||
for (size_t i = 0; i < (*nws).size(); i++)
|
||||
actualNumWords += (*nws)[i];
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// TODO: move this into shim for the old readers.
|
||||
|
@ -284,11 +296,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
private:
|
||||
typedef std::vector<shared_ptr<const msra::dbn::latticesource::latticepair>> Lattice;
|
||||
typedef std::vector<size_t> Uid;
|
||||
/* guoye: start */
|
||||
typedef std::vector<size_t> Wid;
|
||||
typedef std::vector<short> Nw;
|
||||
/* guoye: end */
|
||||
|
||||
typedef std::vector<size_t> ExtrauttMap;
|
||||
typedef std::vector<size_t> Boundaries;
|
||||
|
||||
typedef std::vector<shared_ptr<const msra::dbn::latticesource::latticepair>>* LatticePtr;
|
||||
typedef std::vector<size_t>* UidPtr;
|
||||
/* guoye: start */
|
||||
typedef std::vector<size_t>* WidPtr;
|
||||
typedef std::vector<short>* NwPtr;
|
||||
/* guoye: end */
|
||||
typedef std::vector<size_t>* ExtrauttMapPtr;
|
||||
typedef std::vector<size_t>* BoundariesPtr;
|
||||
typedef StreamMinibatchInputs Matrices;
|
||||
|
@ -298,6 +319,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
MBLayoutPtr m_MBLayoutCache;
|
||||
Lattice m_LatticeCache;
|
||||
Uid m_uidCache;
|
||||
/* guoye: start */
|
||||
Wid m_widCache;
|
||||
Nw m_nwCache;
|
||||
/* guoye: end */
|
||||
ExtrauttMap m_extrauttmapCache;
|
||||
Boundaries m_BoundariesCache;
|
||||
shared_ptr<Matrix<ElemType>> m_netCriterionAccumulator;
|
||||
|
@ -313,6 +338,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
Matrices m_netInputMatrixPtr;
|
||||
LatticePtr m_netLatticePtr;
|
||||
UidPtr m_netUidPtr;
|
||||
/* guoye: start */
|
||||
WidPtr m_netWidPtr;
|
||||
NwPtr m_netNwPtr;
|
||||
/* guoye: end */
|
||||
ExtrauttMapPtr m_netExtrauttMapPtr;
|
||||
BoundariesPtr m_netBoundariesPtr;
|
||||
// we remember the pointer to the learnable Nodes so that we can accumulate the gradient once a sub-minibatch is done
|
||||
|
@ -352,7 +381,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
public:
|
||||
SubminibatchDispatcher()
|
||||
: m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr)
|
||||
/* guoye: start */
|
||||
// : m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr)
|
||||
: m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr), m_netWidPtr(nullptr), m_netNwPtr(nullptr)
|
||||
/* guoye: end */
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -398,6 +430,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_netLatticePtr = node->getLatticePtr();
|
||||
m_netExtrauttMapPtr = node->getextrauttmap();
|
||||
m_netUidPtr = node->getuidprt();
|
||||
/* guoye: start */
|
||||
m_netWidPtr = node->getwidprt();
|
||||
m_netNwPtr = node->getnwprt();
|
||||
/* guoye: end */
|
||||
m_netBoundariesPtr = node->getboundaryprt();
|
||||
m_hasLattices = true;
|
||||
}
|
||||
|
@ -408,6 +444,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_netUidPtr = nullptr;
|
||||
m_netBoundariesPtr = nullptr;
|
||||
m_hasLattices = false;
|
||||
/* guoye: start */
|
||||
m_netWidPtr = nullptr;
|
||||
m_netNwPtr = nullptr;
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -444,11 +484,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_uidCache.clear();
|
||||
m_extrauttmapCache.clear();
|
||||
m_BoundariesCache.clear();
|
||||
/* guoye: start */
|
||||
m_widCache.clear();
|
||||
m_nwCache.clear();
|
||||
/* guoye: end */
|
||||
|
||||
|
||||
m_LatticeCache = *m_netLatticePtr;
|
||||
m_uidCache = *m_netUidPtr;
|
||||
m_extrauttmapCache = *m_netExtrauttMapPtr;
|
||||
m_BoundariesCache = *m_netBoundariesPtr;
|
||||
/* guoye: start */
|
||||
m_widCache = *m_netWidPtr;
|
||||
m_nwCache = *m_netNwPtr;
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// subminibatches are cutted at the parallel sequence level;
|
||||
|
@ -495,10 +544,18 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
BoundariesPtr decimatedBoundaryPtr, /* output: boundary after decimation*/
|
||||
ExtrauttMapPtr decimatedExtraMapPtr, /* output: extramap after decimation*/
|
||||
UidPtr decimatedUidPtr, /* output: Uid after decimation*/
|
||||
/* guoye: start */
|
||||
WidPtr decimatedWidPtr, /* output: Wid after decimation*/
|
||||
NwPtr decimatedNwPtr, /* output: Nw after decimation*/
|
||||
/* guoye: end */
|
||||
const Lattice lattices, /* input: lattices to be decimated */
|
||||
const Boundaries boundaries, /* input: boundary to be decimated */
|
||||
const ExtrauttMap extraMaps, /* input: extra map to be decimated */
|
||||
const Uid uids, /* input: uid to be decimated*/
|
||||
/* guoye: start */
|
||||
const Wid wids, /* input: uid to be decimated*/
|
||||
const Nw nws, /* input: uid to be decimated*/
|
||||
/* guoye: end */
|
||||
pair<size_t, size_t> parallelSeqRange /* input: what parallel sequence range we are looking at */
|
||||
)
|
||||
{
|
||||
|
@ -509,12 +566,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
decimatedBoundaryPtr->clear();
|
||||
decimatedExtraMapPtr->clear();
|
||||
decimatedUidPtr->clear();
|
||||
/* guoye: start */
|
||||
decimatedWidPtr->clear();
|
||||
decimatedNwPtr->clear();
|
||||
/* guoye: end */
|
||||
|
||||
size_t stFrame = 0;
|
||||
/* guoye: start */
|
||||
size_t stWord = 0;
|
||||
/* guoye: end */
|
||||
for (size_t iUtt = 0; iUtt < extraMaps.size(); iUtt++)
|
||||
{
|
||||
size_t numFramesInThisUtterance = lattices[iUtt]->getnumframes();
|
||||
size_t iParallelSeq = extraMaps[iUtt]; // i-th utterance belongs to iParallelSeq-th parallel sequence
|
||||
/* guoye: start */
|
||||
size_t numWordsInThisUtterance = nws[iUtt];
|
||||
/* guoye: end */
|
||||
if (iParallelSeq >= parallelSeqStId && iParallelSeq < parallelSeqEnId)
|
||||
{
|
||||
// this utterance has been selected
|
||||
|
@ -522,8 +589,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
decimatedBoundaryPtr->insert(decimatedBoundaryPtr->end(), boundaries.begin() + stFrame, boundaries.begin() + stFrame + numFramesInThisUtterance);
|
||||
decimatedUidPtr->insert(decimatedUidPtr->end(), uids.begin() + stFrame, uids.begin() + stFrame + numFramesInThisUtterance);
|
||||
decimatedExtraMapPtr->push_back(extraMaps[iUtt] - parallelSeqStId);
|
||||
/* guoye: start */
|
||||
|
||||
decimatedWidPtr->insert(decimatedWidPtr->end(), wids.begin() + stWord, wids.begin() + stWord + numWordsInThisUtterance);
|
||||
decimatedNwPtr->push_back(numWordsInThisUtterance);
|
||||
/* guoye: end */
|
||||
}
|
||||
stFrame += numFramesInThisUtterance;
|
||||
/* guoye: start */
|
||||
stWord += numWordsInThisUtterance;
|
||||
/* guoye: end */
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -538,12 +613,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (m_hasLattices)
|
||||
{
|
||||
DecimateLattices(
|
||||
/* guoye: start */
|
||||
/*output */
|
||||
m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr,
|
||||
// m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr,
|
||||
m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr, m_netWidPtr, m_netNwPtr,
|
||||
/*input to be decimated */
|
||||
m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache,
|
||||
// m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache,
|
||||
m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache, m_widCache, m_nwCache,
|
||||
/* what range we want ? */
|
||||
seqRange);
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// The following does m_netInputMatrixPtr = decimatedMatrices; with ownership shenanigans.
|
||||
|
|
|
@ -81,6 +81,9 @@ void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader *
|
|||
let bnNode = static_pointer_cast<BatchNormalizationNode<ElemType>>(node);
|
||||
size_t actualMBSize = 0;
|
||||
|
||||
/* guoye: start */
|
||||
size_t actualNumWords = 0;
|
||||
/* guoye: end */
|
||||
LOGPRINTF(stderr, "Estimating Statistics --> %ls\n", bnNode->GetName().c_str());
|
||||
|
||||
|
||||
|
@ -90,8 +93,10 @@ void PostComputingActions<ElemType>::BatchNormalizationStatistics(IDataReader *
|
|||
{
|
||||
// during the bn stat, dataRead must be ensured
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net,
|
||||
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
|
||||
/* guoye: start */
|
||||
// nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi, actualNumWords);
|
||||
/* guoye: end */
|
||||
if (!wasDataRead) LogicError("DataRead Failure in batch normalization statistics");
|
||||
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
|
|
|
@ -262,7 +262,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 1 \n");
|
||||
/* guoye: end */
|
||||
std::vector<ComputationNodeBasePtr> additionalNodesToEvaluate;
|
||||
|
||||
// Do not include the output nodes in the matrix sharing structure when using forward value matrix
|
||||
|
@ -273,13 +275,18 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
auto& outputNodes = net->OutputNodes();
|
||||
additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), outputNodes.cbegin(), outputNodes.cend());
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 2 \n");
|
||||
/* guoye: end */
|
||||
auto preComputeNodesList = net->GetNodesRequiringPreComputation();
|
||||
// LOGPRINTF(stderr, "SGD debug 2.1 \n");
|
||||
additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), preComputeNodesList.cbegin(), preComputeNodesList.cend());
|
||||
|
||||
// LOGPRINTF(stderr, "SGD debug 2.2 \n");
|
||||
// allocate memory for forward and backward computation
|
||||
net->AllocateAllMatrices(evaluationNodes, additionalNodesToEvaluate, criterionNodes[0]); // TODO: use criterionNodes.front() throughout
|
||||
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 3 \n");
|
||||
/* guoye: end */
|
||||
// get feature and label nodes into an array of matrices that will be passed to GetMinibatch()
|
||||
// TODO: instead, remember the nodes directly, to be able to handle both float and double nodes; current version will crash for mixed networks
|
||||
StreamMinibatchInputs* inputMatrices = new StreamMinibatchInputs();
|
||||
|
@ -287,23 +294,34 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
let& featureNodes = net->FeatureNodes();
|
||||
let& labelNodes = net->LabelNodes();
|
||||
// BUGBUG: ^^ should not get all feature/label nodes, but only the ones referenced in a criterion
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 4 \n");
|
||||
/* guoye: end */
|
||||
for (size_t pass = 0; pass < 2; pass++)
|
||||
{
|
||||
auto& nodes = (pass == 0) ? featureNodes : labelNodes;
|
||||
for (const auto & node : nodes)
|
||||
inputMatrices->AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout());
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 5 \n");
|
||||
/* guoye: end */
|
||||
// get hmm file for sequence training
|
||||
bool isSequenceTrainingCriterion = (criterionNodes[0]->OperationName() == L"SequenceWithSoftmax");
|
||||
// LOGPRINTF(stderr, "SGD debug 5.1 \n");
|
||||
if (isSequenceTrainingCriterion)
|
||||
{
|
||||
// SequenceWithSoftmaxNode<ElemType>* node = static_cast<SequenceWithSoftmaxNode<ElemType>*>(criterionNodes[0]);
|
||||
// LOGPRINTF(stderr, "SGD debug 5.2 \n");
|
||||
auto node = dynamic_pointer_cast<SequenceWithSoftmaxNode<ElemType>>(criterionNodes[0]);
|
||||
|
||||
auto hmm = node->gethmm();
|
||||
// LOGPRINTF(stderr, "SGD debug 5.4 \n");
|
||||
trainSetDataReader->GetHmmData(hmm);
|
||||
// LOGPRINTF(stderr, "SGD debug 5.5 \n");
|
||||
}
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 6 \n");
|
||||
/* guoye: end */
|
||||
|
||||
// used for KLD regularized adaptation. For all other adaptation techniques
|
||||
// use MEL to edit the model and using normal training algorithm
|
||||
|
@ -329,6 +347,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
// allocate memory for forward computation
|
||||
refNet->AllocateAllMatrices({refNode}, {}, nullptr);
|
||||
}
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 7 \n");
|
||||
/* guoye: end */
|
||||
|
||||
// initializing weights and gradient holder
|
||||
// only one criterion so far TODO: support multiple ones?
|
||||
|
@ -338,6 +359,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
size_t numParameters = 0;
|
||||
|
||||
vector<wstring> nodesToUpdateDescriptions; // for logging only
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 8 \n");
|
||||
/* guoye: end */
|
||||
for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
|
||||
{
|
||||
ComputationNodePtr node = dynamic_pointer_cast<ComputationNode<ElemType>>(*nodeIter);
|
||||
|
@ -354,12 +378,18 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
numParameters += node->GetSampleLayout().GetNumElements();
|
||||
}
|
||||
}
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 9 \n");
|
||||
/* guoye: end */
|
||||
size_t numNeedsGradient = 0;
|
||||
for (let node : net->GetEvalOrder(criterionNodes[0]))
|
||||
{
|
||||
if (node->NeedsGradient())
|
||||
numNeedsGradient++;
|
||||
}
|
||||
/* guoye: start */
|
||||
// LOGPRINTF(stderr, "SGD debug 10 \n");
|
||||
/* guoye: end */
|
||||
fprintf(stderr, "\n");
|
||||
LOGPRINTF(stderr, "Training %.0f parameters in %d ",
|
||||
(double)numParameters, (int)nodesToUpdateDescriptions.size());
|
||||
|
@ -482,8 +512,14 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
// likewise for sequence training parameters
|
||||
if (isSequenceTrainingCriterion)
|
||||
{
|
||||
ComputationNetwork::SetSeqParam<ElemType>(net, criterionNodes[0], m_hSmoothingWeight, m_frameDropThresh, m_doReferenceAlign,
|
||||
/* guoye: start */
|
||||
/* ComputationNetwork::SetSeqParam<ElemType>(net, criterionNodes[0], m_hSmoothingWeight, m_frameDropThresh, m_doReferenceAlign,
|
||||
m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR);
|
||||
*/
|
||||
ComputationNetwork::SetSeqParam<ElemType>(net, criterionNodes[0], m_hSmoothingWeight, m_frameDropThresh, m_doReferenceAlign,
|
||||
m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR,
|
||||
m_seqGammarCalcUseEMBR, m_EMBRUnit, m_numPathsEMBR, m_enforceValidPathEMBR, m_getPathMethodEMBR, m_showWERMode, m_excludeSpecialWords, m_wordNbest, m_useAccInNbest, m_accWeightInNbest, m_numRawPathsEMBR);
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// Multiverso Warpper for ASGD logic init
|
||||
|
@ -660,8 +696,16 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
|
|||
learnableNodes, smoothedGradients, smoothedCounts,
|
||||
epochCriterion, epochEvalErrors,
|
||||
"", SIZE_MAX, totalMBsSeen, tensorBoardWriter, startEpoch);
|
||||
totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only
|
||||
|
||||
/* guoye: start */
|
||||
// totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only
|
||||
|
||||
if(!m_seqGammarCalcUseEMBR)
|
||||
totalTrainingSamplesSeen += epochCriterion.second;
|
||||
else
|
||||
totalTrainingSamplesSeen += epochEvalErrors[0].second;
|
||||
|
||||
/* guoye: end */
|
||||
timer.Stop();
|
||||
double epochTime = timer.ElapsedSeconds();
|
||||
|
||||
|
@ -1167,10 +1211,14 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
// get minibatch
|
||||
// TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
|
||||
size_t actualMBSize = 0;
|
||||
/* guoye_start */
|
||||
size_t actualNumWords = 0;
|
||||
|
||||
auto profGetMinibatch = ProfilerTimeBegin();
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
|
||||
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi);
|
||||
// useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi);
|
||||
useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi, actualNumWords);
|
||||
/* guoye_end */
|
||||
|
||||
if (maxNumSamplesExceeded) // Dropping data.
|
||||
wasDataRead = false;
|
||||
|
@ -1294,7 +1342,15 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
// accumulate criterion values (objective, eval)
|
||||
assert(wasDataRead || numSamplesWithLabelOfNetwork == 0);
|
||||
// criteria are in Value()(0,0), we accumulate into another 1x1 Matrix (to avoid having to pull the values off the GPU)
|
||||
localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork);
|
||||
// localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork);
|
||||
|
||||
/* guoye: start */
|
||||
if(!m_seqGammarCalcUseEMBR)
|
||||
localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork);
|
||||
else
|
||||
localEpochCriterion.Add(0, actualNumWords);
|
||||
|
||||
/* guoye: end */
|
||||
for (size_t i = 0; i < evaluationNodes.size(); i++)
|
||||
localEpochEvalErrors.Add(i, numSamplesWithLabelOfNetwork);
|
||||
}
|
||||
|
@ -1326,14 +1382,29 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
}
|
||||
|
||||
// hoist the criterion into CPU space for all-reduce
|
||||
localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork);
|
||||
/* guoye: start */
|
||||
|
||||
if (!m_seqGammarCalcUseEMBR)
|
||||
localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork);
|
||||
else
|
||||
localEpochCriterion.Assign(0, actualNumWords);
|
||||
|
||||
// localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork);
|
||||
/* guoye: end */
|
||||
for (size_t i = 0; i < evaluationNodes.size(); i++)
|
||||
localEpochEvalErrors.Assign(i, numSamplesWithLabelOfNetwork);
|
||||
|
||||
// copy all values to be aggregated into the header
|
||||
m_gradHeader->numSamples = aggregateNumSamples;
|
||||
m_gradHeader->criterion = localEpochCriterion.GetCriterion(0).first;
|
||||
m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel
|
||||
/* guoye: start */
|
||||
// m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel
|
||||
|
||||
if (!m_seqGammarCalcUseEMBR)
|
||||
m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel
|
||||
else
|
||||
m_gradHeader->numSamplesWithLabel = numSamplesWithLabelOfNetwork;
|
||||
/* guoye: end */
|
||||
assert(m_gradHeader->numSamplesWithLabel == aggregateNumSamplesWithLabel);
|
||||
for (size_t i = 0; i < evaluationNodes.size(); i++)
|
||||
m_gradHeader->evalErrors[i] = localEpochEvalErrors.GetCriterion(i);
|
||||
|
@ -1482,8 +1553,14 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
|
|||
// epochCriterion aggregates over entire epoch, but we only show difference to last time we logged
|
||||
EpochCriterion epochCriterionSinceLastLogged = epochCriterion - epochCriterionLastLogged;
|
||||
let trainLossSinceLastLogged = epochCriterionSinceLastLogged.Average(); // TODO: Check whether old trainSamplesSinceLastLogged matches this ^^ difference
|
||||
let trainSamplesSinceLastLogged = (int)epochCriterionSinceLastLogged.second;
|
||||
|
||||
/* guoye: start */
|
||||
// let trainSamplesSinceLastLogged = (int)epochCriterionSinceLastLogged.second;
|
||||
|
||||
// for EMBR, epochCriterionSinceLastLogged.second stores the #words rather than #frames
|
||||
let trainSamplesSinceLastLogged = (m_seqGammarCalcUseEMBR? (int)(epochEvalErrors[0].second - epochEvalErrorsLastLogged[0].second) : (int)epochCriterionSinceLastLogged.second);
|
||||
|
||||
/* guoye: end */
|
||||
// determine progress in percent
|
||||
int mbProgNumPrecision = 2;
|
||||
double mbProg = 0.0;
|
||||
|
@ -1777,7 +1854,12 @@ bool SGD<ElemType>::PreCompute(ComputationNetworkPtr net,
|
|||
const size_t numIterationsBeforePrintingProgress = 100;
|
||||
size_t numItersSinceLastPrintOfProgress = 0;
|
||||
size_t actualMBSizeDummy;
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi))
|
||||
/* guoye: start */
|
||||
size_t actualNumWordsDummy;
|
||||
|
||||
// while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi))
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi, actualNumWordsDummy))
|
||||
/* guoye: end */
|
||||
{
|
||||
// TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'?
|
||||
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
|
||||
|
@ -2981,6 +3063,42 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType)
|
|||
m_frameDropThresh = configSGD(L"frameDropThresh", 1e-10);
|
||||
m_doReferenceAlign = configSGD(L"doReferenceAlign", false);
|
||||
m_seqGammarCalcUsesMBR = configSGD(L"seqGammarUsesMBR", false);
|
||||
|
||||
/* guoye: start */
|
||||
m_seqGammarCalcUseEMBR = configSGD(L"seqGammarUseEMBR", false);
|
||||
m_EMBRUnit = configSGD(L"EMBRUnit", "word");
|
||||
|
||||
m_numPathsEMBR = configSGD(L"numPathsEMBR", (size_t)100);
|
||||
// enforce the path starting with sentence start
|
||||
m_enforceValidPathEMBR = configSGD(L"enforceValidPathEMBR", false);
|
||||
//could be sampling or nbest
|
||||
m_getPathMethodEMBR = configSGD(L"getPathMethodEMBR", "sampling");
|
||||
// could be average or onebest
|
||||
m_showWERMode = configSGD(L"showWERMode", "average");
|
||||
|
||||
// don't include path that has special words if true
|
||||
m_excludeSpecialWords = configSGD(L"excludeSpecialWords", false);
|
||||
|
||||
// true then, we force the nbest has different word sequence
|
||||
m_wordNbest = configSGD(L"wordNbest", false);
|
||||
m_useAccInNbest = configSGD(L"useAccInNbest", false);
|
||||
m_accWeightInNbest = configSGD(L"accWeightInNbest", 1.0f);
|
||||
|
||||
m_numRawPathsEMBR = configSGD(L"numRawPathsEMBR", (size_t)100);
|
||||
|
||||
if (!m_useAccInNbest)
|
||||
{
|
||||
if (m_numRawPathsEMBR > m_numPathsEMBR)
|
||||
{
|
||||
fprintf(stderr, "SGDParams: WARNING: we do not use acc in nbest, so no need to make numRawPathsEMBR = %d larger than numPathsEMBR = %d \n", (int)m_numRawPathsEMBR, (int)m_numPathsEMBR);
|
||||
}
|
||||
}
|
||||
if (m_getPathMethodEMBR == "sampling" && m_showWERMode == "onebest")
|
||||
{
|
||||
RuntimeError("There is no way to show onebest WER in sampling based EMBR");
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
m_seqGammarCalcAMF = configSGD(L"seqGammarAMF", 14.0);
|
||||
m_seqGammarCalcLMF = configSGD(L"seqGammarLMF", 14.0);
|
||||
m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0);
|
||||
|
|
|
@ -323,6 +323,20 @@ protected:
|
|||
double m_seqGammarCalcbMMIFactor;
|
||||
bool m_seqGammarCalcUsesMBR;
|
||||
|
||||
/* guoye: start */
|
||||
bool m_seqGammarCalcUseEMBR;
|
||||
string m_EMBRUnit; //unit could be: word, phone, state (we all compute edit distance
|
||||
bool m_enforceValidPathEMBR;
|
||||
string m_getPathMethodEMBR;
|
||||
size_t m_numPathsEMBR; // number of sampled paths
|
||||
string m_showWERMode; // number of sampled paths
|
||||
bool m_excludeSpecialWords;
|
||||
bool m_wordNbest;
|
||||
bool m_useAccInNbest;
|
||||
float m_accWeightInNbest;
|
||||
size_t m_numRawPathsEMBR;
|
||||
/* guoye: end */
|
||||
|
||||
// decide whether should apply regularization into BatchNormalizationNode
|
||||
// true: disable Regularization
|
||||
// false: enable Regularization (default)
|
||||
|
|
|
@ -120,7 +120,11 @@ public:
|
|||
for (;;)
|
||||
{
|
||||
size_t actualMBSize = 0;
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
/* guoye: start */
|
||||
size_t actualNumWords = 0;
|
||||
/* guoye: end */
|
||||
// bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi);
|
||||
bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi, actualNumWords);
|
||||
// in case of distributed reading, we do a few more loops until all ranks have completed
|
||||
// end of epoch
|
||||
if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess))
|
||||
|
|
|
@ -62,7 +62,11 @@ public:
|
|||
const size_t numIterationsBeforePrintingProgress = 100;
|
||||
size_t numItersSinceLastPrintOfProgress = 0;
|
||||
size_t actualMBSize;
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
|
||||
/* guoye: start */
|
||||
size_t actualNumWords;
|
||||
// while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr))
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr, actualNumWords))
|
||||
/* guoye: end */
|
||||
{
|
||||
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
||||
m_net->ForwardProp(outputNodes);
|
||||
|
@ -230,7 +234,11 @@ public:
|
|||
char formatChar = !formattingOptions.isCategoryLabel ? 'f' : !formattingOptions.labelMappingFile.empty() ? 's' : 'u';
|
||||
std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
|
||||
|
||||
for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr); numMBsRun++)
|
||||
/* guoye: start */
|
||||
size_t actualNumWords;
|
||||
//for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr); numMBsRun++)
|
||||
for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr, actualNumWords); numMBsRun++)
|
||||
/* guoye: end */
|
||||
{
|
||||
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
|
||||
m_net->ForwardProp(outputNodes);
|
||||
|
|
|
@ -11,6 +11,11 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
/* guoye: start */
|
||||
#include <string>
|
||||
|
||||
/* guoye: end */
|
||||
|
||||
#pragma warning(disable : 4127) // conditional expression is constant
|
||||
|
||||
namespace msra { namespace lattices {
|
||||
|
@ -22,6 +27,19 @@ struct SeqGammarCalParam
|
|||
double wp;
|
||||
double bMMIfactor;
|
||||
bool sMBRmode;
|
||||
/* guoye: start */
|
||||
bool EMBR;
|
||||
std::string EMBRUnit;
|
||||
size_t numPathsEMBR;
|
||||
bool enforceValidPathEMBR;
|
||||
std::string getPathMethodEMBR;
|
||||
std::string showWERMode;
|
||||
bool excludeSpecialWords;
|
||||
bool wordNbest;
|
||||
bool useAccInNbest;
|
||||
float accWeightInNbest;
|
||||
size_t numRawPathsEMBR;
|
||||
/* guoye: end */
|
||||
SeqGammarCalParam()
|
||||
{
|
||||
amf = 14.0;
|
||||
|
@ -29,6 +47,20 @@ struct SeqGammarCalParam
|
|||
wp = 0.0;
|
||||
bMMIfactor = 0.0;
|
||||
sMBRmode = false;
|
||||
|
||||
/* guoye: start */
|
||||
EMBR = false;
|
||||
EMBRUnit = "word";
|
||||
numPathsEMBR = 100;
|
||||
enforceValidPathEMBR = false;
|
||||
getPathMethodEMBR = "sampling";
|
||||
showWERMode = "average";
|
||||
excludeSpecialWords = false;
|
||||
wordNbest = false;
|
||||
useAccInNbest = false;
|
||||
accWeightInNbest = 1.0;
|
||||
numRawPathsEMBR = 100;
|
||||
/* guoye: end*/
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -81,6 +113,19 @@ public:
|
|||
wp = (float) gammarParam.wp;
|
||||
seqsMBRmode = gammarParam.sMBRmode;
|
||||
boostmmifactor = (float) gammarParam.bMMIfactor;
|
||||
/* guoye: start */
|
||||
EMBR = gammarParam.EMBR;
|
||||
EMBRUnit = gammarParam.EMBRUnit;
|
||||
numPathsEMBR = gammarParam.numPathsEMBR;
|
||||
enforceValidPathEMBR = gammarParam.enforceValidPathEMBR;
|
||||
getPathMethodEMBR = gammarParam.getPathMethodEMBR;
|
||||
showWERMode = gammarParam.showWERMode;
|
||||
excludeSpecialWords = gammarParam.excludeSpecialWords;
|
||||
wordNbest = gammarParam.wordNbest;
|
||||
useAccInNbest = gammarParam.useAccInNbest;
|
||||
accWeightInNbest = gammarParam.accWeightInNbest;
|
||||
numRawPathsEMBR = gammarParam.numRawPathsEMBR;
|
||||
/* guoye: end */
|
||||
}
|
||||
|
||||
// ========================================
|
||||
|
@ -91,7 +136,10 @@ public:
|
|||
const Microsoft::MSR::CNTK::Matrix<ElemType>& loglikelihood,
|
||||
Microsoft::MSR::CNTK::Matrix<ElemType>& labels,
|
||||
Microsoft::MSR::CNTK::Matrix<ElemType>& gammafromlattice,
|
||||
std::vector<size_t>& uids, std::vector<size_t>& boundaries,
|
||||
/* guoye: start */
|
||||
// std::vector<size_t>& uids, std::vector<size_t>& boundaries,
|
||||
std::vector<size_t>& uids, std::vector<size_t>& wids, std::vector<short>& nws, std::vector<size_t>& boundaries,
|
||||
/* guoye: end */
|
||||
size_t samplesInRecurrentStep, /* numParallelUtterance ? */
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::MBLayout> pMBLayout,
|
||||
std::vector<size_t>& extrauttmap,
|
||||
|
@ -99,6 +147,13 @@ public:
|
|||
{
|
||||
// check total frame number to be added ?
|
||||
// int deviceid = loglikelihood.GetDeviceId();
|
||||
/* guoye: start */
|
||||
/*
|
||||
for (size_t i = 0; i < lattices.size(); i++)
|
||||
{
|
||||
// fprintf(stderr, "calgammaformb: i = %d, utt = %ls \n", int(i), lattices[i]->second.key.c_str());
|
||||
}
|
||||
*/
|
||||
size_t boundaryframenum;
|
||||
std::vector<size_t> validframes; // [s] cursor pointing to next utterance begin within a single parallel sequence [s]
|
||||
validframes.assign(samplesInRecurrentStep, 0);
|
||||
|
@ -128,9 +183,15 @@ public:
|
|||
size_t mapi = 0; // parallel-sequence index for utterance [i]
|
||||
// cal gamma for each utterance
|
||||
size_t ts = 0;
|
||||
/* guoye: start */
|
||||
size_t ws = 0;
|
||||
/* guoye: end */
|
||||
for (size_t i = 0; i < lattices.size(); i++)
|
||||
{
|
||||
const size_t numframes = lattices[i]->getnumframes();
|
||||
/* guoye: start */
|
||||
const short numwords = nws[i];
|
||||
/* guoye: end */
|
||||
|
||||
msra::dbn::matrixstripe predstripe(pred, ts, numframes); // logLLs for this utterance
|
||||
msra::dbn::matrixstripe dengammasstripe(dengammas, ts, numframes); // denominator gammas
|
||||
|
@ -186,6 +247,9 @@ public:
|
|||
}
|
||||
|
||||
array_ref<size_t> uidsstripe(&uids[ts], numframes);
|
||||
/* guoye: start */
|
||||
std::vector<size_t> widsstripe(wids.begin() + ws, wids.begin() + ws + numwords);
|
||||
/* guoye: end */
|
||||
|
||||
if (doreferencealign)
|
||||
{
|
||||
|
@ -204,12 +268,28 @@ public:
|
|||
numavlogp /= numframes;
|
||||
|
||||
// auto_timer dengammatimer;
|
||||
/* guoye: start */
|
||||
|
||||
// double denavlogp = lattices[i]->second.forwardbackward(parallellattice,
|
||||
// (const msra::math::ssematrixbase&) predstripe, (const msra::asr::simplesenonehmm&) m_hset,
|
||||
// (msra::math::ssematrixbase&) dengammasstripe, (msra::math::ssematrixbase&) gammasbuffer /*empty, not used*/,
|
||||
// lmf, wp, amf, boostmmifactor, seqsMBRmode, uidsstripe, boundariesstripe);
|
||||
|
||||
// fprintf(stderr, "calgammaformb: i = %d, utt = %ls \n", int(i), lattices[i]->second.key.c_str());
|
||||
|
||||
double denavlogp = lattices[i]->second.forwardbackward(parallellattice,
|
||||
(const msra::math::ssematrixbase&) predstripe, (const msra::asr::simplesenonehmm&) m_hset,
|
||||
(msra::math::ssematrixbase&) dengammasstripe, (msra::math::ssematrixbase&) gammasbuffer /*empty, not used*/,
|
||||
lmf, wp, amf, boostmmifactor, seqsMBRmode, uidsstripe, boundariesstripe);
|
||||
objectValue += (ElemType)((numavlogp - denavlogp) * numframes);
|
||||
|
||||
lmf, wp, amf, boostmmifactor, seqsMBRmode, EMBR, EMBRUnit, numPathsEMBR, enforceValidPathEMBR, getPathMethodEMBR, showWERMode, excludeSpecialWords, wordNbest, useAccInNbest, accWeightInNbest, numRawPathsEMBR, uidsstripe, widsstripe, boundariesstripe);
|
||||
|
||||
/* guoye: end */
|
||||
/* guoye: start */
|
||||
// objectValue += (ElemType)((numavlogp - denavlogp) * numframes);
|
||||
numavlogp;
|
||||
denavlogp;
|
||||
// objectValue += (ElemType)( 0 * numframes);
|
||||
objectValue += (ElemType)(denavlogp*numwords);
|
||||
/* guoye: end */
|
||||
if (samplesInRecurrentStep == 1)
|
||||
{
|
||||
tempmatrix = gammafromlattice.ColumnSlice(ts, numframes);
|
||||
|
@ -243,8 +323,13 @@ public:
|
|||
}
|
||||
if (samplesInRecurrentStep > 1)
|
||||
validframes[mapi] += numframes; // advance the cursor within the parallel sequence
|
||||
fprintf(stderr, "dengamma value %f\n", denavlogp);
|
||||
/* guoye: start */
|
||||
// fprintf(stderr, "dengamma value %f\n", denavlogp);
|
||||
/* guoye: end */
|
||||
ts += numframes;
|
||||
/* guoye: start */
|
||||
ws += numwords;
|
||||
/* guoye: end */
|
||||
}
|
||||
functionValues.SetValue(objectValue);
|
||||
}
|
||||
|
@ -509,6 +594,20 @@ protected:
|
|||
float boostmmifactor;
|
||||
bool seqsMBRmode;
|
||||
|
||||
/* guoye: start */
|
||||
bool EMBR;
|
||||
std::string EMBRUnit;
|
||||
size_t numPathsEMBR;
|
||||
bool enforceValidPathEMBR;
|
||||
std::string getPathMethodEMBR;
|
||||
std::string showWERMode;
|
||||
bool excludeSpecialWords;
|
||||
bool wordNbest;
|
||||
bool useAccInNbest;
|
||||
float accWeightInNbest;
|
||||
size_t numRawPathsEMBR;
|
||||
/* guoye: end */
|
||||
|
||||
private:
|
||||
std::unique_ptr<Microsoft::MSR::CNTK::CUDAPageLockedMemAllocator> m_cudaAllocator;
|
||||
std::shared_ptr<ElemType> m_intermediateCUDACopyBuffer;
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -133,6 +133,27 @@ void backwardlatticej(const size_t batchsize, const size_t startindex, const std
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
|
||||
void backwardlatticejEMBR(const size_t batchsize, const size_t startindex, const std::vector<float>& edgeacscores,
|
||||
const std::vector<msra::lattices::edgeinfowithscores>& edges,
|
||||
const std::vector<msra::lattices::nodeinfo>& nodes,
|
||||
std::vector<double>& edgelogbetas, std::vector<double>& logbetas,
|
||||
float lmf, float wp, float amf)
|
||||
{
|
||||
const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block
|
||||
const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const size_t j = jinblock + blockIdx.x * tpb;
|
||||
if (j < batchsize) // note: will cause issues if we ever use __synctreads() in backwardlatticej
|
||||
{
|
||||
msra::lattices::latticefunctionskernels::backwardlatticejEMBR(j + startindex, edgeacscores,
|
||||
edges, nodes, edgelogbetas,
|
||||
logbetas, lmf, wp, amf);
|
||||
}
|
||||
}
|
||||
|
||||
/* guoye: end */
|
||||
void sMBRerrorsignalj(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
|
||||
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
|
||||
const std::vector<double>& logpps, const float amf, const std::vector<double>& logEframescorrect,
|
||||
|
@ -147,6 +168,19 @@ void sMBRerrorsignalj(const std::vector<unsigned short>& alignstateids, const st
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
void EMBRerrorsignalj(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
|
||||
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
|
||||
const std::vector<double>& edgeweights, msra::math::ssematrixbase& errorsignal)
|
||||
{
|
||||
const size_t shufflemode = 3;
|
||||
const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode);
|
||||
if (j < edges.size()) // note: will cause issues if we ever use __synctreads()
|
||||
{
|
||||
msra::lattices::latticefunctionskernels::EMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
|
||||
}
|
||||
}
|
||||
/* guoye: end */
|
||||
void stateposteriorsj(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
|
||||
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
|
||||
const std::vector<double>& logqs, msra::math::ssematrixbase& logacc)
|
||||
|
@ -298,6 +332,50 @@ static double emulateforwardbackwardlattice(const size_t* batchsizeforward, cons
|
|||
#endif
|
||||
return totalfwscore;
|
||||
}
|
||||
/* guoye: start */
|
||||
|
||||
static double emulatebackwardlatticeEMBR(const size_t* batchsizebackward, const size_t numlaunchbackward,
|
||||
const std::vector<float>& edgeacscores,
|
||||
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
|
||||
std::vector<double>& edgelogbetas, std::vector<double>& logbetas,
|
||||
const float lmf, const float wp, const float amf)
|
||||
{
|
||||
dim3 t(32, 8);
|
||||
const size_t tpb = t.x * t.y;
|
||||
dim3 b((unsigned int)((logbetas.size() + tpb - 1) / tpb));
|
||||
|
||||
emulatecuda(b, t, [&]()
|
||||
{
|
||||
setvaluej(logbetas, LOGZERO, logbetas.size());
|
||||
});
|
||||
|
||||
|
||||
logbetas[nodes.size() - 1] = 0;
|
||||
|
||||
// forward pass
|
||||
|
||||
|
||||
// backward pass
|
||||
size_t startindex = edges.size();
|
||||
for (size_t i = 0; i < numlaunchbackward; i++)
|
||||
{
|
||||
dim3 b3((unsigned int)((batchsizebackward[i] + tpb - 1) / tpb));
|
||||
emulatecuda(b3, t, [&]()
|
||||
{
|
||||
backwardlatticejEMBR(batchsizebackward[i], startindex - batchsizebackward[i], edgeacscores,
|
||||
edges, nodes, edgelogbetas, logbetas, lmf, wp, amf);
|
||||
|
||||
|
||||
});
|
||||
startindex -= batchsizebackward[i];
|
||||
}
|
||||
double totalbwscore = logbetas.front();
|
||||
|
||||
|
||||
return totalbwscore;
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
// this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything
|
||||
// this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector)
|
||||
static void emulatesMBRerrorsignal(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
|
||||
|
@ -324,6 +402,29 @@ static void emulatesMBRerrorsignal(const std::vector<unsigned short>& alignstate
|
|||
});
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
// this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything
|
||||
// this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector)
|
||||
static void emulateEMBRerrorsignal(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
|
||||
const std::vector<msra::lattices::edgeinfowithscores>& edges, const std::vector<msra::lattices::nodeinfo>& nodes,
|
||||
const std::vector<double>& edgeweights,
|
||||
msra::math::ssematrixbase& errorsignal)
|
||||
{
|
||||
|
||||
const size_t numedges = edges.size();
|
||||
dim3 t(32, 8);
|
||||
const size_t tpb = t.x * t.y;
|
||||
foreach_coord(i, j, errorsignal)
|
||||
errorsignal(i, j) = 0;
|
||||
dim3 b((unsigned int)((numedges + tpb - 1) / tpb));
|
||||
emulatecuda(b, t, [&]()
|
||||
{
|
||||
EMBRerrorsignalj(alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal);
|
||||
});
|
||||
dim3 b1((((unsigned int)errorsignal.rows()) + 31) / 32);
|
||||
}
|
||||
/* guoye: end */
|
||||
// this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything
|
||||
// this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector)
|
||||
static void emulatemmierrorsignal(const std::vector<unsigned short>& alignstateids, const std::vector<unsigned int>& alignoffsets,
|
||||
|
@ -388,6 +489,11 @@ struct parallelstateimpl
|
|||
logppsgpu(msra::cuda::newdoublevector(deviceid)),
|
||||
logalphasgpu(msra::cuda::newdoublevector(deviceid)),
|
||||
logbetasgpu(msra::cuda::newdoublevector(deviceid)),
|
||||
/* guoye: start */
|
||||
edgelogbetasgpu(msra::cuda::newdoublevector(deviceid)),
|
||||
edgeweightsgpu(msra::cuda::newdoublevector(deviceid)),
|
||||
/* guoye: end */
|
||||
|
||||
logaccalphasgpu(msra::cuda::newdoublevector(deviceid)),
|
||||
logaccbetasgpu(msra::cuda::newdoublevector(deviceid)),
|
||||
logframescorrectedgegpu(msra::cuda::newdoublevector(deviceid)),
|
||||
|
@ -526,6 +632,10 @@ struct parallelstateimpl
|
|||
|
||||
std::unique_ptr<doublevector> logppsgpu;
|
||||
std::unique_ptr<doublevector> logalphasgpu;
|
||||
/* guoye: start */
|
||||
std::unique_ptr<doublevector> edgelogbetasgpu;
|
||||
std::unique_ptr<doublevector> edgeweightsgpu;
|
||||
/* guoye: end */
|
||||
std::unique_ptr<doublevector> logbetasgpu;
|
||||
std::unique_ptr<doublevector> logaccalphasgpu;
|
||||
std::unique_ptr<doublevector> logaccbetasgpu;
|
||||
|
@ -619,6 +729,20 @@ struct parallelstateimpl
|
|||
logEframescorrectgpu->allocate(edges.size());
|
||||
}
|
||||
}
|
||||
/* guoye: start */
|
||||
template <class edgestype, class nodestype>
|
||||
void allocbwvectorsEMBR(const edgestype& edges, const nodestype& nodes)
|
||||
{
|
||||
#ifndef TWO_CHANNEL
|
||||
const size_t alphabetanoderatio = 1;
|
||||
#else
|
||||
const size_t alphabetanoderatio = 2;
|
||||
#endif
|
||||
logbetasgpu->allocate(alphabetanoderatio * nodes.size());
|
||||
edgelogbetasgpu->allocate(edges.size());
|
||||
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
// check if gpumatrixstorage supports size of cpumatrix, if not allocate. set gpumatrix to part of gpumatrixstorage
|
||||
// This function checks the size of errorsignalgpustorage, and then sets errorsignalgpu to a columnslice of the
|
||||
|
@ -664,6 +788,35 @@ struct parallelstateimpl
|
|||
edgealignments.resize(alignresult->size());
|
||||
alignresult->fetch(edgealignments, true);
|
||||
}
|
||||
/* guoye: start */
|
||||
|
||||
void getlogbetas(std::vector<double>& logbetas)
|
||||
{
|
||||
logbetas.resize(logbetasgpu->size());
|
||||
logbetasgpu->fetch(logbetas, true);
|
||||
}
|
||||
|
||||
void getedgelogbetas(std::vector<double>& edgelogbetas)
|
||||
{
|
||||
edgelogbetas.resize(edgelogbetasgpu->size());
|
||||
edgelogbetasgpu->fetch(edgelogbetas, true);
|
||||
}
|
||||
|
||||
void getedgeweights(std::vector<double>& edgeweights)
|
||||
{
|
||||
edgeweights.resize(edgeweightsgpu->size());
|
||||
edgeweightsgpu->fetch(edgeweights, true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void setedgeweights(const std::vector<double>& edgeweights)
|
||||
{
|
||||
edgeweightsgpu->assign(edgeweights, false);
|
||||
}
|
||||
|
||||
|
||||
/* guoye: end */
|
||||
};
|
||||
|
||||
void lattice::parallelstate::setdevice(size_t deviceid)
|
||||
|
@ -725,6 +878,29 @@ void lattice::parallelstate::getedgealignments(std::vector<unsigned short>& edge
|
|||
{
|
||||
pimpl->getedgealignments(edgealignments);
|
||||
}
|
||||
/* guoye: start */
|
||||
void lattice::parallelstate::getlogbetas(std::vector<double>& logbetas)
|
||||
{
|
||||
pimpl->getlogbetas(logbetas);
|
||||
}
|
||||
void lattice::parallelstate::getedgelogbetas(std::vector<double>& edgelogbetas)
|
||||
{
|
||||
pimpl->getedgelogbetas(edgelogbetas);
|
||||
}
|
||||
void lattice::parallelstate::getedgeweights(std::vector<double>& edgeweights)
|
||||
{
|
||||
pimpl->getedgeweights(edgeweights);
|
||||
}
|
||||
void lattice::parallelstate::setedgeweights(const std::vector<double>& edgeweights)
|
||||
{
|
||||
pimpl->setedgeweights(edgeweights);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/* guoye: end */
|
||||
//template<class ElemType>
|
||||
void lattice::parallelstate::setloglls(const Microsoft::MSR::CNTK::Matrix<float>& loglls)
|
||||
{
|
||||
|
@ -909,6 +1085,73 @@ double lattice::parallelforwardbackwardlattice(parallelstate& parallelstate, con
|
|||
return totalfwscore;
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
|
||||
|
||||
// parallelforwardbackwardlattice() -- compute the latticelevel logpps using forwardbackward
|
||||
double lattice::parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector<float>& edgeacscores,
|
||||
const float lmf, const float wp, const float amf,
|
||||
std::vector<double>& edgelogbetas, std::vector<double>& logbetas) const
|
||||
{ // ^^ TODO: remove this
|
||||
vector<size_t> batchsizebackward; // record the batch size that exclude the data dependency for backward
|
||||
|
||||
|
||||
size_t endindexbackward = edges.back().S;
|
||||
size_t countbatchbackward = 0;
|
||||
foreach_index (j, edges) // compute the batch size info for kernel launches
|
||||
{
|
||||
const size_t backj = edges.size() - 1 - j;
|
||||
if (edges[backj].E > endindexbackward)
|
||||
{
|
||||
countbatchbackward++;
|
||||
if (endindexbackward < edges[backj].S)
|
||||
endindexbackward = edges[backj].S;
|
||||
}
|
||||
else
|
||||
{
|
||||
batchsizebackward.push_back(countbatchbackward);
|
||||
countbatchbackward = 1;
|
||||
endindexbackward = edges[backj].S;
|
||||
}
|
||||
}
|
||||
batchsizebackward.push_back(countbatchbackward);
|
||||
|
||||
|
||||
double totalbwscore = 0.0f;
|
||||
if (!parallelstate->emulation)
|
||||
{
|
||||
if (verbosity >= 2)
|
||||
fprintf(stderr, "parallelbackwardlatticeEMBR: %d launches for backward\n", (int) batchsizebackward.size());
|
||||
|
||||
|
||||
parallelstate->allocbwvectorsEMBR(edges, nodes);
|
||||
|
||||
std::unique_ptr<latticefunctions> latticefunctions(msra::cuda::newlatticefunctions(parallelstate.getdevice())); // final CUDA call
|
||||
latticefunctions->backwardlatticeEMBR(&batchsizebackward[0], batchsizebackward.size(),
|
||||
*parallelstate->edgeacscoresgpu.get(), *parallelstate->edgesgpu.get(),
|
||||
*parallelstate->nodesgpu.get(), *parallelstate->edgelogbetasgpu.get(),
|
||||
*parallelstate->logbetasgpu.get(), lmf, wp, amf, totalbwscore);
|
||||
|
||||
}
|
||||
else // emulation
|
||||
{
|
||||
#ifndef TWO_CHANNEL
|
||||
fprintf(stderr, "forbid invalid sil path\n");
|
||||
const size_t alphabetanoderatio = 1;
|
||||
#else
|
||||
const size_t alphabetanoderatio = 2;
|
||||
#endif
|
||||
logbetas.resize(alphabetanoderatio * nodes.size());
|
||||
edgelogbetas.resize(edges.size());
|
||||
|
||||
|
||||
totalbwscore = emulatebackwardlatticeEMBR(&batchsizebackward[0], batchsizebackward.size(),
|
||||
edgeacscores, edges, nodes,
|
||||
edgelogbetas, logbetas, lmf, wp, amf);
|
||||
}
|
||||
return totalbwscore;
|
||||
}
|
||||
/* guoye: end */
|
||||
// ------------------------------------------------------------------------
|
||||
// parallel implementations of sMBR error updating step
|
||||
// ------------------------------------------------------------------------
|
||||
|
@ -948,6 +1191,36 @@ void lattice::parallelsMBRerrorsignal(parallelstate& parallelstate, const edgeal
|
|||
}
|
||||
}
|
||||
|
||||
/* guoye: start */
|
||||
// ------------------------------------------------------------------------
|
||||
void lattice::parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
|
||||
const std::vector<double>& edgeweights,
|
||||
msra::math::ssematrixbase& errorsignal) const
|
||||
{
|
||||
|
||||
if (!parallelstate->emulation)
|
||||
{
|
||||
// no need negative buffer for EMBR
|
||||
const bool cacheerrorsignalneg = false;
|
||||
parallelstate->cacheerrorsignal(errorsignal, cacheerrorsignalneg);
|
||||
|
||||
std::unique_ptr<latticefunctions> latticefunctions(msra::cuda::newlatticefunctions(parallelstate.getdevice()));
|
||||
latticefunctions->EMBRerrorsignal(*parallelstate->alignresult.get(), *parallelstate->alignoffsetsgpu.get(), *parallelstate->edgesgpu.get(),
|
||||
*parallelstate->nodesgpu.get(), *parallelstate->edgeweightsgpu.get(),
|
||||
*parallelstate->errorsignalgpu.get());
|
||||
|
||||
if (errorsignal.rows() > 0 && errorsignal.cols() > 0)
|
||||
{
|
||||
parallelstate->errorsignalgpu->CopySection(errorsignal.rows(), errorsignal.cols(), &errorsignal(0, 0), errorsignal.getcolstride());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
emulateEMBRerrorsignal(thisedgealignments.getalignmentsbuffer(), thisedgealignments.getalignoffsets(), edges, nodes, edgeweights, errorsignal);
|
||||
}
|
||||
}
|
||||
/* guoye: end */
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
// parallel implementations of MMI error updating step
|
||||
// ------------------------------------------------------------------------
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Загрузка…
Ссылка в новой задаче