Merge remote-tracking branch 'cntk/master' into merge
This commit is contained in:
Коммит
cf29bf0f38
Двоичные данные
BrainScript/BrainScript--extending the CNTK config language, Frank Seide August 2015.pptx
Normal file
Двоичные данные
BrainScript/BrainScript--extending the CNTK config language, Frank Seide August 2015.pptx
Normal file
Двоичный файл не отображается.
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,29 @@
|
|||
// BrainScriptEvaluator.h -- execute what's given in a config file
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Basics.h"
|
||||
#include "ScriptableObjects.h"
|
||||
#include "BrainScriptParser.h"
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace BS {
|
||||
|
||||
using namespace std;
|
||||
using namespace Microsoft::MSR::ScriptableObjects;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// functions exposed by this module
|
||||
// TODO: This is the only thing that should stay in an actual BrainScriptEvaluator.h.
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// understand and execute from the syntactic expression tree
|
||||
ConfigValuePtr Evaluate(ExpressionPtr); // evaluate the expression tree
|
||||
void Do(ExpressionPtr e); // evaluate e.do
|
||||
shared_ptr<Object> EvaluateField(ExpressionPtr e, const wstring & id); // for experimental CNTK integration
|
||||
|
||||
// some simple tests
|
||||
void SomeTests();
|
||||
|
||||
}}} // end namespaces
|
|
@ -0,0 +1,797 @@
|
|||
// ConfigParser.cpp -- config parser (syntactic only, that is, source -> Expression tree)
|
||||
|
||||
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
|
||||
|
||||
#include "BrainScriptParser.h"
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cctype>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef let
|
||||
#define let const auto
|
||||
#endif
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace BS {
|
||||
|
||||
using namespace std;
|
||||
using namespace msra::strfun;
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// source files and text references (location) into them
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SourceFile constructors
|
||||
SourceFile::SourceFile(wstring location, wstring text) : path(location), lines(split(text, L"\r\n")) { } // from string, e.g. command line
|
||||
SourceFile::SourceFile(wstring path) : path(path) // from file
|
||||
{
|
||||
File(path, fileOptionsRead).GetLines(lines);
|
||||
}
|
||||
|
||||
bool TextLocation::IsValid() const { return sourceFileAsIndex != SIZE_MAX; }
|
||||
|
||||
// register a new source file and return a TextPosition that points to its start
|
||||
/*static*/ TextLocation TextLocation::NewSourceFile(SourceFile && sourceFile)
|
||||
{
|
||||
TextLocation loc;
|
||||
loc.lineNo = 0;
|
||||
loc.charPos = 0;
|
||||
loc.sourceFileAsIndex = sourceFileMap.size(); // index under which we store the source file
|
||||
sourceFileMap.push_back(move(sourceFile)); // take ownership of the source file and give it a numeric index
|
||||
return loc;
|
||||
}
|
||||
|
||||
// helper for pretty-printing errors: Show source-code line with ...^ under it to mark up the point of error
|
||||
struct Issue
|
||||
{
|
||||
TextLocation location; // using lineno and source file; char position only for printing the overall error loc
|
||||
wstring markup; // string with markup symbols at char positions and dots inbetween
|
||||
void AddMarkup(wchar_t symbol, size_t charPos)
|
||||
{
|
||||
if (charPos >= markup.size())
|
||||
markup.resize(charPos+1, L' '); // fill with '.' up to desired position if the string is not that long yet
|
||||
if (markup[charPos] == L' ') // don't overwrite
|
||||
markup[charPos] = symbol;
|
||||
}
|
||||
Issue(TextLocation location) : location(location) { }
|
||||
};
|
||||
|
||||
// trace
|
||||
/*static*/ void TextLocation::Trace(TextLocation location, const wchar_t * traceKind, const wchar_t * op, const wchar_t * exprPath)
|
||||
{
|
||||
fprintf(stderr, "%ls: %ls (path %ls)\n", traceKind, op, exprPath);
|
||||
const auto & lines = location.GetSourceFile().lines;
|
||||
const auto line = (location.lineNo == lines.size()) ? L"(end)" : lines[location.lineNo].c_str();
|
||||
Issue issue(location);
|
||||
issue.AddMarkup(L'^', location.charPos);
|
||||
fprintf(stderr, " %ls\n %ls\n", line, issue.markup.c_str());
|
||||
}
|
||||
|
||||
// report an error
|
||||
// The source line is shown, and the position is marked as '^'.
|
||||
// Because it is often hard to recognize an issue only from the point where it occurred, we also report the history in compact visual form.
|
||||
// Since often multiple contexts are on the same source line, we only print each source line once in a consecutive row of contexts.
|
||||
/*static*/ void TextLocation::PrintIssue(const vector<TextLocation> & locations, const wchar_t * errorKind, const wchar_t * kind, const wchar_t * what)
|
||||
{
|
||||
vector<Issue> issues; // tracing the error backwards
|
||||
size_t symbolIndex = 0;
|
||||
for (size_t n = 0; n < locations.size(); n++)
|
||||
{
|
||||
let & location = locations[n];
|
||||
if (!location.IsValid()) // means thrower has no location, go up one context
|
||||
continue;
|
||||
// build the array
|
||||
if (symbolIndex == 0 || location.lineNo != issues.back().location.lineNo || location.sourceFileAsIndex != issues.back().location.sourceFileAsIndex)
|
||||
{
|
||||
if (issues.size() == 10)
|
||||
break;
|
||||
else
|
||||
issues.push_back(location);
|
||||
}
|
||||
// get the symbol to indicate how many steps back, in this sequence: ^ 0..9 a..z A..Z (we don't go further than this)
|
||||
wchar_t symbol;
|
||||
if (symbolIndex == 0) symbol = '^';
|
||||
else if (symbolIndex < 1 + 10) symbol = '0' + (wchar_t)symbolIndex - 1;
|
||||
else if (symbolIndex < 1 + 10 + 26) symbol = 'a' + (wchar_t)symbolIndex - (1 + 10);
|
||||
else if (symbolIndex < 1 + 10 + 26 + 26) symbol = 'A' + (wchar_t)symbolIndex - (1 + 10 + 26);
|
||||
else break;
|
||||
symbolIndex++;
|
||||
// insert the markup
|
||||
issues.back().AddMarkup(symbol, location.charPos);
|
||||
}
|
||||
// print it backwards
|
||||
if (!locations.empty()) // (be resilient to some throwers not having a TextrLocation; to be avoided)
|
||||
{
|
||||
let & firstLoc = issues.front().location;
|
||||
fprintf(stderr, "\n%ls while %ls line %d char %d of %ls\n", errorKind, kind, (int)firstLoc.lineNo + 1/*report 1-based*/, (int)firstLoc.charPos + 1, firstLoc.GetSourceFile().path.c_str());
|
||||
fprintf(stderr, "see location marked ^ and parent contexts marked 0..9, a..z, A..Z:\n\n");
|
||||
for (auto i = issues.rbegin(); i != issues.rend(); i++)
|
||||
{
|
||||
let & issue = *i;
|
||||
auto & where = issue.location;
|
||||
const auto & lines = where.GetSourceFile().lines;
|
||||
const auto line = (where.lineNo == lines.size()) ? L"(end)" : lines[where.lineNo].c_str();
|
||||
fprintf(stderr, " %ls\n %ls\n", line, issue.markup.c_str());
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%ls: %ls\n", errorKind, what);
|
||||
fflush(stderr);
|
||||
}
|
||||
/*static*/ vector<SourceFile> TextLocation::sourceFileMap;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// reader -- reads source code, including loading from disk
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class CodeSource
|
||||
{
|
||||
vector<TextLocation> locationStack; // parent locations in case of included files
|
||||
TextLocation cursor; // current location
|
||||
const wchar_t * currentLine; // cache of cursor.GetSourceFile().lines[cursor.lineNo]
|
||||
// update currentLine from cursor
|
||||
void CacheCurrentLine()
|
||||
{
|
||||
let & lines = cursor.GetSourceFile().lines;
|
||||
if (cursor.lineNo == lines.size())
|
||||
currentLine = nullptr;
|
||||
else
|
||||
currentLine = lines[cursor.lineNo].c_str();
|
||||
}
|
||||
protected:
|
||||
// set a source file; only do that from constructor or inside PushSourceFile()
|
||||
void SetSourceFile(SourceFile && sourceFile)
|
||||
{
|
||||
cursor = TextLocation::NewSourceFile(move(sourceFile)); // save source file and set the cursor to its start
|
||||
CacheCurrentLine(); // re-cache current line
|
||||
}
|
||||
public:
|
||||
class CodeSourceError : public ConfigError
|
||||
{
|
||||
public:
|
||||
CodeSourceError(const wstring & msg, TextLocation where) : ConfigError(msg, where) { }
|
||||
/*ConfigError::*/ const wchar_t * kind() const { return L"reading source"; }
|
||||
};
|
||||
|
||||
__declspec_noreturn static void Fail(wstring msg, TextLocation where) { throw CodeSourceError(msg, where); }
|
||||
|
||||
// enter a source file, at start or as a result of an include statement
|
||||
void PushSourceFile(SourceFile && sourceFile)
|
||||
{
|
||||
locationStack.push_back(cursor);
|
||||
SetSourceFile(move(sourceFile));
|
||||
}
|
||||
|
||||
// are we inside an include file?
|
||||
bool IsInInclude() { return locationStack.size() > 0; }
|
||||
|
||||
// done with a source file. Only call this for nested files; the outermost one must not be popped.
|
||||
void PopSourceFile()
|
||||
{
|
||||
if (!IsInInclude())
|
||||
LogicError("PopSourceFile: location stack empty");
|
||||
cursor = locationStack.back(); // restore cursor we came from
|
||||
CacheCurrentLine(); // re-cache current line
|
||||
locationStack.pop_back();
|
||||
}
|
||||
|
||||
// get current cursor; this is remembered for each token, and also used when throwing errors
|
||||
TextLocation GetCursor() const { return cursor; }
|
||||
|
||||
// get character at current position.
|
||||
// Special cases:
|
||||
// - end of line is returned as '\n'
|
||||
// - end of file is returned as 0
|
||||
wchar_t GotChar() const
|
||||
{
|
||||
if (!currentLine) return 0; // end of file
|
||||
else if (!currentLine[cursor.charPos]) return '\n'; // end of line
|
||||
else return currentLine[cursor.charPos];
|
||||
}
|
||||
|
||||
// we chan also return the address of the current character, e.g. for passing it to a C stdlib funcion such as wcstod()
|
||||
const wchar_t * GotCharPtr() const { return currentLine + cursor.charPos; }
|
||||
|
||||
// advance cursor by #chars (but across line boundaries)
|
||||
void ConsumeChars(size_t chars)
|
||||
{
|
||||
let ch = GotChar();
|
||||
if (!ch) LogicError("Consume: cannot run beyond end of source file");
|
||||
if (ch == '\n' && chars > 0)
|
||||
{
|
||||
if (chars != 1) LogicError("Consume: cannot run beyond end of line");
|
||||
cursor.lineNo++;
|
||||
CacheCurrentLine(); // line no has changed: re-cache the line ptr
|
||||
cursor.charPos = 0;
|
||||
}
|
||||
else
|
||||
cursor.charPos += chars;
|
||||
}
|
||||
|
||||
// get the next character
|
||||
wchar_t GetChar()
|
||||
{
|
||||
ConsumeChars(1);
|
||||
return GotChar();
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lexer -- iterates over the source code and returns token by token
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class Lexer : public CodeSource
|
||||
{
|
||||
set<wstring> keywords;
|
||||
set<wstring> punctuations;
|
||||
public:
|
||||
Lexer() : CodeSource(), currentToken(TextLocation())
|
||||
{
|
||||
keywords = set<wstring>
|
||||
{
|
||||
L"include",
|
||||
L"new", L"true", L"false",
|
||||
L"if", L"then", L"else",
|
||||
L"array",
|
||||
};
|
||||
punctuations = set<wstring>
|
||||
{
|
||||
L"=", L";", L",", L"\n",
|
||||
L"[", L"]", L"(", L")",
|
||||
L"+", L"-", L"*", L"/", L"**", L".*", L"%", L"||", L"&&", L"^",
|
||||
L"!",
|
||||
L"==", L"!=", L"<", L"<=", L">", L">=",
|
||||
L":", L"=>",
|
||||
L"..", L".",
|
||||
L"//", L"#", L"/*"
|
||||
};
|
||||
}
|
||||
|
||||
enum TokenKind
|
||||
{
|
||||
invalid, punctuation, numberliteral, stringliteral, booleanliter, identifier, keyword, eof // TODO: what are true and false? Literals or identifiers?
|
||||
};
|
||||
|
||||
struct Token
|
||||
{
|
||||
wstring symbol; // identifier, keyword, punctuation, or string literal
|
||||
double number; // number
|
||||
TokenKind kind;
|
||||
TextLocation beginLocation; // text loc of first character of this token
|
||||
bool isLineInitial; // this token is the first on the line (ignoring comments)
|
||||
Token(TextLocation loc) : beginLocation(loc), kind(invalid), number(0.0), isLineInitial(false) { }
|
||||
// diagnostic helper
|
||||
static wstring TokenKindToString(TokenKind kind)
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case invalid: return L"invalid";
|
||||
case punctuation: return L"punctuation";
|
||||
case numberliteral: return L"numberliteral";
|
||||
case stringliteral: return L"stringliteral";
|
||||
case identifier: return L"identifier";
|
||||
case keyword: return L"keyword";
|
||||
case eof: return L"eof";
|
||||
default: return L"(unknown?)";
|
||||
}
|
||||
}
|
||||
wstring ToString() const // string to show the content of token for debugging
|
||||
{
|
||||
let kindStr = TokenKindToString(kind);
|
||||
switch (kind)
|
||||
{
|
||||
case numberliteral: return kindStr + wstrprintf(L" %f", number);
|
||||
case stringliteral: return kindStr + L" '" + symbol + L"'";
|
||||
case identifier: case keyword: case punctuation: return kindStr + L" " + symbol;
|
||||
default: return kindStr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class LexerError : public ConfigError
|
||||
{
|
||||
public:
|
||||
LexerError(const wstring & msg, TextLocation where) : ConfigError(msg, where) { }
|
||||
/*ConfigError::*/ const wchar_t * kind() const { return L"tokenizing"; }
|
||||
};
|
||||
|
||||
private:
|
||||
__declspec_noreturn static void Fail(wstring msg, Token where) { throw LexerError(msg, where.beginLocation); }
|
||||
|
||||
Token currentToken;
|
||||
// consume input characters to form a next token
|
||||
// - this function mutates the cursor, but does not set currentToken
|
||||
// - white space and comments are skipped
|
||||
// - including files is handled here
|
||||
// - the cursor is left on the first character that does not belong to the token
|
||||
// TODO: need to know whether we want to see '\n' or not
|
||||
Token NextToken()
|
||||
{
|
||||
auto ch = GotChar();
|
||||
// skip white space
|
||||
// We remember whether we crossed a line end. Dictionary assignments end at newlines if syntactically acceptable.
|
||||
bool crossedLineEnd = (GetCursor().lineNo == 0 && GetCursor().charPos == 0);
|
||||
while (iswblank(ch) || ch == '\n' || ch == '\r')
|
||||
{
|
||||
crossedLineEnd |= (ch == '\n' || ch == '\r');
|
||||
ch = GetChar();
|
||||
}
|
||||
Token t(GetCursor());
|
||||
t.isLineInitial = crossedLineEnd;
|
||||
// handle end of (include) file
|
||||
if (ch == 0)
|
||||
{
|
||||
if (IsInInclude())
|
||||
{
|
||||
PopSourceFile();
|
||||
t = NextToken(); // tail call--the current 't' gets dropped/ignored
|
||||
t.isLineInitial = true; // eof is a line end
|
||||
return t;
|
||||
}
|
||||
// really end of all source code: we are done. If calling this function multiple times, we will keep returning this.
|
||||
t.kind = eof;
|
||||
}
|
||||
else if (iswdigit(ch) || (ch == L'.' && iswdigit(GotCharPtr()[1]))) // --- number
|
||||
{
|
||||
let beginPtr = GotCharPtr();
|
||||
wchar_t * endPtr = nullptr;
|
||||
t.number = wcstod(beginPtr, &endPtr); // BUGBUG: this seems to honor locale settings. We need one that doesn't. With this, CNTK won't parse right in Germany.
|
||||
if (endPtr == beginPtr) Fail(L"parsing number", t); // should not really happen!
|
||||
t.kind = numberliteral;
|
||||
if (endPtr[0] == L'.' && endPtr[-1] == L'.') // prevent 1..2 from begin tokenized 1. .2
|
||||
endPtr--;
|
||||
ConsumeChars(endPtr - beginPtr);
|
||||
}
|
||||
else if (iswalpha(ch) || ch == L'_') // --- identifier or keyword
|
||||
{
|
||||
while (iswalpha(ch) || ch == L'_' || iswdigit(ch)) // inside we also allow digits
|
||||
{
|
||||
t.symbol.push_back(ch);
|
||||
ch = GetChar();
|
||||
}
|
||||
// check against keyword list
|
||||
if (keywords.find(t.symbol) != keywords.end()) t.kind = keyword;
|
||||
else t.kind = identifier;
|
||||
// special case: include "path"
|
||||
if (t.symbol == L"include")
|
||||
{
|
||||
let nameTok = NextToken(); // must be followed by a string literal
|
||||
if (nameTok.kind != stringliteral) Fail(L"'include' must be followed by a quoted string", nameTok);
|
||||
let path = nameTok.symbol; // TODO: some massaging of the path
|
||||
PushSourceFile(SourceFile(path)); // current cursor is right after the pathname; that's where we will pick up later
|
||||
return NextToken();
|
||||
}
|
||||
}
|
||||
else if (ch == L'"' || ch == 0x27) // --- string literal
|
||||
{
|
||||
t.kind = stringliteral;
|
||||
let q = ch; // remember quote character
|
||||
ch = GetChar(); // consume the quote character
|
||||
while (ch != 0 && ch != q) // note: our strings do not have any escape characters to consider
|
||||
{
|
||||
t.symbol.append(1, ch);
|
||||
ch = GetChar();
|
||||
}
|
||||
if (ch == 0) // runaway string
|
||||
Fail(L"string without closing quotation mark", t);
|
||||
GetChar(); // consume the closing quote
|
||||
}
|
||||
else // --- punctuation
|
||||
{
|
||||
t.kind = punctuation;
|
||||
t.symbol = ch;
|
||||
t.symbol.append(1, GetChar()); // first try two-char punctuation
|
||||
if (punctuations.find(t.symbol) != punctuations.end())
|
||||
GetChar(); // it is a two-char one: need to consume the second one of them
|
||||
else // try single-char one
|
||||
{
|
||||
t.symbol.pop_back(); // drop the last one & try again
|
||||
if (punctuations.find(t.symbol) == punctuations.end()) // unknown
|
||||
Fail(L"unexpected character: " + t.symbol, t);
|
||||
}
|
||||
// special case: comments
|
||||
if (t.symbol == L"#" || t.symbol == L"//")
|
||||
{
|
||||
ConsumeChars(wcslen(GotCharPtr()));
|
||||
return NextToken();
|
||||
}
|
||||
else if (t.symbol == L"/*")
|
||||
{
|
||||
ch = GotChar();
|
||||
while (ch != 0 && !(ch == L'*' && GetChar() == L'/')) // note: this test leverages short-circuit evaluation semantics of C
|
||||
ch = GetChar();
|
||||
if (ch == 0)
|
||||
Fail(L"comment without closing */", t);
|
||||
GetChar(); // consume the final '/'
|
||||
return NextToken(); // and return the next token
|
||||
}
|
||||
}
|
||||
return t;
|
||||
}
|
||||
public:
|
||||
const Token & GotToken() { return currentToken; }
|
||||
void ConsumeToken() { currentToken = NextToken(); }
|
||||
const Token & GetToken()
|
||||
{
|
||||
ConsumeToken();
|
||||
return GotToken();
|
||||
}
|
||||
|
||||
// some simple test function
|
||||
void Test()
|
||||
{
|
||||
let lexerTest = L"new CNTK [ do = (train:eval) # main\ntrain=/*test * */if eval include 'c:/me/test.txt' then 13 else array[1..10](i=>i*i); eval=\"a\"+'b' // line-end\n ] 'a\nb\nc' new";
|
||||
PushSourceFile(SourceFile(L"(command line)", lexerTest));
|
||||
while (GotToken().kind != Lexer::TokenKind::eof)
|
||||
{
|
||||
let & token = GotToken(); // get first token
|
||||
fprintf(stderr, "%ls\n", token.ToString().c_str());
|
||||
ConsumeToken();
|
||||
}
|
||||
Fail(L"error test", GetCursor());
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// parser -- parses configurations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// diagnostics helper: print the content
|
||||
void Expression::Dump(int indent) const
|
||||
{
|
||||
fprintf(stderr, "%*s", indent, "");
|
||||
if (op == L"s") fprintf(stderr, "'%ls' ", s.c_str());
|
||||
else if (op == L"d") fprintf(stderr, "%.f ", d);
|
||||
else if (op == L"b") fprintf(stderr, "%s ", b ? "true" : "false");
|
||||
else if (op == L"id") fprintf(stderr, "%ls ", id.c_str());
|
||||
else if (op == L"new" || op == L"array" || op == L".") fprintf(stderr, "%ls %ls ", op.c_str(), id.c_str());
|
||||
else fprintf(stderr, "%ls ", op.c_str());
|
||||
if (!args.empty())
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
for (const auto & arg : args)
|
||||
arg->Dump(indent + 2);
|
||||
}
|
||||
if (!namedArgs.empty())
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
for (const auto & arg : namedArgs)
|
||||
{
|
||||
fprintf(stderr, "%*s%ls =\n", indent + 2, "", arg.first.c_str());
|
||||
arg.second.second->Dump(indent + 4);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
class Parser : public Lexer
|
||||
{
|
||||
// errors
|
||||
class ParseError : public ConfigError
|
||||
{
|
||||
public:
|
||||
ParseError(const wstring & msg, TextLocation where) : ConfigError(msg, where) { }
|
||||
/*ConfigError::*/ const wchar_t * kind() const { return L"parsing"; }
|
||||
};
|
||||
|
||||
__declspec_noreturn static void Fail(const wstring & msg, Token where) { throw ParseError(msg, where.beginLocation); }
|
||||
|
||||
//void Expected(const wstring & what) { Fail(strprintf("%ls expected", what.c_str()), GotToken().beginLocation); } // I don't know why this does not work
|
||||
void Expected(const wstring & what) { Fail(what + L" expected", GotToken().beginLocation); }
|
||||
|
||||
// this token must be punctuation 's'; check and get the next
|
||||
void ConsumePunctuation(const wchar_t * s)
|
||||
{
|
||||
let & tok = GotToken();
|
||||
if (tok.kind != punctuation || tok.symbol != s)
|
||||
Expected(L"'" + wstring(s) + L"'");
|
||||
ConsumeToken();
|
||||
}
|
||||
|
||||
// this token must be keyword 's'; check and get the next
|
||||
void ConsumeKeyword(const wchar_t * s)
|
||||
{
|
||||
let & tok = GotToken();
|
||||
if (tok.kind != keyword || tok.symbol != s)
|
||||
Expected(L"'" + wstring(s) + L"'");
|
||||
ConsumeToken();
|
||||
}
|
||||
|
||||
// this token must be an identifier; check and get the next token. Return the identifier.
|
||||
wstring ConsumeIdentifier()
|
||||
{
|
||||
let & tok = GotToken();
|
||||
if (tok.kind != identifier)
|
||||
Expected(L"identifier");
|
||||
let id = tok.symbol;
|
||||
ConsumeToken();
|
||||
return id;
|
||||
}
|
||||
|
||||
map<wstring, int> infixPrecedence; // precedence level of infix operators
|
||||
public:
|
||||
Parser(SourceFile && sourceFile) : Lexer()
|
||||
{
|
||||
infixPrecedence = map<wstring, int>
|
||||
{
|
||||
{ L".", 100 }, { L"[", 100 }, { L"(", 100 }, // also sort-of infix operands...
|
||||
{ L"*", 10 }, { L"/", 10 }, { L".*", 10 }, { L"**", 10 }, { L"%", 10 },
|
||||
{ L"+", 9 }, { L"-", 9 },
|
||||
{ L"==", 8 }, { L"!=", 8 }, { L"<", 8 }, { L"<=", 8 }, { L">", 8 }, { L">=", 8 },
|
||||
{ L"&&", 7 },
|
||||
{ L"||", 6 },
|
||||
{ L":", 5 },
|
||||
{ L"=>", 0 },
|
||||
};
|
||||
SetSourceFile(move(sourceFile));
|
||||
ConsumeToken(); // get the very first token
|
||||
}
|
||||
ExpressionPtr OperandFromTokenSymbol(const Token & tok) // helper to make an Operand expression with op==tok.symbol and then consume it
|
||||
{
|
||||
auto operand = make_shared<Expression>(tok.beginLocation, tok.symbol);
|
||||
ConsumeToken();
|
||||
return operand;
|
||||
}
|
||||
ExpressionPtr ParseOperand(bool stopAtNewline)
|
||||
{
|
||||
let & tok = GotToken();
|
||||
ExpressionPtr operand;
|
||||
if (tok.kind == numberliteral) // === numeral literal
|
||||
{
|
||||
operand = make_shared<Expression>(tok.beginLocation, L"d", tok.number, wstring(), false);
|
||||
ConsumeToken();
|
||||
}
|
||||
else if (tok.kind == stringliteral) // === string literal
|
||||
{
|
||||
operand = make_shared<Expression>(tok.beginLocation, L"s", 0.0, tok.symbol, false);
|
||||
ConsumeToken();
|
||||
}
|
||||
else if (tok.symbol == L"true" || tok.symbol == L"false") // === boolean literal
|
||||
{
|
||||
operand = make_shared<Expression>(tok.beginLocation, L"b", 0.0, wstring(), (tok.symbol == L"true"));
|
||||
ConsumeToken();
|
||||
}
|
||||
else if (tok.kind == identifier) // === dict member (unqualified)
|
||||
{
|
||||
operand = make_shared<Expression>(tok.beginLocation, L"id");
|
||||
operand->id = ConsumeIdentifier();
|
||||
}
|
||||
else if (tok.symbol == L"+" || tok.symbol == L"-" // === unary operators
|
||||
|| tok.symbol == L"!")
|
||||
{
|
||||
operand = make_shared<Expression>(tok.beginLocation, tok.symbol + L"("); // encoded as +( -( !(
|
||||
ConsumeToken();
|
||||
operand->args.push_back(ParseExpression(100, stopAtNewline));
|
||||
}
|
||||
else if (tok.symbol == L"new") // === new class instance
|
||||
{
|
||||
operand = OperandFromTokenSymbol(tok);
|
||||
operand->id = ConsumeIdentifier();
|
||||
operand->args.push_back(ParseOperand(stopAtNewline));
|
||||
}
|
||||
else if (tok.symbol == L"if") // === conditional expression
|
||||
{
|
||||
operand = OperandFromTokenSymbol(tok);
|
||||
operand->args.push_back(ParseExpression(0, false)); // [0] condition
|
||||
ConsumeKeyword(L"then");
|
||||
operand->args.push_back(ParseExpression(0, false)); // [1] then expression
|
||||
ConsumeKeyword(L"else");
|
||||
operand->args.push_back(ParseExpression(0, false)); // [2] else expression
|
||||
}
|
||||
else if (tok.symbol == L"(") // === nested parentheses
|
||||
{
|
||||
ConsumeToken();
|
||||
operand = ParseExpression(0, false/*go across newlines*/); // just return the content of the parens (they do not become part of the expression tree)
|
||||
ConsumePunctuation(L")");
|
||||
}
|
||||
else if (tok.symbol == L"[") // === dictionary constructor
|
||||
{
|
||||
operand = make_shared<Expression>(tok.beginLocation, L"[]");
|
||||
ConsumeToken();
|
||||
operand->namedArgs = ParseRecordMembers();
|
||||
ConsumePunctuation(L"]");
|
||||
}
|
||||
else if (tok.symbol == L"array") // === array constructor
|
||||
{
|
||||
operand = OperandFromTokenSymbol(tok);
|
||||
ConsumePunctuation(L"[");
|
||||
operand->args.push_back(ParseExpression(0, false)); // [0] first index
|
||||
ConsumePunctuation(L"..");
|
||||
operand->args.push_back(ParseExpression(0, false)); // [1] last index
|
||||
ConsumePunctuation(L"]");
|
||||
ConsumePunctuation(L"(");
|
||||
operand->args.push_back(ParseExpression(0, false)); // [2] one-argument lambda to initialize
|
||||
ConsumePunctuation(L")");
|
||||
}
|
||||
else
|
||||
Expected(L"operand");
|
||||
return operand; // not using returns above to avoid "not all control paths return a value"
|
||||
}
|
||||
ExpressionPtr ParseExpression(int requiredPrecedence, bool stopAtNewline)
|
||||
{
|
||||
auto left = ParseOperand(stopAtNewline); // get first operand
|
||||
for (;;)
|
||||
{
|
||||
let & opTok = GotToken();
|
||||
if (stopAtNewline && opTok.isLineInitial)
|
||||
break;
|
||||
let opIter = infixPrecedence.find(opTok.symbol);
|
||||
if (opIter == infixPrecedence.end()) // not an infix operator: we are done here, 'left' is our expression
|
||||
break;
|
||||
let opPrecedence = opIter->second;
|
||||
if (opPrecedence < requiredPrecedence) // operator below required precedence level: does not belong to this sub-expression
|
||||
break;
|
||||
let op = opTok.symbol;
|
||||
auto operation = make_shared<Expression>(opTok.beginLocation, op, left); // [0] is left operand; we will add [1] except for macro application
|
||||
// deal with special cases first
|
||||
// We treat member lookup (.), macro application (a()), and indexing (a[i]) together with the true infix operators.
|
||||
if (op == L".") // === reference of a dictionary item
|
||||
{
|
||||
ConsumeToken();
|
||||
operation->location = GotToken().beginLocation; // location of the identifier after the .
|
||||
operation->id = ConsumeIdentifier();
|
||||
}
|
||||
else if (op == L"=>")
|
||||
{
|
||||
if (left->op != L"id") // currently only allow for a single argument
|
||||
Expected(L"identifier");
|
||||
ConsumeToken();
|
||||
let macroArgs = make_shared<Expression>(left->location, L"()", left); // wrap identifier in a '()' macro-args expression
|
||||
// TODO: test parsing of i => j => i*j
|
||||
let body = ParseExpression(opPrecedence, stopAtNewline); // pass same precedence; this makes '=>' right-associative e.g.i=>j=>i*j
|
||||
operation->args[0] = macroArgs; // [0]: parameter list
|
||||
operation->args.push_back(body); // [1]: right operand
|
||||
}
|
||||
else if (op == L"(") // === macro application
|
||||
{
|
||||
// op = "(" means 'apply'
|
||||
// args[0] = lambda expression (lambda: op="=>", args[0] = param list, args[1] = expression with unbound vars)
|
||||
// args[1] = arguments (arguments: op="(), args=vector of expressions, one per arg; and namedArgs)
|
||||
operation->args.push_back(ParseMacroArgs(false)); // [1]: all arguments
|
||||
}
|
||||
else if (op == L"[") // === array index
|
||||
{
|
||||
ConsumeToken();
|
||||
operation->args.push_back(ParseExpression(0, false)); // [1]: index
|
||||
ConsumePunctuation(L"]");
|
||||
}
|
||||
else // === regular infix operator
|
||||
{
|
||||
ConsumeToken();
|
||||
let right = ParseExpression(opPrecedence + 1, stopAtNewline); // get right operand, or entire multi-operand expression with higher precedence
|
||||
operation->args.push_back(right); // [1]: right operand
|
||||
}
|
||||
left = operation;
|
||||
}
|
||||
return left;
|
||||
}
|
||||
// a macro-args expression lists position-dependent and optional parameters
|
||||
// This is used both for defining macros (LHS) and using macros (RHS).
|
||||
// Result:
|
||||
// op = "()"
|
||||
// args = vector of arguments (which are given comma-separated)
|
||||
// In case of macro definition, all arguments must be of type "id". Pass 'defining' to check for that.
|
||||
// namedArgs = dictionary of optional args
|
||||
// In case of macro definition, dictionary values are default values that are used if the argument is not given
|
||||
ExpressionPtr ParseMacroArgs(bool defining)
|
||||
{
|
||||
ConsumePunctuation(L"(");
|
||||
auto macroArgs = make_shared<Expression>(GotToken().beginLocation, L"()");
|
||||
for (;;)
|
||||
{
|
||||
let expr = ParseExpression(0, false); // this could be an optional arg (var = val)
|
||||
if (defining && expr->op != L"id") // when defining we only allow a single identifier
|
||||
Fail(L"argument identifier expected", expr->location);
|
||||
if (expr->op == L"id" && GotToken().symbol == L"=")
|
||||
{
|
||||
let id = expr->id; // 'expr' gets resolved (to 'id') and forgotten
|
||||
ConsumeToken();
|
||||
let defValueExpr = ParseExpression(0, false); // default value
|
||||
let res = macroArgs->namedArgs.insert(make_pair(id, make_pair(expr->location, defValueExpr)));
|
||||
if (!res.second)
|
||||
Fail(L"duplicate optional parameter '" + id + L"'", expr->location);
|
||||
}
|
||||
else
|
||||
macroArgs->args.push_back(expr); // [0..]: position args
|
||||
if (GotToken().symbol != L",")
|
||||
break;
|
||||
ConsumeToken();
|
||||
}
|
||||
ConsumePunctuation(L")");
|
||||
return macroArgs;
|
||||
}
|
||||
map<wstring, pair<TextLocation,ExpressionPtr>> ParseRecordMembers()
|
||||
{
|
||||
// A dictionary is a map
|
||||
// member identifier -> expression
|
||||
// Macro declarations are translated into lambdas, e.g.
|
||||
// F(A,B) = expr(A,B)
|
||||
// gets represented in the dictionary as
|
||||
// F = (A,B) => expr(A,B)
|
||||
// where a lambda expression has this structure:
|
||||
// op="=>"
|
||||
// args[0] = parameter list (op="()", with args (all of op="id") and namedArgs)
|
||||
// args[1] = expression with unbound arguments
|
||||
// An array constructor of the form
|
||||
// V[i:from..to] = expression of i
|
||||
// gets mapped to the explicit array operator
|
||||
// V = array[from..to] (i => expression of i)
|
||||
map<wstring, pair<TextLocation,ExpressionPtr>> members;
|
||||
auto idTok = GotToken();
|
||||
while (idTok.kind == identifier)
|
||||
{
|
||||
let location = idTok.beginLocation; // for error message
|
||||
let id = ConsumeIdentifier(); // the member's name
|
||||
// optional array constructor
|
||||
ExpressionPtr arrayIndexExpr, fromExpr, toExpr;
|
||||
if (GotToken().symbol == L"[")
|
||||
{
|
||||
// X[i:from..to]
|
||||
ConsumeToken();
|
||||
arrayIndexExpr = ParseOperand(false); // 'i' name of index variable
|
||||
if (arrayIndexExpr->op != L"id")
|
||||
Expected(L"identifier");
|
||||
ConsumePunctuation(L":");
|
||||
fromExpr = ParseExpression(0, false); // 'from' start index
|
||||
ConsumePunctuation(L"..");
|
||||
toExpr = ParseExpression(0, false); // 'to' end index
|
||||
ConsumePunctuation(L"]");
|
||||
}
|
||||
// optional macro args
|
||||
let parameters = (GotToken().symbol == L"(") ? ParseMacroArgs(true/*defining*/) : ExpressionPtr(); // optionally, macro arguments
|
||||
ConsumePunctuation(L"=");
|
||||
auto rhs = ParseExpression(0, true/*can end at newline*/); // and the right-hand side
|
||||
// if macro then rewrite it as an assignment of a lambda expression
|
||||
if (parameters)
|
||||
rhs = make_shared<Expression>(parameters->location, L"=>", parameters, rhs);
|
||||
// if array then rewrite it as an assignment of a array-constructor expression
|
||||
if (arrayIndexExpr)
|
||||
{
|
||||
// create a lambda expression over the index variable
|
||||
let macroArgs = make_shared<Expression>(arrayIndexExpr->location, L"()", arrayIndexExpr); // wrap identifier in a '()' macro-args expression
|
||||
let initLambdaExpr = make_shared<Expression>(arrayIndexExpr->location, L"=>", macroArgs, rhs); // [0] is id, [1] is body
|
||||
rhs = make_shared<Expression>(location, L"array");
|
||||
rhs->args.push_back(fromExpr); // [0] first index
|
||||
rhs->args.push_back(toExpr); // [1] last index
|
||||
rhs->args.push_back(initLambdaExpr); // [2] one-argument lambda to initialize
|
||||
}
|
||||
// insert
|
||||
let res = members.insert(make_pair(id, make_pair(location, rhs)));
|
||||
if (!res.second)
|
||||
Fail(L"duplicate member definition '" + id + L"'", location);
|
||||
// advance
|
||||
idTok = GotToken();
|
||||
if (idTok.symbol == L";")
|
||||
idTok = GetToken();
|
||||
}
|
||||
return members;
|
||||
}
|
||||
// top-level parse function parses dictonary members
|
||||
ExpressionPtr Parse()
|
||||
{
|
||||
let topMembers = ParseRecordMembers();
|
||||
if (GotToken().kind != eof)
|
||||
Fail(L"junk at end of source", GetCursor());
|
||||
ExpressionPtr topDict = make_shared<Expression>(GetCursor(), L"[]");
|
||||
topDict->namedArgs = topMembers;
|
||||
return topDict;
|
||||
}
|
||||
// simple test function for use during development
|
||||
static void Test()
|
||||
{
|
||||
let parserTest = L"a=1\na1_=13;b=2 // cmt\ndo = (print\n:train:eval) ; x = array[1..13] (i=>1+i*print.message==13*42) ; print = new PrintAction [ message = 'Hello World' ]";
|
||||
ParseConfigString(parserTest)->Dump();
|
||||
}
|
||||
};
|
||||
|
||||
// globally exported functions to execute the parser
|
||||
static ExpressionPtr Parse(SourceFile && sourceFile) { return Parser(move(sourceFile)).Parse(); }
|
||||
ExpressionPtr ParseConfigString(wstring text) { return Parse(SourceFile(L"(command line)", text)); }
|
||||
ExpressionPtr ParseConfigFile(wstring path) { return Parse(SourceFile(path)); }
|
||||
|
||||
}}} // namespaces
|
|
@ -0,0 +1,103 @@
|
|||
// ConfigParser.h -- config parser (syntactic only, that is, source -> Expression tree)
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Basics.h"
|
||||
#include "ScriptableObjects.h"
|
||||
#include "File.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace BS {
|
||||
|
||||
using namespace std;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TextLocation -- holds a pointer into a source file
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct SourceFile // content of one source file (only in this header because TextLocation's private member uses it)
|
||||
{
|
||||
/*const*/ wstring path; // where it came from
|
||||
/*const*/ vector<wstring> lines; // source code lines
|
||||
SourceFile(wstring location, wstring text); // from string, e.g. command line
|
||||
SourceFile(wstring path); // from file
|
||||
};
|
||||
|
||||
struct TextLocation // position in the text. Lightweight value struct that we can copy around, even into dictionaries etc., for error messages
|
||||
{
|
||||
// source-code locations are given by line number, character position, and the source file
|
||||
size_t lineNo, charPos; // line number and character index (0-based)
|
||||
const SourceFile & GetSourceFile() const { return sourceFileMap[sourceFileAsIndex]; } // get the corresponding source-code line
|
||||
|
||||
// helpers for pretty-printing errors: Show source-code line with ...^ under it to mark up the point of error
|
||||
static void PrintIssue(const vector<TextLocation> & locations, const wchar_t * errorKind, const wchar_t * kind, const wchar_t * what);
|
||||
static void Trace(TextLocation, const wchar_t * traceKind, const wchar_t * op, const wchar_t * exprPath);
|
||||
|
||||
// construction
|
||||
TextLocation() : lineNo(SIZE_MAX), charPos(SIZE_MAX), sourceFileAsIndex(SIZE_MAX) { } // default constructor constructs an unmissably invalid object
|
||||
bool IsValid() const;
|
||||
|
||||
// register a new source file and return a TextPosition that points to its start
|
||||
static TextLocation NewSourceFile(SourceFile && sourceFile);
|
||||
|
||||
private:
|
||||
size_t sourceFileAsIndex; // source file is remembered in the value struct as an index into the static sourceFileMap[]
|
||||
// the meaning of the 'sourceFile' index is global, stored in this static map
|
||||
static vector<SourceFile> sourceFileMap;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ConfigError -- all errors from processing the config files are reported as ConfigError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class ConfigError : public Microsoft::MSR::ScriptableObjects::ScriptingError
|
||||
{
|
||||
vector<TextLocation> locations; // error location (front()) and evaluation parents (upper)
|
||||
public:
|
||||
// Note: All our Error objects use wide strings, which we round-trip through runtime_error as utf8.
|
||||
ConfigError(const wstring & msg, TextLocation where) : Microsoft::MSR::ScriptableObjects::ScriptingError(msra::strfun::utf8(msg)) { locations.push_back(where); }
|
||||
|
||||
// these are used in pretty-printing
|
||||
TextLocation where() const { return locations.front(); } // where the error happened
|
||||
virtual const wchar_t * kind() const = 0; // e.g. "warning" or "error"
|
||||
|
||||
// pretty-print this as an error message
|
||||
void /*ScriptingError::*/PrintError() const { TextLocation::PrintIssue(locations, L"error", kind(), msra::strfun::utf16(what()).c_str()); }
|
||||
void AddLocation(TextLocation where) { locations.push_back(where); }
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Expression -- the entire config is a tree of Expression types
|
||||
// We don't use polymorphism here because C++ is so verbose...
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct Expression
|
||||
{
|
||||
wstring op; // operation, encoded as a string; 'symbol' for punctuation and keywords, otherwise used in constructors below ...TODO: use constexpr
|
||||
wstring id; // identifier; op == "id", "new", "array", and "." (if macro then it also has args)
|
||||
wstring s; // string literal; op == "s"
|
||||
double d; // numeric literal; op == "d"
|
||||
bool b; // boolean literal; op == "b"
|
||||
typedef shared_ptr<struct Expression> ExpressionPtr;
|
||||
vector<ExpressionPtr> args; // position-dependent expression/function args
|
||||
map<wstring, pair<TextLocation,ExpressionPtr>> namedArgs; // named expression/function args; also dictionary members (loc is of the identifier)
|
||||
TextLocation location; // where in the source code (for downstream error reporting)
|
||||
// constructors
|
||||
Expression(TextLocation location) : location(location), d(0.0), b(false) { }
|
||||
Expression(TextLocation location, wstring op) : location(location), d(0.0), b(false), op(op) { }
|
||||
Expression(TextLocation location, wstring op, double d, wstring s, bool b) : location(location), d(d), s(s), b(b), op(op) { }
|
||||
Expression(TextLocation location, wstring op, ExpressionPtr arg) : location(location), d(0.0), b(false), op(op) { args.push_back(arg); }
|
||||
Expression(TextLocation location, wstring op, ExpressionPtr arg1, ExpressionPtr arg2) : location(location), d(0.0), b(false), op(op) { args.push_back(arg1); args.push_back(arg2); }
|
||||
// diagnostics helper: print the content
|
||||
void Dump(int indent = 0) const;
|
||||
};
|
||||
typedef Expression::ExpressionPtr ExpressionPtr; // circumvent some circular definition problem
|
||||
|
||||
// access the parser through one of these two functions
|
||||
ExpressionPtr ParseConfigString(wstring text);
|
||||
ExpressionPtr ParseConfigFile(wstring path);
|
||||
|
||||
}}} // namespaces
|
|
@ -0,0 +1,217 @@
|
|||
// BrainScriptTest.cpp -- some tests
|
||||
|
||||
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
|
||||
|
||||
#include "Basics.h"
|
||||
#include "BrainScriptEvaluator.h"
|
||||
#include "BrainScriptParser.h"
|
||||
|
||||
#ifndef let
|
||||
#define let const auto
|
||||
#endif
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace BS {
|
||||
|
||||
using namespace std;
|
||||
using namespace msra::strfun;
|
||||
|
||||
// Note: currently this seems to be the master copy; got to check whether the other one was also changed
|
||||
|
||||
//extern wstring standardFunctions, computationNodes, commonMacros;
|
||||
|
||||
#if 1 // TODO: these may be newer, merge into Experimentalthingy
|
||||
|
||||
static wstring standardFunctions =
|
||||
L"Print(value, format='') = new PrintAction [ what = value /*; how = format*/ ] \n"
|
||||
L"Fail(msg) = new FailAction [ what = msg ] \n"
|
||||
L"RequiredParameter(message) = Fail('RequiredParameter: ' + message) \n"
|
||||
L"Format(value, format) = new StringFunction [ what = 'Format' ; arg = value ; how = format ] \n"
|
||||
L"Replace(s, from, to) = new StringFunction [ what = 'Replace' ; arg = s ; replacewhat = from ; withwhat = to ] \n"
|
||||
L"Substr(s, begin, num) = new StringFunction [ what = 'Substr' ; arg = s ; pos = begin ; chars = num ] \n"
|
||||
L"Chr(c) = new StringFunction [ what = 'Chr' ; arg = c ] \n"
|
||||
L"Floor(x) = new NumericFunction [ what = 'Floor' ; arg = x ] \n"
|
||||
L"Length(x) = new NumericFunction [ what = 'Length' ; arg = x ] \n"
|
||||
L"Ceil(x) = -Floor(-x) \n"
|
||||
L"Round(x) = Floor(x+0.5) \n"
|
||||
L"Abs(x) = if x >= 0 then x else -x \n"
|
||||
L"Sign(x) = if x > 0 then 1 else if x < 0 then -1 else 0 \n"
|
||||
L"Min(a,b) = if a < b then a else b \n"
|
||||
L"Max(a,b) = if a > b then a else b \n"
|
||||
L"Fac(n) = if n > 1 then Fac(n-1) * n else 1 \n"
|
||||
;
|
||||
|
||||
static wstring computationNodes = // BUGBUG: optional args not working yet, some scope problem causing a circular reference
|
||||
L"Mean(z, tag='') = new ComputationNode [ class = 'MeanNode' ; inputs = z /* ; tag = tag */ ]\n"
|
||||
L"InvStdDev(z, tag='') = new ComputationNode [ class = 'InvStdDevNode' ; inputs = z /* ; tag = tag */ ]\n"
|
||||
L"PerDimMeanVarNormalization(feat,mean,invStdDev, tag='') = new ComputationNode [ class = 'PerDimMeanVarNormalizationNode' ; inputs = feat:mean:invStdDev /* ; tag = tag */ ]\n"
|
||||
L"Parameter(outD, inD, tag='parameter') = new ComputationNode [ class = 'LearnableParameterNode' ; outDim = outD ; inDim = inD /*; tag = tag*/ ]\n"
|
||||
L"Input(dim,tag='features') = Parameter(dim,1,tag=tag) // TODO: for now \n"
|
||||
L"RowSlice(firstRow, rows, features, tag='') = new ComputationNode [ class = 'RowSliceNode' ; inputs = features ; first = firstRow ; num = rows /* ; tag = tag */ ]\n"
|
||||
L"Delay(in, delay, tag='') = new RecurrentComputationNode [ class = 'DelayNode' ; inputs = in ; deltaT = -delay /* ; tag = tag */ ]\n"
|
||||
L"Sigmoid(z, tag='') = new ComputationNode [ class = 'SigmoidNode' ; inputs = z /* ; tag = tag */ ]\n"
|
||||
L"Log(z, tag='') = new ComputationNode [ class = 'LogNode' ; inputs = z /* ; tag = tag */ ]\n"
|
||||
L"CrossEntropyWithSoftmax(labels, outZ, tag='') = new ComputationNode [ class = 'CrossEntropyWithSoftmaxNode' ; inputs = labels:outZ /* ; tag = tag */ ]\n"
|
||||
L"ErrorPrediction(labels, outZ, tag='') = new ComputationNode [ class = 'ErrorPredictionNode' ; inputs = labels:outZ /* ; tag = tag */ ]\n"
|
||||
;
|
||||
|
||||
static wstring commonMacros = // TODO: rename rows and cols to inDim and outDim or vice versa, whichever it is
|
||||
L"BFF(in, rows, cols) = [ B = Parameter(rows, 1/*init = fixedvalue, value = 0*/) ; W = Parameter(rows, cols) ; z = W*in+B ] \n"
|
||||
L"SBFF(in, rows, cols) = [ Eh = Sigmoid(BFF(in, rows, cols).z) ] \n "
|
||||
L"MeanVarNorm(feat) = PerDimMeanVarNormalization(feat, Mean(feat), InvStdDev(feat)) \n"
|
||||
L"LogPrior(labels) = Log(Mean(labels)) \n"
|
||||
;
|
||||
|
||||
#endif
|
||||
|
||||
void SomeTests()
|
||||
{
|
||||
try
|
||||
{
|
||||
// collecting all sorts of test cases here
|
||||
const wchar_t * parserTests[] =
|
||||
{
|
||||
L"do = Parameter(13,42) * Input(42) + Parameter(13,1)"
|
||||
,
|
||||
L"do = Print(array [1..10] (i=>i*i))"
|
||||
,
|
||||
L"do = new PrintAction [ what = 'abc' ]"
|
||||
,
|
||||
L"do = Print(new StringFunction [ x = 13 ; y = 42 ; what = 'Format' ; how = '.2' ; arg = x*y ])"
|
||||
,
|
||||
L"do = Print(\"new StringFunction [ what = 'Format' ; how = '.2' ; arg = '13 > 42' ]\")"
|
||||
,
|
||||
L"do = new PrintAction [ what = if 13 > 42 || 12 > 1 then 'Hello World' + \"!\" else 'Oops?']"
|
||||
,
|
||||
L"i2s(i) = new StringFunction [ what = 'Format' ; arg = i ; how = '.2' ] ; do = Print('result=' + i2s((( [ v = (i => i + delta) ].v(5)))+13)) ; delta = 42 "
|
||||
,
|
||||
L"do = Print(1+2*3) : Print('hello'+' world')"
|
||||
,
|
||||
L"do = Print(Format( (13:(fortytwo:1):100), '')) ; fortytwo=42 "
|
||||
,
|
||||
L"do = Print(val) ; val=if !false then 42 else -+-++-13:[a='a';b=42]:+14; arr = array [1..10] (i => 2*i)"
|
||||
,
|
||||
L"do = Print(arg) ; N = 5 ; arr = array [1..N] (i => if i < N then arr[i+1]*i else N) ; arg = arr "
|
||||
,
|
||||
L"do = Print(val) ; val = [ v = (i => i + offset) ].v(42) ; offset = 13 "
|
||||
,
|
||||
// #12: DNN with recursion
|
||||
L"do = Print(val) \n"
|
||||
L"val = new NDLComputationNetwork [\n"
|
||||
L" featDim=40*31 ; labelDim=9000 ; hiddenDim=2048 ; numHiddenLayers = 3 \n"
|
||||
L" myFeatures = Input(featDim) ; myLabels = Input(labelDim) \n"
|
||||
L" featNorm = MeanVarNorm(myFeatures) \n"
|
||||
L" HiddenStack(layer) = if layer > 1 then SBFF(HiddenStack(layer - 1).Eh, hiddenDim, hiddenDim) else SBFF(featNorm, hiddenDim, featDim) \n"
|
||||
L" outLayer = BFF(HiddenStack(numHiddenLayers).Eh, labelDim, hiddenDim) \n"
|
||||
L" outZ = outLayer.z \n"
|
||||
L" CE = CrossEntropyWithSoftmax(myLabels, outZ) \n"
|
||||
L" Err = ErrorPrediction(myLabels, outZ) \n"
|
||||
L" logPrior = LogPrior(myLabels) \n"
|
||||
L" ScaledLogLikelihood = outZ - logPrior \n"
|
||||
L"]\n"
|
||||
,
|
||||
// #13: factorial
|
||||
L"do = Print(fac(5)) ; fac(i) = if i > 1 then fac(i-1)*i else 1 "
|
||||
,
|
||||
// #14: Fibonacci sequence with memoization
|
||||
L"do = Print(fibs(10)) ; fibs(n) = [ vals = array[1..n] (i => if i < 3 then i-1 else vals[i-1]+vals[i-2]) ].vals[n] "
|
||||
,
|
||||
// #15: DNN with array
|
||||
L"do = Print(val) \n"
|
||||
L"val = new NDLComputationNetwork [\n"
|
||||
L" featDim=40*31 ; labelDim=9000 ; hiddenDim=2048 ; numHiddenLayers = 3 \n"
|
||||
L" myFeatures = Input(featDim, tag='features') ; myLabels = Input(labelDim, tag='labels') \n"
|
||||
L" featNorm = MeanVarNorm(myFeatures) \n"
|
||||
L" layers[layer:1..numHiddenLayers] = if layer > 1 then SBFF(layers[layer-1].Eh, hiddenDim, hiddenDim) else SBFF(featNorm, hiddenDim, featDim) \n"
|
||||
L" outLayer = BFF(layers[numHiddenLayers].Eh, labelDim, hiddenDim) \n"
|
||||
L" outZ = outLayer.z + Delay(outZ, 1) \n"
|
||||
L" CE = CrossEntropyWithSoftmax(myLabels, outZ) \n"
|
||||
L" Err = ErrorPrediction(myLabels, outZ) \n"
|
||||
L" logPrior = LogPrior(myLabels) \n"
|
||||
L" ScaledLogLikelihood = outZ - logPrior \n"
|
||||
L"]\n"
|
||||
,
|
||||
// #16: windowed RNN
|
||||
L"do = Print(val) \n"
|
||||
L"val = new NDLComputationNetwork [ \n"
|
||||
L" hiddenDim = 512 \n"
|
||||
L" numHiddenLayers = 2 \n"
|
||||
L" T = 3 // total context window \n"
|
||||
L" \n"
|
||||
L" // data sources \n"
|
||||
L" featDim = 40 ; labelDim = 9000 \n"
|
||||
L" myFeatures = Input(featDim) ; myLabels = Input(labelDim) \n"
|
||||
L" \n"
|
||||
L" // split the augmented input vector into individual frame vectors \n"
|
||||
L" subframes[t:0..T - 1] = RowSlice(t * featDim, featDim, myFeatures) \n"
|
||||
L" \n"
|
||||
L" // hidden layers \n"
|
||||
L" layers[layer:1..numHiddenLayers] = [ // each layer stores a dict that stores its hidden fwd and bwd state vectors \n"
|
||||
L" // model parameters \n"
|
||||
L" W_fwd = Parameter(hiddenDim, featDim) // Parameter(outdim, indim) \n"
|
||||
L" W_bwd = if layer > 1 then Parameter(hiddenDim, hiddenDim) else Fail('no W_bwd') // input-to-hidden \n"
|
||||
L" H_fwd = Parameter(hiddenDim, hiddenDim) // hidden-to-hidden \n"
|
||||
L" H_bwd = Parameter(hiddenDim, hiddenDim) \n"
|
||||
L" b = Parameter(hiddenDim, 1) // bias \n"
|
||||
L" // shared part of activations (input connections and bias) \n"
|
||||
L" z_shared[t:0..T-1] = (if layer > 1 \n"
|
||||
L" then W_fwd * layers[layer - 1].h_fwd[t] + W_bwd * layers[layer - 1].h_bwd[t] \n"
|
||||
L" else W_fwd * subframes[t] \n"
|
||||
L" ) + b \n"
|
||||
L" // recurrent part and non-linearity \n"
|
||||
L" step(H, h, dt, t) = Sigmoid(if (t + dt >= 0 && t + dt < T) \n"
|
||||
L" then z_shared[t] + H * h[t + dt] \n"
|
||||
L" else z_shared[t]) \n"
|
||||
L" h_fwd[t:0..T-1] = step(H_fwd, h_fwd, -1, t) \n"
|
||||
L" h_bwd[t:0..T-1] = step(H_bwd, h_bwd, 1, t) \n"
|
||||
L" ] \n"
|
||||
L" // output layer --linear only at this point; Softmax is applied later \n"
|
||||
L" outLayer = [ \n"
|
||||
L" // model parameters \n"
|
||||
L" W_fwd = Parameter(labelDim, hiddenDim) \n"
|
||||
L" W_bwd = Parameter(labelDim, hiddenDim) \n"
|
||||
L" b = Parameter(labelDim, 1) \n"
|
||||
L" // output \n"
|
||||
L" topHiddenLayer = layers[numHiddenLayers] \n"
|
||||
L" centerT = Floor(T/2) \n"
|
||||
L" z = W_fwd * topHiddenLayer.h_fwd[centerT] + W_bwd * topHiddenLayer.h_bwd[centerT] + b \n"
|
||||
L" ] \n"
|
||||
L" outZ = outLayer.z // we only want this one & don't care about the rest of this dictionary \n"
|
||||
L" \n"
|
||||
L" // define criterion nodes \n"
|
||||
L" CE = CrossEntropyWithSoftmax(myLabels, outZ) \n"
|
||||
L" Err = ErrorPrediction(myLabels, outZ) \n"
|
||||
L" \n"
|
||||
L" // define output node for decoding \n"
|
||||
L" logPrior = LogPrior(myLabels) \n"
|
||||
L" ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output) \n"
|
||||
L"]\n"
|
||||
,
|
||||
L" \n" // this fails because dict is outside val; expression name is not local to it
|
||||
L"do = Print(val) \n"
|
||||
L"dict = [ outY = Input(13) ] ; val = new NDLComputationNetwork [ outZ = dict.outY \n"
|
||||
L"]\n"
|
||||
,
|
||||
L"f(x,option='default') = Print(option); do = f(42,option='value')"
|
||||
,
|
||||
NULL
|
||||
};
|
||||
let first = 0; // 0 for all
|
||||
bool oneOnly = first > 0;
|
||||
for (size_t i = first; parserTests[i]; i++)
|
||||
{
|
||||
fprintf(stderr, "\n### Test %d ###\n\n", (int)i), fflush(stderr);
|
||||
let parserTest = parserTests[i];
|
||||
let expr = ParseConfigString(standardFunctions + computationNodes + commonMacros + parserTest);
|
||||
//expr->Dump();
|
||||
Do(expr);
|
||||
if (oneOnly)
|
||||
break;
|
||||
}
|
||||
}
|
||||
catch (const ConfigError & err)
|
||||
{
|
||||
err.PrintError();
|
||||
}
|
||||
}
|
||||
|
||||
}}} // namespaces
|
|
@ -0,0 +1,400 @@
|
|||
CNTK configuration language redesign (ongoing work)
|
||||
====================================
|
||||
|
||||
F. Seide, August 2015
|
||||
|
||||
These are the original notes from before coding began. Basic ideas are correct, but may be a bit outdated.
|
||||
|
||||
- config specifies all configurable runtime objects and their initialization parameters
|
||||
- basic concepts: dictionaries and runtime-object definitions
|
||||
- basic syntactic elements:
|
||||
- runtime object definitions // new classname initargsdictionary
|
||||
- macro definition // M(x,y,z) = expression // expression uses x, y, and z
|
||||
- expressions
|
||||
- dictionaries // [ a=expr1 ; c=expr2 ]
|
||||
- math ops and parentheses as usual // W*v+a, n==0
|
||||
- conditional expression // if c then a else b
|
||||
- array // a:b:c ; array [1..N] (i => f(i))
|
||||
- syntax supports usual math and boolean expressions
|
||||
- functions are runtime objects defined through macros, e.g. Replace(s,with,withwhat) = String [ from=s ; replacing=what ; with=withwhat ]
|
||||
- config is parsed eagerly but evaluated lazily
|
||||
- CNTK command line "configFile=conf.bs a=b c=d" expands to "new CNTK {content of conf.bs} + [ a=b ; c=d ]"
|
||||
|
||||
current issues
|
||||
--------------
|
||||
|
||||
- syntax does not distinguish between dictionary members, intermediate variables, and actual parameter names
|
||||
- dictionary editing needs to allow a.b.c syntax; and subtracting is not pretty as it needs dummy values -> maybe use a delete symbol? a=delete?
|
||||
- missing: optional parameters to macros; and how this whole thing would work with MEL
|
||||
|
||||
grammar
|
||||
-------
|
||||
|
||||
// --- top level defines a runtime object of class 'CNTK'
|
||||
// example: new CNTK [ actions=train ; train=TrainAction [ ... ] ] // where "new CNTK [" is prepended by the command-line parser
|
||||
|
||||
$ = $dictitems // this is a dictionary without enclosing [ ... ] that defines instantiation args of CNTK class
|
||||
|
||||
// --- defining a runtime object and its parameters
|
||||
// example: new ComputeNode [ class="Plus" ; arg1=A ; arg2=B ]
|
||||
|
||||
$newinstance = 'new' $classname $expr
|
||||
where $expr must be a dictionary expression
|
||||
$classname = $identifier
|
||||
where $identifier is one of the known pre-defined C++ class names
|
||||
|
||||
// --- dictionaries are groups of key-value pairs.
|
||||
// Dictionaries are expressions.
|
||||
// Multiple dictionaries can be edited (dict1 + dict2) where dict2 members override dict1 ones of the same name.
|
||||
// examples: [ arg1=A ; arg2=B ]
|
||||
// dict1 + (if (dpt && layer < totallayers) then [ numiter = 5 ] else []) // overrides 'numiter' in 'dict1' if condition is fulfilled
|
||||
|
||||
$dictdef = '[' $dictitems ']'
|
||||
$dictitems = $itemdef*
|
||||
|
||||
$itemdef = $paramdef // var=val
|
||||
| $macrodef // macro(args)=expression
|
||||
|
||||
$paramdef = $identifier '=' $expr // e.g. numiter = 13
|
||||
$macrodef = $identifier '(' $arg (',' $arg) ')' = $expr // e.g. sqr(x) = x*x
|
||||
|
||||
// --- expressions
|
||||
// Expressions are what you'd expect. Infix operators those of C, with addition of '.*' '**' ':' '..'
|
||||
// ML-style "let ... in" (expression-local variables) are possible but not super-pretty: [ a=13; b=42; res=a*b ].res
|
||||
// There are infix ops for strings (concatenation) and dictionaries (editing).
|
||||
|
||||
$expr = $operand
|
||||
| $expr $infixop $operand
|
||||
| $expr '.' $memberref // dict.member TODO: fix this; memberrefs exist without '.'
|
||||
where $expr is a dictionary
|
||||
| $expr '(' $expr (',' $expr)* ')' // a(13) also: dict.a(13); note: partial application possible, i.e. macros may be passed as args and curried
|
||||
where $expr is a macro
|
||||
| $expr '[' $expr ']' // h_fwd[t]
|
||||
where first $expr must be a array and second $expr a number (that must be an integer value)
|
||||
$infixop = // highest precedence level
|
||||
'*' // numbers; also magic short-hand for "Times" and "Scale" ComputeNodes
|
||||
| '/' // numbers; Scale ComputeNode
|
||||
| '.*' // ComputeNodes: component-wise product
|
||||
| '**' // numbers (exponentiation, FORTRAN style!)
|
||||
| '%' // numbers: remainder
|
||||
// next lower precedence level
|
||||
| '+' // numbers; ComputeNodes; strings; dictionary editing
|
||||
| '-' // numbers; ComputeNodes; dictionary editing
|
||||
// next lower precedence level
|
||||
| '==' '!=' '<' '>' '<=' '>=' // applies to config items only; objects other than boxed primitive values are compared by object identity not content
|
||||
// next lower precedence level
|
||||
| '&&' // booleans
|
||||
// next lower precedence level
|
||||
| '||' | '^' // booleans
|
||||
// next lower precedence level
|
||||
| ':' // concatenate items and/or arrays --TODO: can arrays have nested arrays? Syntax?
|
||||
$operand = $literal // "Hello World"
|
||||
| $memberref // a
|
||||
| $dictdef // [ a="Hello World" ]
|
||||
| $newinstance // new ComputeNode [ ... ]
|
||||
| ('-' | '+' | '!') $operand // -X+Y
|
||||
| '(' $expr ')' // (a==b) || (c==d)
|
||||
| $arrayconstructor // array [1..N] (i => i*i)
|
||||
|
||||
$literal = $number // built-in literal types are numeric, string, and boolean
|
||||
| $string
|
||||
| $boolconst
|
||||
$number = // floating point number; no separate 'int' type, 'int' args are checked at runtime to be non-fractional
|
||||
$string = // characters enclosed in "" or ''; no escape characters inside, use combinations of "", '', and + instead (TODO: do we need string interpolation?).
|
||||
// Strings may span multiple lines (containing newlines)
|
||||
$boolconst = 'true' | 'false'
|
||||
|
||||
$memberref = $identifier // will search parent scopes
|
||||
|
||||
$arrayconstructor = 'array' '[' $expr '..' $expr ']' '(' $identifier '=>' $expr ')' // array [1..N] (i => i*i)
|
||||
where ^start ^end (int) ^index variable ^function of index variable
|
||||
|
||||
// --- predefined functions
|
||||
// *All* functions are defined as macros that instantiate a runtime object. (The same is true for operators above, too, actually.)
|
||||
|
||||
// functions that really are macros that instantiate ComputeNodes:
|
||||
// - Times(,), Plus(,), Sigmoid(), etc.
|
||||
// numeric functions:
|
||||
// - Floor() (for int division), Ceil(), Round() (for rounding), Abs(), Sign(), ...
|
||||
// string functions:
|
||||
// - Replace(s,what,withwhat), Str(number) (number to string), Chr(number) (convert Unicode codepoint to string), Format(fmt,val) (sprintf-like formatting with one arg)
|
||||
// other:
|
||||
// - Fail("error description") --will throw exception when executed; use this like assertion
|
||||
|
||||
dictionaries
|
||||
------------
|
||||
|
||||
- dictionaries are key-value pairs; they are records or compound data structures for use inside the config file itself
|
||||
- dictionaries are immutable and exist inside the parser but are not serialized to disk with a model --TODO: it might be needed to do that for MEL
|
||||
- the argument to a runtime-object instantiation is also a dictionary
|
||||
- the config file can access that dictionary's members directly from the runtime-object expression, for convenience
|
||||
- intermediate variables that are only used to construct dictionary entries also become dictionary entries (no syntactic distinction) --TODO: should we distinguish them?
|
||||
- macros are also dictionary members
|
||||
- dictionary values are read out using dict.field syntax, where 'dict' is any expression that evaluates to a dictionary
|
||||
- object instantiations will also traverse outer scopes to find values (e.g. precision, which is shared by many)
|
||||
- runtime objects themselves are inputs to other runtime objects, but they cannot have data members that output values
|
||||
- instead, output arguments use a proxy class ComputeNodeRef that can be used as a ComputeNode for input, and gets filled in at runtime
|
||||
- dictionaries can be "edited" by "adding" (+) a second dictionary to it; items from the second will overwrite the same items in the first.
|
||||
Subtracting a dictionary will remove all items in the second dict from the first.
|
||||
This is used to allow for overriding variables on the command line. --TODO: not fully fleshed out how to access nested inner variables inside a dict
|
||||
|
||||
arrays
|
||||
------
|
||||
|
||||
- another core data type is the array. Like dictionaries, arrays are immutable and exist inside the parser only.
|
||||
- arrays are created at once in two ways
|
||||
- 'array' expression:
|
||||
array [1..N] (i => f(i)) // fake lambda syntax could be made real lambda; also could extend to multi-dim arrays
|
||||
- ':' operator concatenates arrays and/or elements. Arrays are flattened.
|
||||
1:2:3
|
||||
- elements are read-accessed with index operator
|
||||
X[i]
|
||||
- example syntax of how one could define useful operators for arrays
|
||||
- Append(seq,item) = seq : item
|
||||
- Repeat(item,N) = array [1..N] (i => item)
|
||||
- arrays with repetition can be created like this:
|
||||
0.8 : array [1..3] (i => 0.2) : 0.05
|
||||
or
|
||||
0.8 : Repeat(0.2,3) : 0.05
|
||||
- the array[] () argument looks like a C# lambda, but for now is hard-coded syntax (but with potential to be a true lambda in the future)
|
||||
|
||||
towards MEL
|
||||
-----------
|
||||
|
||||
Model editing is now done in a functional manner, like this:
|
||||
|
||||
TIMIT_AddLayer = new EditAction [
|
||||
|
||||
currModelPath = "ExpDir\TrainWithPreTrain\dptmodel1\cntkSpeech.dnn"
|
||||
newModelPath = "ExpDir\TrainWithPreTrain\dptmodel2\cntkSpeech.dnn.0"
|
||||
|
||||
model = LoadModel(currModelPath);
|
||||
newModel = EditModel(model, [
|
||||
// new items here
|
||||
outZ = SBFF(model.outZ.INPUT, LABELDIM, outZ.INPUT.OUTDIM)
|
||||
])
|
||||
do = ( Dump(newModel, newModelPath + ".dump.txt")
|
||||
: SaveModel(newModel, newModelPath) )
|
||||
|
||||
]
|
||||
|
||||
sample
|
||||
------
|
||||
|
||||
// This sample is a modification of the original TIMIT_TrainSimpleNetwork.config and TIMIT_TrainNDLNetwork.config.
|
||||
// The changes compared to the origina syntax are called out in comments.
|
||||
|
||||
stderr = ExpDir + "\TrainSimpleNetwork\log\log" // before: $ExpDir$\TrainSimpleNetwork\log\log
|
||||
actions = TIMIT_TrainSimple // before: command = ... ('command' is singular, but this can be a sequence of actions)
|
||||
|
||||
// these values are used by several runtime-object instantiations below
|
||||
precision = 'float' // before: precision = float
|
||||
deviceId = DeviceNumber // before: $DeviceNumber$
|
||||
|
||||
#######################################
|
||||
# TRAINING CONFIG (Simple, Fixed LR) #
|
||||
#######################################
|
||||
|
||||
Repeat(val,count) = array [1..count] (i => val) // new: array helper to repeat a value (result is a array) (this would be defined in a library eventually)
|
||||
|
||||
TIMIT_TrainSimple = new TrainAction [ // new: added TrainAction; this is a class name of the underlying runtime object
|
||||
// new: TrainAction takes three main parameters: 'source' -> 'model' -> 'optimizer' (-> indicating logical dependency)
|
||||
//action = train // removed (covered by class name)
|
||||
traceLevel = 1
|
||||
|
||||
// new: Model object; some parameters were moved into this
|
||||
model = new Model [ // this is an input to TrainAction
|
||||
modelPath = ExpDir + "\TrainSimpleNetwork\model\cntkSpeech.dnn" // before: $ExpDir$\TrainSimpleNetwork\model\cntkSpeech.dnn
|
||||
|
||||
// EXAMPLE 1: SimpleNetworkBuilder --TODO: do we even need a C++ class, or can we use a macro instead? Would make life easier re connecting inputs
|
||||
network = new SimpleNetworkBuilder [ // before: SimpleNetworkBuilder = [
|
||||
layerSizes = 792 : Repeat(512,3) : 183 // before: 792:512*3:183
|
||||
layerTypes = 'Sigmoid' // before: no quotes
|
||||
initValueScale = 1.0
|
||||
applyMeanVarNorm = true
|
||||
uniformInit = true
|
||||
needPrior = true
|
||||
// the following two belong into SGD, so they were removed here
|
||||
//trainingCriterion = CrossEntropyWithSoftmax
|
||||
//evalCriterion = ErrorPrediction
|
||||
// new: connect to input stream from source; and expose the output layer
|
||||
input = source.features.data // these are also ComputeNodeRefs, exposed by the source
|
||||
output = ComputeNodeRef [ dim = source.labels.dim ] // SimpleNetworkBuilder will put top layer affine transform output (input to softmax) here
|
||||
// criteria are configurable here; these are ComputeNodes created here
|
||||
trainingCriterion = CrossEntropyWithSoftmax (source.labels.data, output)
|
||||
evalCriterion = ErrorPrediction (source.labels.data, output)
|
||||
// new: (and half-baked) define Input nodes
|
||||
myFeatures=Input(featDim) // reader stream will reference this
|
||||
myLabels=Input(labelDim)
|
||||
]
|
||||
|
||||
// EXAMPLE 2: network from NDL (an actual config would contain one of these two examples)
|
||||
network = new NDL [ // before: run=ndlCreateNetwork ; ndlCreateNetwork=[
|
||||
featDim = myFeatures.dim // before: 792 hard-coded; note: myFeatures and myLabels are defined below
|
||||
labelDim = myLabels.dim // before: 183 hard-coded
|
||||
hiddenDim = 512
|
||||
|
||||
// input nodes
|
||||
myFeatures=Input(featDim) // before: optional arg tag=feature
|
||||
myLabels=Input(labelDim) // before: optional arg tag=label
|
||||
|
||||
// old
|
||||
//# define network
|
||||
//featNorm = MeanVarNorm(myFeatures)
|
||||
//L1 = SBFF(featNorm,hiddenDim,featDim)
|
||||
//L2 = SBFF(L1,hiddenDim,hiddenDim)
|
||||
//L3 = SBFF(L2,hiddenDim,hiddenDim)
|
||||
//CE = SMBFF(L3,labelDim,hiddenDim,myLabels,tag=Criteria)
|
||||
//Err = ErrorPrediction(myLabels,CE.BFF.FF.P,tag=Eval)
|
||||
//logPrior = LogPrior(myLabels)
|
||||
//ScaledLogLikelihood=Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
|
||||
// new:
|
||||
// Let's have the macros declared here for illustration (in the end, these would live in a library)
|
||||
FF(X1, W1, B1) = W1 * X1 + B1 // before: T=Times(W1,X1) ; P=Plus(T, B1)
|
||||
BFF(in, rows, cols) = [ // before: BFF(in, rows, cols) { ... }
|
||||
B = Parameter(rows, init = fixedvalue, value = 0)
|
||||
W = Parameter(rows, cols)
|
||||
z = FF(in, w, b) // before: FF = ...; illegal now, cannot use same name again
|
||||
]
|
||||
SBFF(in, rowCount, colCount) = [ // before: SBFF(in,rowCount,colCount) { ... }
|
||||
z = BFF(in, rowCount, colCount).z // before: BFF = BFF(in, rowCount, colCount)
|
||||
Eh = Sigmoid(z)
|
||||
]
|
||||
// Macros are expressions. FF returns a ComputeNode; while BFF and SBFF return a dictionary that contains multiple named ComputeNode.
|
||||
|
||||
// new: define network in a loop. This allows parameterizing over the network depth.
|
||||
numLayers = 7
|
||||
layers = array [0..numLayers] ( layer =>
|
||||
if layer == 0 then featNorm
|
||||
else if layer == 1 then SBFF(layers[layer-1].Eh, hiddenDim, featDim)
|
||||
else if layer < numLayers then SBFF(layers[layer-1].Eh, hiddenDim, hiddenDim)
|
||||
else BFF(layers[layer-1].Eh, labelDim, hiddenDim)
|
||||
)
|
||||
outZ = layers[numLayers].z // new: to access the output value, the variable name (dictionary member) cannot be omitted
|
||||
|
||||
// alternative to the above: define network with recursion
|
||||
HiddenStack(layer) = if layer > 1 then SBFF(HiddenStack(layer-1).Eh, hiddenDim, hiddenDim) else SBFF(featNorm, hiddenDim, featDim)
|
||||
outZ = BFF(HiddenStack(numlayers).Eh, labelDim, hiddenDim)
|
||||
|
||||
// define criterion nodes
|
||||
CE = CrossEntropyWithSoftmax(myLabels, outZ)
|
||||
Err = ErrorPrediction(myLabels, outZ)
|
||||
|
||||
// define output node for decoding
|
||||
logPrior = LogPrior(myLabels)
|
||||
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
]
|
||||
]
|
||||
|
||||
// the SGD optimizer
|
||||
optimizer = new SGDOptimizer [ // before: SGD = [
|
||||
epochSize = 0
|
||||
minibatchSize = 256 : 1024
|
||||
learningRatesPerMB = 0.8 : Repeat(3.2,14) : 0.08 // (syntax change for repetition)
|
||||
momentumPerMB = 0.9
|
||||
dropoutRate = 0.0
|
||||
maxEpochs = 25
|
||||
// new: link to the criterion node
|
||||
trainingCriterion = model.network.CE // (note: I would like to rename this to 'objective')
|
||||
]
|
||||
|
||||
// The RandomizingSource performs randomization and mini-batching, while driving low-level random-access readers.
|
||||
source = new RandomizingSource [ // before: reader = [
|
||||
//readerType = HTKMLFReader // removed since covered by class name
|
||||
|
||||
// new: define what utterances to get from what stream sources
|
||||
dataSetFile = ScpDir + "\TIMIT.train.scp.fbank.fullpath" // (new) defines set of utterances to train on; accepts HTK archives
|
||||
streams = ( [ // This passes the 'features' and 'labels' runtime objects to the source, and also connects them to the model's Input nodes.
|
||||
reader = features // random-access reader
|
||||
input = model.network.myFeatures // Input node that this feeds into
|
||||
]
|
||||
: [
|
||||
reader = labels
|
||||
input = model.network.myLabels
|
||||
] ) // note: ':' is array syntax. Parentheses are only for readability
|
||||
|
||||
readMethod = 'blockRandomize' // before: no quotes
|
||||
miniBatchMode = 'Partial' // before: no quotes
|
||||
randomize = 'Auto' // before: no quotes
|
||||
verbosity = 1
|
||||
|
||||
// change: The following two are not accessed directly by the source, but indirectly through the 'streams' argument.
|
||||
// They could also be defined outside of this dictionary. They are from the NDL, though.
|
||||
// The 'RandomizingSource' does not know about features and labels specifically.
|
||||
features = new HTKFeatReader [ // before: features = [
|
||||
//dim = 792 // (moved to 'data' node)
|
||||
scpFile = dataSetFile // HTK reader can share source's archive file that defines dataSet
|
||||
data = new ComputeNodeRef [ dim = 792 ] // an input node the model can connect to; dimension is verified when files are opened
|
||||
]
|
||||
|
||||
labels = new HTKMLFReader [ // before: labels = [
|
||||
mlfFile = MlfDir + "\TIMIT.train.align_cistate.mlf.cntk" // before: $MlfDir$\TIMIT.train.align_cistate.mlf.cntk
|
||||
//labelDim = 183 // (moved to 'data' node)
|
||||
labelMappingFile = MlfDir + "\TIMIT.statelist" // before: $MlfDir$\TIMIT.statelist
|
||||
data = new ComputeNodeRef [ dim = 183 ] // an input node the model can connect to; dimension is verified when reading statelist file
|
||||
]
|
||||
]
|
||||
]
|
||||
|
||||
Example 2: truncated bidirectional RNN
|
||||
--------------------------------------
|
||||
|
||||
network = new NDL [
|
||||
// network parameters
|
||||
hiddenDim = 512
|
||||
numHiddenLayers = 6 // 6 hidden layers
|
||||
T = 41 // total context window
|
||||
|
||||
// data sources
|
||||
myFeatures = source.features.data
|
||||
myLabels = source.labels.data
|
||||
|
||||
// derived dimensions
|
||||
augmentedFeatDim = myFeatures.dim // feature arrays are context window frames stacked into a single long array
|
||||
labelDim = myLabels.dim
|
||||
|
||||
centerT = Floor(T/2) // center frame to predict
|
||||
featDim = Floor(augmentedFeatDim / T)
|
||||
|
||||
// split the augmented input vector into individual frame vectors
|
||||
subframes = array [0..T-1] (t => RowSlice(t * featDim, featDim, myFeatures))
|
||||
|
||||
// hidden layers
|
||||
// Hidden state arrays for all frames are stored in a array object.
|
||||
layers = array [1..numHiddenLayers] (layer => [ // each layer stores a dictionary that stores its output hidden fwd and bwd state vectors
|
||||
// model parameters
|
||||
W_fwd = Parameter(hiddenDim, featDim) // Parameter(outdim, indim) --in_fwd.rows is an initialization parameter read from the dict
|
||||
W_bwd = if layer > 1 then Parameter(hiddenDim, hiddenDim) else Fail("no W_bwd") // W denotes input-to-hidden connections
|
||||
H_fwd = Parameter(hiddenDim, hiddenDim) // H denotes hidden-to-hidden lateral connections
|
||||
H_bwd = Parameter(hiddenDim, hiddenDim)
|
||||
b = Parameter(hiddenDim, 1) // bias
|
||||
// shared part of activations (input connections and bias)
|
||||
z_shared = array [0..T-1] (t => if layers > 1 then W_fwd * layers[layer-1].h_fwd[t] + W_bwd * layers[layer-1].h_bwd[t] + b // intermediate layer gets fed fwd and bwd hidden state
|
||||
else W_fwd * subframes + b) // input layer reads frames directly
|
||||
// recurrent part and non-linearity
|
||||
neededT = if layer < numHiddenLayers then T else centerT+1 // last hidden layer does not require all frames
|
||||
step(H,h,dt,t) = Sigmoid(if (t+dt > 0 && t+dt < T) then z_shared[t] + H * h[t+dt]
|
||||
else z_shared[t])
|
||||
h_fwd = array [0..neededT-1] (t => step(H_fwd, h_fwd, -1, t))
|
||||
h_bwd = array [T-neededT..T-1] (t => step(H_bwd, h_bwd, 1, t))
|
||||
])
|
||||
// output layer --linear only at this point; Softmax is applied later
|
||||
outZ = [
|
||||
// model parameters
|
||||
W_fwd = Parameter(labelDim, hiddenDim)
|
||||
W_bwd = Parameter(labelDim, hiddenDim)
|
||||
b = Parameter(labelDim, 1)
|
||||
// output
|
||||
topHiddenLayer = layers[numHiddenLayers]
|
||||
z = W_fwd * topHiddenLayer.h_fwd[centerT] + W_bwd * topHiddenLayer.h_bwd[centerT] + b
|
||||
].z // we only want this one & don't care about the rest of this dictionary
|
||||
|
||||
// define criterion nodes
|
||||
CE = CrossEntropyWithSoftmax(myLabels, outZ)
|
||||
Err = ErrorPrediction(myLabels, outZ)
|
||||
|
||||
// define output node for decoding
|
||||
logPrior = LogPrior(myLabels)
|
||||
ScaledLogLikelihood = outZ - logPrior // before: Minus(CE.BFF.FF.P,logPrior,tag=Output)
|
||||
]
|
|
@ -0,0 +1,80 @@
|
|||
#
|
||||
# test this with this command line:
|
||||
# configFile=$(SolutionDir)BrainScript/test.config RunDir=$(SolutionDir)\Tests\Speech\RunDir DataDir=$(SolutionDir)\Tests\Speech\Data DeviceId=Auto
|
||||
|
||||
precision=float
|
||||
command=speechTrain
|
||||
deviceId=$DeviceId$
|
||||
|
||||
parallelTrain=false
|
||||
|
||||
speechTrain=[
|
||||
action=train
|
||||
modelPath=$RunDir$/models/cntkSpeech.dnn
|
||||
deviceId=$DeviceId$
|
||||
traceLevel=1
|
||||
# inside here is the new stuff
|
||||
ExperimentalNetworkBuilder=[
|
||||
//deviceId = -21 ; precision = 'floax' // for now
|
||||
layerSizes=363:512:512:132
|
||||
trainingCriterion=CE
|
||||
evalCriterion=Err
|
||||
//layerTypes=Sigmoid
|
||||
//initValueScale=1.0
|
||||
//applyMeanVarNorm=true
|
||||
//uniformInit=true
|
||||
//needPrior=true
|
||||
|
||||
numHiddenLayers = 3
|
||||
myFeatures = Input(layerSizes[0]) ; myLabels = Input(layerSizes[Length(layerSizes)-1])
|
||||
featNorm = MeanVarNorm(myFeatures)
|
||||
layers = array[1..numHiddenLayers] (layer => if layer > 1 then SBFF(layers[layer-1].Eh, layerSizes[layer], layerSizes[layer-1]) else SBFF(featNorm, layerSizes[layer], layerSizes[layer-1]))
|
||||
outLayer = BFF(layers[numHiddenLayers].Eh, labelDim, hiddenDim)
|
||||
outZ = outLayer.z
|
||||
CE = CrossEntropyWithSoftmax(myLabels, outZ)
|
||||
Err = ErrorPrediction(myLabels, outZ)
|
||||
logPrior = LogPrior(myLabels)
|
||||
ScaledLogLikelihood = outZ - logPrior
|
||||
]
|
||||
|
||||
SGD=[
|
||||
epochSize=20480
|
||||
minibatchSize=64:256:1024:
|
||||
learningRatesPerMB=1.0:0.5:0.1
|
||||
numMBsToShowResult=10
|
||||
momentumPerMB=0.9:0.656119
|
||||
dropoutRate=0.0
|
||||
maxEpochs=3
|
||||
keepCheckPointFiles=true
|
||||
|
||||
AutoAdjust=[
|
||||
reduceLearnRateIfImproveLessThan=0
|
||||
loadBestModel=true
|
||||
increaseLearnRateIfImproveMoreThan=1000000000
|
||||
learnRateDecreaseFactor=0.5
|
||||
learnRateIncreaseFactor=1.382
|
||||
autoAdjustLR=AdjustAfterEpoch
|
||||
]
|
||||
clippingThresholdPerSample=1#INF
|
||||
]
|
||||
reader=[
|
||||
readerType=HTKMLFReader
|
||||
readMethod=blockRandomize
|
||||
miniBatchMode=Partial
|
||||
randomize=Auto
|
||||
verbosity=0
|
||||
features=[
|
||||
dim=363
|
||||
type=Real
|
||||
scpFile=glob_0000.scp
|
||||
]
|
||||
|
||||
labels=[
|
||||
mlfFile=$DataDir$/glob_0000.mlf
|
||||
labelMappingFile=$DataDir$/state.list
|
||||
|
||||
labelDim=132
|
||||
labelType=Category
|
||||
]
|
||||
]
|
||||
]
|
77
CNTK.sln
77
CNTK.sln
|
@ -3,13 +3,14 @@ Microsoft Visual Studio Solution File, Format Version 12.00
|
|||
# Visual Studio 2013
|
||||
VisualStudioVersion = 12.0.21005.1
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKMath", "Math\Math\Math.vcxproj", "{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}"
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKMathDll", "Math\Math\Math.vcxproj", "{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{B3DD765E-694E-4494-BAD7-37BBF2942517} = {B3DD765E-694E-4494-BAD7-37BBF2942517}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTK", "MachineLearning\CNTK\CNTK.vcxproj", "{E6F26F9A-FF64-4F0A-B749-CD309EE357EE}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
|
||||
{33D2FD22-DEF2-4507-A58A-368F641AEBE5} = {33D2FD22-DEF2-4507-A58A-368F641AEBE5}
|
||||
{D667AF32-028A-4A5D-BE19-F46776F0F6B2} = {D667AF32-028A-4A5D-BE19-F46776F0F6B2}
|
||||
{9A2F2441-5972-4EA8-9215-4119FCE0FB68} = {9A2F2441-5972-4EA8-9215-4119FCE0FB68}
|
||||
|
@ -17,6 +18,7 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTK", "MachineLearning\CNT
|
|||
{014DA766-B37B-4581-BC26-963EA5507931} = {014DA766-B37B-4581-BC26-963EA5507931}
|
||||
{62836DC1-DF77-4B98-BF2D-45C943B7DDC6} = {62836DC1-DF77-4B98-BF2D-45C943B7DDC6}
|
||||
{1D5787D4-52E4-45DB-951B-82F220EE0C6A} = {1D5787D4-52E4-45DB-951B-82F220EE0C6A}
|
||||
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937} = {DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}
|
||||
{E6646FFE-3588-4276-8A15-8D65C22711C1} = {E6646FFE-3588-4276-8A15-8D65C22711C1}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
|
@ -50,8 +52,9 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LUSequenceReader", "DataRea
|
|||
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKEval", "MachineLearning\CNTKEval\CNTKEval.vcxproj", "{482999D1-B7E2-466E-9F8D-2119F93EAFD9}"
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKEvalDll", "MachineLearning\CNTKEval\CNTKEval.vcxproj", "{482999D1-B7E2-466E-9F8D-2119F93EAFD9}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
|
||||
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
|
@ -84,9 +87,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibSVMBinaryReader", "DataR
|
|||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Linux build files", "Linux build files", "{3ED0465D-23E7-4855-9694-F788717B6533}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
configure = configure
|
||||
Makefile = Makefile
|
||||
Makefile_kaldi.cpu = Makefile_kaldi.cpu
|
||||
Makefile_kaldi.gpu = Makefile_kaldi.gpu
|
||||
README = README
|
||||
EndProjectSection
|
||||
EndProject
|
||||
|
@ -197,12 +199,59 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Data", "Data", "{5F733BBA-F
|
|||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "LSTM", "LSTM", "{19EE975B-232D-49F0-94C7-6F1C6424FB53}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
Tests\Speech\LSTM\baseline.cpu.txt = Tests\Speech\LSTM\baseline.cpu.txt
|
||||
Tests\Speech\LSTM\baseline.gpu.txt = Tests\Speech\LSTM\baseline.gpu.txt
|
||||
Tests\Speech\LSTM\baseline.windows.cpu.txt = Tests\Speech\LSTM\baseline.windows.cpu.txt
|
||||
Tests\Speech\LSTM\baseline.windows.gpu.txt = Tests\Speech\LSTM\baseline.windows.gpu.txt
|
||||
Tests\Speech\LSTM\cntk.config = Tests\Speech\LSTM\cntk.config
|
||||
Tests\Speech\LSTM\lstmp-3layer_WithSelfStab.ndl = Tests\Speech\LSTM\lstmp-3layer_WithSelfStab.ndl
|
||||
Tests\Speech\LSTM\run-test = Tests\Speech\LSTM\run-test
|
||||
Tests\Speech\LSTM\testcases.yml = Tests\Speech\LSTM\testcases.yml
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ParseConfig", "MachineLearning\ParseConfig\ParseConfig.vcxproj", "{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}"
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKComputationNetworkLib", "MachineLearning\CNTKComputationNetworkLib\CNTKComputationNetworkLib.vcxproj", "{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKSGDLib", "MachineLearning\CNTKSGDLib\CNTKSGDLib.vcxproj", "{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ParallelTraining", "ParallelTraining", "{5E666C53-2D82-49C9-9127-3FDDC321C741}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
Tests\ParallelTraining\SimpleMultiGPU.config = Tests\ParallelTraining\SimpleMultiGPU.config
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Data", "Data", "{6D1353D6-F196-466F-B886-F16D48759B20}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
Tests\ParallelTraining\Data\SimpleDataTrain.txt = Tests\ParallelTraining\Data\SimpleDataTrain.txt
|
||||
Tests\ParallelTraining\Data\SimpleMapping.txt = Tests\ParallelTraining\Data\SimpleMapping.txt
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "NoQuantization", "NoQuantization", "{B6725C9F-A6D2-4269-9B74-7888A90F7884}"
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "SinglePrecision", "SinglePrecision", "{B27DD434-EECD-4EE0-A03B-1150EB87258E}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.cpu.txt = Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.cpu.txt
|
||||
Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.gpu.txt = Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.gpu.txt
|
||||
Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.windows.cpu.txt = Tests\ParallelTraining\NoQuantization\SinglePrecision\baseline.windows.cpu.txt
|
||||
Tests\ParallelTraining\NoQuantization\SinglePrecision\run-test = Tests\ParallelTraining\NoQuantization\SinglePrecision\run-test
|
||||
Tests\ParallelTraining\NoQuantization\SinglePrecision\testcases.yml = Tests\ParallelTraining\NoQuantization\SinglePrecision\testcases.yml
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "DoublePrecision", "DoublePrecision", "{A4884465-CFBB-4A64-A9DE-690E1A63EF7E}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.cpu.txt = Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.cpu.txt
|
||||
Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.gpu.txt = Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.gpu.txt
|
||||
Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.windows.cpu.txt = Tests\ParallelTraining\NoQuantization\DoublePrecision\baseline.windows.cpu.txt
|
||||
Tests\ParallelTraining\NoQuantization\DoublePrecision\run-test = Tests\ParallelTraining\NoQuantization\DoublePrecision\run-test
|
||||
Tests\ParallelTraining\NoQuantization\DoublePrecision\testcases.yml = Tests\ParallelTraining\NoQuantization\DoublePrecision\testcases.yml
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|x64 = Debug|x64
|
||||
|
@ -268,6 +317,18 @@ Global
|
|||
{CE429AA2-3778-4619-8FD1-49BA3B81197B}.Debug|x64.Build.0 = Debug|x64
|
||||
{CE429AA2-3778-4619-8FD1-49BA3B81197B}.Release|x64.ActiveCfg = Release|x64
|
||||
{CE429AA2-3778-4619-8FD1-49BA3B81197B}.Release|x64.Build.0 = Release|x64
|
||||
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Debug|x64.ActiveCfg = Debug|x64
|
||||
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Debug|x64.Build.0 = Debug|x64
|
||||
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Release|x64.ActiveCfg = Release|x64
|
||||
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B}.Release|x64.Build.0 = Release|x64
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Debug|x64.ActiveCfg = Debug|x64
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Debug|x64.Build.0 = Debug|x64
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Release|x64.ActiveCfg = Release|x64
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513}.Release|x64.Build.0 = Release|x64
|
||||
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Debug|x64.ActiveCfg = Debug|x64
|
||||
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Debug|x64.Build.0 = Debug|x64
|
||||
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Release|x64.ActiveCfg = Release|x64
|
||||
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937}.Release|x64.Build.0 = Release|x64
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
|
@ -277,11 +338,15 @@ Global
|
|||
{482999D1-B7E2-466E-9F8D-2119F93EAFD9} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
|
||||
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
|
||||
{B3DD765E-694E-4494-BAD7-37BBF2942517} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
|
||||
{DE3C54E5-D7D0-47AF-A783-DFDCE59E7937} = {DD043083-71A4-409A-AA91-F9C548DCF7EC}
|
||||
{6CEE834A-8104-46A8-8902-64C81BD7928F} = {D45DF403-6781-444E-B654-A96868C5BE68}
|
||||
{668BEED5-AC07-4F35-B3AE-EE65A7F9C976} = {D45DF403-6781-444E-B654-A96868C5BE68}
|
||||
{0F30EBCF-09F3-4EED-BF54-4214BCE53FEC} = {D45DF403-6781-444E-B654-A96868C5BE68}
|
||||
{DBB3C106-B0B4-4059-8477-C89528CEC1B0} = {D45DF403-6781-444E-B654-A96868C5BE68}
|
||||
{C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8} = {D45DF403-6781-444E-B654-A96868C5BE68}
|
||||
{5E666C53-2D82-49C9-9127-3FDDC321C741} = {D45DF403-6781-444E-B654-A96868C5BE68}
|
||||
{7C4E77C9-6B17-4B02-82C1-DB62EEE2635B} = {D45DF403-6781-444E-B654-A96868C5BE68}
|
||||
{E6646FFE-3588-4276-8A15-8D65C22711C1} = {33EBFE78-A1A8-4961-8938-92A271941F94}
|
||||
{1D5787D4-52E4-45DB-951B-82F220EE0C6A} = {33EBFE78-A1A8-4961-8938-92A271941F94}
|
||||
{62836DC1-DF77-4B98-BF2D-45C943B7DDC6} = {33EBFE78-A1A8-4961-8938-92A271941F94}
|
||||
|
@ -301,5 +366,9 @@ Global
|
|||
{4BBF2950-3DBD-469A-AD57-6CACBEBAF541} = {C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8}
|
||||
{5F733BBA-FE83-4668-8F83-8B0E78A36619} = {C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8}
|
||||
{19EE975B-232D-49F0-94C7-6F1C6424FB53} = {C47CDAA5-6D6C-429E-BC89-7CA0F868FDC8}
|
||||
{6D1353D6-F196-466F-B886-F16D48759B20} = {5E666C53-2D82-49C9-9127-3FDDC321C741}
|
||||
{B6725C9F-A6D2-4269-9B74-7888A90F7884} = {5E666C53-2D82-49C9-9127-3FDDC321C741}
|
||||
{B27DD434-EECD-4EE0-A03B-1150EB87258E} = {B6725C9F-A6D2-4269-9B74-7888A90F7884}
|
||||
{A4884465-CFBB-4A64-A9DE-690E1A63EF7E} = {B6725C9F-A6D2-4269-9B74-7888A90F7884}
|
||||
EndGlobalSection
|
||||
EndGlobal
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include <nvml.h> // note: expected at "c:\Program Files\NVIDIA Corporation\GDK\gdk_win7_amd64_release\nvml\include" (Windows) and /the path you installed deployment kit/usr/include/nvidia/gdk (Linux)
|
||||
#pragma comment (lib, "nvml.lib") // note: expected at "c:\Program Files\NVIDIA Corporation\GDK\gdk_win7_amd64_release\nvml\lib" (Windows) and /the path you installed deployment kit/usr/include/nvidia/gdk (Linux)
|
||||
#include <vector>
|
||||
#else
|
||||
int bestGPUDummy = 42; // put something into this CPP, as to avoid a linker warning
|
||||
#endif
|
||||
#include "CommonMatrix.h" // for CPUDEVICE and AUTOPLACEMATRIX
|
||||
|
||||
|
@ -43,9 +45,6 @@
|
|||
|
||||
#include <memory>
|
||||
#include "CrossProcessMutex.h"
|
||||
#include "../../MachineLearning/CNTK/MPIWrapper.h"
|
||||
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BestGpu class
|
||||
|
@ -123,6 +122,9 @@ private:
|
|||
// 0:2:3- an array of ids to use, (PTask will only use the specified IDs)
|
||||
// *3 - a count of GPUs to use (PTask)
|
||||
// All - Use all the GPUs (PTask)
|
||||
#ifdef MATH_EXPORTS
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
DEVICEID_TYPE DeviceFromConfig(const ConfigParameters& config)
|
||||
{
|
||||
static BestGpu* g_bestGpu = NULL;
|
||||
|
|
|
@ -242,7 +242,7 @@ bool DataReader<ElemType>::GetProposalObs(std::map<std::wstring, Matrix<ElemType
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &sentenceEnd, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
void DataReader<ElemType>::SetSentenceSegBatch(Matrix<float> &sentenceEnd, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
{
|
||||
for (size_t i = 0; i < m_ioNames.size(); i++)
|
||||
m_dataReader[m_ioNames[i]]->SetSentenceSegBatch(sentenceEnd, minibatchPackingFlag);
|
||||
|
@ -259,7 +259,7 @@ template<class ElemType>
|
|||
bool DataReader<ElemType>::GetMinibatchCopy(
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
std::map<std::wstring, Matrix<ElemType>*>& matrices,
|
||||
Matrix<ElemType>& sentenceBegin,
|
||||
Matrix<float>& sentenceBegin,
|
||||
std::vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
{
|
||||
bool ans = false;
|
||||
|
@ -272,7 +272,7 @@ template<class ElemType>
|
|||
bool DataReader<ElemType>::SetNetOutput(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
const Matrix<ElemType>& outputs,
|
||||
const Matrix<ElemType>& sentenceBegin,
|
||||
const Matrix<float>& sentenceBegin,
|
||||
const std::vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
{
|
||||
bool ans = false;
|
||||
|
|
|
@ -169,6 +169,20 @@ void File::GetLine(string& str)
|
|||
str = fgetline(m_file);
|
||||
}
|
||||
|
||||
// GetLines - get all lines from a file
|
||||
template<typename STRING> static void FileGetLines(File & file, std::vector<STRING>& lines)
|
||||
{
|
||||
STRING line;
|
||||
while (!file.IsEOF())
|
||||
{
|
||||
file.GetLine(line);
|
||||
lines.push_back(line);
|
||||
}
|
||||
}
|
||||
void File::GetLines(std::vector<std::wstring>& lines) { FileGetLines(*this, lines); };
|
||||
void File::GetLines(std::vector<std::string>& lines) { FileGetLines(*this, lines); }
|
||||
|
||||
|
||||
// Put a zero/space terminated wstring into a file
|
||||
// val - value to write to the file
|
||||
File& File::operator<<(const std::wstring& val)
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#define _BASICS_H_
|
||||
|
||||
#include "basetypes.h" // TODO: gradually move over here all that's needed of basetypes.h, then remove basetypes.h.
|
||||
#include "Platform.h"
|
||||
|
||||
#define TWO_PI 6.283185307f // TODO: find the official standards-confirming definition of this and use it instead
|
||||
|
||||
|
@ -25,61 +26,46 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
bool operator()(const std::wstring& left, const std::wstring& right) { return _wcsicmp(left.c_str(), right.c_str()) < 0; }
|
||||
};
|
||||
|
||||
// ThrowFormatted() - template function to throw a std::exception with a formatted error string
|
||||
template<class E>
|
||||
__declspec_noreturn static inline void ThrowFormatted(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
char buffer[1024];
|
||||
va_start(args, format);
|
||||
vsprintf(buffer, format, args);
|
||||
throw E(buffer);
|
||||
};
|
||||
|
||||
// if it receives a lonely std::string then throw that directly
|
||||
template<class E>
|
||||
__declspec_noreturn static inline void ThrowFormatted(const string & message) { throw E(message); }
|
||||
|
||||
// RuntimeError - throw a std::runtime_error with a formatted error string
|
||||
#ifdef _MSC_VER
|
||||
__declspec(noreturn)
|
||||
#endif
|
||||
static inline void RuntimeError(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
char buffer[1024];
|
||||
|
||||
va_start(args, format);
|
||||
vsprintf(buffer, format, args);
|
||||
throw std::runtime_error(buffer);
|
||||
};
|
||||
static inline void RuntimeError(const string & message) { RuntimeError("%s", message.c_str()); }
|
||||
|
||||
// LogicError - throw a std::logic_error with a formatted error string
|
||||
#ifdef _MSC_VER
|
||||
__declspec(noreturn)
|
||||
#endif
|
||||
static inline void LogicError(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
char buffer[1024];
|
||||
|
||||
va_start(args, format);
|
||||
vsprintf(buffer, format, args);
|
||||
throw std::logic_error(buffer);
|
||||
};
|
||||
static inline void LogicError(const string & message) { LogicError("%s", message.c_str()); }
|
||||
|
||||
// InvalidArgument - throw a std::logic_error with a formatted error string
|
||||
#ifdef _MSC_VER
|
||||
__declspec(noreturn)
|
||||
#endif
|
||||
static inline void InvalidArgument(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
char buffer[1024];
|
||||
|
||||
va_start(args, format);
|
||||
vsprintf(buffer, format, args);
|
||||
throw std::invalid_argument(buffer);
|
||||
};
|
||||
static inline void InvalidArgument(const string & message) { InvalidArgument("%s", message.c_str()); }
|
||||
template<class... _Types>
|
||||
__declspec_noreturn static inline void RuntimeError(_Types&&... _Args) { ThrowFormatted<std::runtime_error>(forward<_Types>(_Args)...); }
|
||||
template<class... _Types>
|
||||
__declspec_noreturn static inline void LogicError(_Types&&... _Args) { ThrowFormatted<std::logic_error>(forward<_Types>(_Args)...); }
|
||||
template<class... _Types>
|
||||
__declspec_noreturn static inline void InvalidArgument(_Types&&... _Args) { ThrowFormatted<std::invalid_argument>(forward<_Types>(_Args)...); }
|
||||
|
||||
// Warning - warn with a formatted error string
|
||||
static inline void Warning(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
char buffer[1024];
|
||||
|
||||
va_start(args, format);
|
||||
vsprintf(buffer, format, args);
|
||||
};
|
||||
static inline void Warning(const string & message) { Warning("%s", message.c_str()); }
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// random collection of stuff we needed at some place
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// TODO: maybe change to type id of an actual thing we pass in
|
||||
// TODO: is this header appropriate?
|
||||
template<class C> static wstring TypeId() { return msra::strfun::utf16(typeid(C).name()); }
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// dynamic loading of modules --TODO: not Basics, should move to its own header
|
||||
|
@ -91,7 +77,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
HMODULE m_hModule; // module handle for the writer DLL
|
||||
std::wstring m_dllName; // name of the writer DLL
|
||||
public:
|
||||
Plugin() { m_hModule = NULL; }
|
||||
Plugin() : m_hModule(NULL) { }
|
||||
template<class STRING> // accepts char (UTF-8) and wide string
|
||||
FARPROC Load(const STRING & plugin, const std::string & proc)
|
||||
{
|
||||
|
@ -99,13 +85,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_dllName += L".dll";
|
||||
m_hModule = LoadLibrary(m_dllName.c_str());
|
||||
if (m_hModule == NULL)
|
||||
Microsoft::MSR::CNTK::RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName).c_str());
|
||||
|
||||
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName).c_str());
|
||||
// create a variable of each type just to call the proper templated version
|
||||
return GetProcAddress(m_hModule, proc.c_str());
|
||||
}
|
||||
~Plugin(){}
|
||||
// removed because this causes the exception messages to be lost (exception vftables are unloaded when DLL is unloaded)
|
||||
// we do not unload because this causes the exception messages to be lost (exception vftables are unloaded when DLL is unloaded)
|
||||
// ~Plugin() { if (m_hModule) FreeLibrary(m_hModule); }
|
||||
};
|
||||
#else
|
||||
|
@ -114,11 +99,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
private:
|
||||
void *handle;
|
||||
public:
|
||||
Plugin()
|
||||
{
|
||||
handle = NULL;
|
||||
}
|
||||
|
||||
Plugin() : handle (NULL) { }
|
||||
template<class STRING> // accepts char (UTF-8) and wide string
|
||||
void * Load(const STRING & plugin, const std::string & proc)
|
||||
{
|
||||
|
@ -126,14 +107,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
soName = soName + ".so";
|
||||
void *handle = dlopen(soName.c_str(), RTLD_LAZY);
|
||||
if (handle == NULL)
|
||||
RuntimeError("Plugin not found: %s", soName.c_str());
|
||||
RuntimeError("Plugin not found: %s (error: %s)", soName.c_str(), dlerror());
|
||||
return dlsym(handle, proc.c_str());
|
||||
}
|
||||
|
||||
~Plugin() {
|
||||
if (handle != NULL)
|
||||
dlclose(handle);
|
||||
}
|
||||
~Plugin() { if (handle != NULL) dlclose(handle); }
|
||||
};
|
||||
#endif
|
||||
|
||||
|
|
|
@ -22,10 +22,12 @@
|
|||
#else
|
||||
#define DATAREADER_API
|
||||
#endif
|
||||
|
||||
#include "Basics.h"
|
||||
#include "Matrix.h"
|
||||
#include "commandArgUtil.h" // for ConfigParameters
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "Basics.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
@ -84,7 +86,7 @@ public:
|
|||
virtual void SetLabelMapping(const std::wstring&, const std::map<LabelIdType, LabelType>&) { NOT_IMPLEMENTED; };
|
||||
virtual bool GetData(const std::wstring&, size_t, void*, size_t&, size_t) { NOT_IMPLEMENTED; };
|
||||
virtual bool DataEnd(EndDataType) { NOT_IMPLEMENTED; };
|
||||
virtual void SetSentenceSegBatch(Matrix<ElemType>&, vector<MinibatchPackingFlag>& ) { NOT_IMPLEMENTED; };
|
||||
virtual void SetSentenceSegBatch(Matrix<float>&, vector<MinibatchPackingFlag>& ) { NOT_IMPLEMENTED; };
|
||||
virtual void SetRandomSeed(unsigned seed = 0) { m_seed = seed; };
|
||||
virtual bool GetProposalObs(std::map<std::wstring, Matrix<ElemType>*>*, const size_t, vector<size_t>&) { return false; }
|
||||
virtual void InitProposals(std::map<std::wstring, Matrix<ElemType>*>*) { }
|
||||
|
@ -103,7 +105,7 @@ public:
|
|||
virtual bool GetMinibatchCopy(
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>>& /*uttInfo*/,
|
||||
std::map<std::wstring, Matrix<ElemType>*>& /*matrices*/,
|
||||
Matrix<ElemType>& /*sentenceBegin*/,
|
||||
Matrix<float>& /*sentenceBegin*/,
|
||||
std::vector<MinibatchPackingFlag>& /*minibatchPackingFlag*/)
|
||||
{
|
||||
return false;
|
||||
|
@ -114,7 +116,7 @@ public:
|
|||
virtual bool SetNetOutput(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& /*uttInfo*/,
|
||||
const Matrix<ElemType>& /*outputs*/,
|
||||
const Matrix<ElemType>& /*sentenceBegin*/,
|
||||
const Matrix<float>& /*sentenceBegin*/,
|
||||
const std::vector<MinibatchPackingFlag>& /*minibatchPackingFlag*/)
|
||||
{
|
||||
return false;
|
||||
|
@ -225,7 +227,7 @@ public:
|
|||
virtual bool GetMinibatchCopy(
|
||||
std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
std::map<std::wstring, Matrix<ElemType>*>& matrices,
|
||||
Matrix<ElemType>& sentenceBegin,
|
||||
Matrix<float>& sentenceBegin,
|
||||
std::vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
|
||||
// Sets the neural network output to the reader. This can be useful if some
|
||||
|
@ -233,10 +235,10 @@ public:
|
|||
virtual bool SetNetOutput(
|
||||
const std::vector<std::vector<std::pair<wstring, size_t>>>& uttInfo,
|
||||
const Matrix<ElemType>& outputs,
|
||||
const Matrix<ElemType>& sentenceBegin,
|
||||
const Matrix<float>& sentenceBegin,
|
||||
const std::vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
|
||||
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
void SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
|
||||
void SetRandomSeed(int);
|
||||
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
thoughts on DataReader redesign
|
||||
===============================
|
||||
|
||||
current basic usage pattern:
|
||||
|
||||
- StartMinibatchLoop() sets the start and MB size
|
||||
- GetMinibatch() fills matrices in a dictionary of named matrices
|
||||
- sample code:
|
||||
|
||||
std::map<std::wstring, Matrix<ElemType>*> matrices;
|
||||
matrices[featureNames[0]] = &featuresMatrix;
|
||||
matrices[labelNames[0]] = &labelsMatrix;
|
||||
|
||||
dataReader.StartMinibatchLoop(mbSize, epoch, epochSize);
|
||||
while (dataReader.GetMinibatch(matrices))
|
||||
{
|
||||
Matrix<ElemType>& features = *matrices[featureNames[0]];
|
||||
Matrix<ElemType>& labels = *matrices[labelNames[0]];
|
||||
// no function called at end, implied in GetMinibatch()
|
||||
|
||||
issues with current data reader design:
|
||||
|
||||
- monolithic, combines all (or not) of these:
|
||||
- paging in data (incl. format parsing)
|
||||
- randomization
|
||||
- caching
|
||||
- packing of parallel streams
|
||||
- prefetch (in the original DBN.exe version of the HTK reader)
|
||||
- minibatch decimation in presence of MPI (done in a way that avoids to read data that is not needed by a node)
|
||||
- multiple streams must match in their timing frame-by-frame
|
||||
- kills sequence-to-sequence
|
||||
- currently circumvented by paring multiple networks
|
||||
|
||||
goals:
|
||||
|
||||
- remove time-synchronity limitation
|
||||
- which means that the interface must separate the notion of frames and utterances
|
||||
- break into composable blocks
|
||||
- hopefully, people in the future will only have to implement the data paging
|
||||
- note: packing is not a reader function; nodes themselves may determine packing for each minibatch
|
||||
- more abstract notion of 'utterance' e.g. include variable-size images (2D) and video (3D)
|
||||
- seems we canb keep the existing DataReader interface, but with extensions and a new implementation
|
||||
|
||||
feature details that must be considered/covered:
|
||||
|
||||
- augmentation of context frames
|
||||
- some utterances are missing
|
||||
- multi-lingual training (multi-task learning where each utterance only has labels for one task)
|
||||
- realignment may fail on some utterances
|
||||
- should support non-matrix data, e.g. lattices
|
||||
- maybe we can improve efficiency of decimated minibatch reading (current approach from DBN.exe is not optimally load-balanced)
|
||||
|
||||
thinking out loud on how we may proceed (high level):
|
||||
|
||||
- basic unit of thought is the utterance, not the minibatch
|
||||
- a minibatch is a set of utterances
|
||||
- framewise CE training: each frame is an utterance; the N frames are batched into N streams of 1 frame
|
||||
- note: design must be non-wasteful in this important special case
|
||||
- an utterance should be understood more generally as a fixed or variable-dimension N-dimensional tensor,
|
||||
including images (2D tensor, of possibly variable size) and even video (3D tensor).
|
||||
And 'utterance length' generalizes to image dimensions as well. Everything that's variable length.
|
||||
- interface Sequencer
|
||||
- determines the sequence of utterances and grouping into minibatches
|
||||
- by driving on utterance level, different feature streams with mismatching timing are not a concern of the Sequencer
|
||||
- owns knowlegde of blocks
|
||||
- provides caching control information, that is, when to release data from memory
|
||||
- for frame mode, there must be some form of translation between utterances and frames, so that we can cache utterances while randomizing over frames
|
||||
- does NOT actually read data; only provides descriptors of what to read, which are passed to pagers
|
||||
- DataReader class does the reading
|
||||
- in eval mode, there is also a DataWriter
|
||||
- class RandomSequencer
|
||||
- performs block randomization, based on one user-selected data pager
|
||||
- for SGD
|
||||
- class RandomFrameSequencer
|
||||
- treats frames of utterances into individual utterances and randomizes those (for CE training of DNNs or other windows models)
|
||||
- class LinearSequencer
|
||||
- returns data in original sequence
|
||||
- for evaluation
|
||||
- interface DataPager
|
||||
- random access to page in utterances
|
||||
- specified by a descriptor obtained from the Sequencer
|
||||
- knowledge of how to parse input data formats is in these pagers
|
||||
- data assumed immutable
|
||||
- examples:
|
||||
- HTK features
|
||||
- HTK labels from MLF (also: labels stored in feature format, for reduced startup time)
|
||||
- Python adapter
|
||||
- lightweight agreement between DataPager and Sequencer:
|
||||
- pager provides block-forming relevant information, such that the reading of data consecutively in each block will be optimal;
|
||||
sequencer will ask one user-selected pager to provide this information as a basis for block randomization
|
||||
- class CachedDataPager
|
||||
- TODO: think this through:
|
||||
- are DataPagers driven in blocks? That would be the unit of caching
|
||||
- releasing a block from cache must be an explicit function
|
||||
- maybe that helper class needs to do that
|
||||
- or we use a ref count for utterances to control releasing of blocks? Could be expensive, since invidiual speech frames can be utterances (DNN/CE). It's only a refcount of 0 or 1
|
||||
- should we call this DataPageReader or DataBlockReader?
|
||||
- let's call them Pager for now (need to change because name has a problem with reading vs. writing)
|
||||
- class DataReader
|
||||
- outer layer of the new structure
|
||||
- designed to support reading data streams that that have mismatching utterance lengths
|
||||
- there is only one DataReader instance that handles all utterance-data streams (including mismatching lenths)
|
||||
- takes a reference to one user-specified sequencer
|
||||
- takes ownership of one or more user-supplied pagers
|
||||
- after construction, the above are only accessed through the DataReader
|
||||
- a nested hierarchy of DataReaders implement specific functionaliaty
|
||||
class CachingDataReader
|
||||
- wraps a DataReader with caching--this is what one would use when randomizing (not needed for evaluation)
|
||||
class PrefetchingDataReader
|
||||
- wraps a DataReader and reads ahead on a parallel thread
|
||||
- TODO: so where does the sequencer run?? Or does sequencer provides a FIFO of minibatches (lookahead)?
|
||||
maybe sequence info is routed through the prefetch for everything? Consider that we also need to do writing, so this becomes weird
|
||||
in that one would always have to access the sequencer through the DataReader (in order to get the correctly delayed sequencing information)
|
||||
class BatchingDataReader?
|
||||
- meant to batch utterances into streams
|
||||
- NO: this should not be a DataReader, as this is a network function. But we may have supporting code in the reader interface or a reader helper class
|
||||
- instead, there should be a Minibatch batcher class that (re-)batches and reshuffles minibatches (this could be a ComputeNode, actually)
|
||||
- this new DataReader differs (extends) the current DataReader as:
|
||||
- GetMinibatch() has to return utterance and length/packing information for every minibatch
|
||||
- minibatches must also carry their own sequencing information (utterance ids); this can then be used for data writing
|
||||
- we may want to rethink the glue between reading and Input nodes. Maybe Input nodes can know about readers?
|
||||
- how to set up the whole thing:
|
||||
- create all desired data pagers
|
||||
- create a Sequencer of the desired type
|
||||
- e.g. random utterance, random frame, non-random
|
||||
- pass it one user-selected data pager to let it determining how data is grouped in blocks
|
||||
- create the desired DataReader by passing it the readers and the sequencer
|
||||
- there may be a hierarchy of nested DataReaders, to do caching and prefetching
|
||||
- use it mostly as before
|
||||
- Note: sequencer information can only be accessed through the DataReader.
|
||||
|
||||
on writing:
|
||||
|
||||
- writing is used for evaluation, and also for data conversion
|
||||
- DataReader returns minibatches that carry their utterance information (i.e. utterance ids)
|
||||
- class DataWriter
|
||||
- new SaveData() overload takes an output minibatch complete with utterance information
|
||||
TODO: we could make the two interfaces a little more symmetric w.r.t. function naming
|
||||
|
||||
[fseide 9/2015]
|
|
@ -25,10 +25,10 @@
|
|||
|
||||
#include "Basics.h"
|
||||
#include "Matrix.h"
|
||||
#include "commandArgUtil.h" // for ConfigParameters
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// type of data in this section
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
// </copyright>
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "Basics.h"
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -15,6 +17,8 @@
|
|||
#include <unistd.h>
|
||||
#endif
|
||||
#include "fileutil.h" // for f{ge,pu}t{,Text}()
|
||||
#include <fstream> // for LoadMatrixFromTextFile() --TODO: change to using this File class
|
||||
#include <sstream>
|
||||
|
||||
namespace Microsoft{ namespace MSR { namespace CNTK {
|
||||
|
||||
|
@ -127,6 +131,8 @@ public:
|
|||
|
||||
void GetLine(std::wstring& str);
|
||||
void GetLine(std::string& str);
|
||||
void GetLines(std::vector<std::wstring>& lines);
|
||||
void GetLines(std::vector<std::string>& lines);
|
||||
|
||||
// put operator for basic types
|
||||
template <typename T>
|
||||
|
@ -238,6 +244,74 @@ public:
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Read a matrix stored in text format from 'filePath' (whitespace-separated columns, newline-separated rows),
|
||||
// and return a flat array containing the contents of this file in column-major format.
|
||||
// filePath: path to file containing matrix in text format.
|
||||
// numRows/numCols: after this function is called, these parameters contain the number of rows/columns in the matrix.
|
||||
// returns: a flat array containing the contents of this file in column-major format
|
||||
// NOTE: caller is responsible for deleting the returned buffer once it is finished using it.
|
||||
// TODO: change to return a std::vector<ElemType>; solves the ownership issue
|
||||
// This function does not quite fit here, but it fits elsewhere even worse. TODO: change to use File class!
|
||||
template<class ElemType>
|
||||
static vector<ElemType> LoadMatrixFromTextFile(const std::string filePath, size_t& numRows, size_t& numCols)
|
||||
{
|
||||
size_t r = 0;
|
||||
size_t numColsInFirstRow = 0;
|
||||
|
||||
// NOTE: Not using the Microsoft.MSR.CNTK.File API here because it
|
||||
// uses a buffer of fixed size, which doesn't allow very long rows.
|
||||
// See fileutil.cpp fgetline method (std::string fgetline (FILE * f) { fixed_vector<char> buf (1000000); ... })
|
||||
std::ifstream myfile(filePath);
|
||||
|
||||
// load matrix into vector of vectors (since we don't know the size in advance).
|
||||
std::vector<std::vector<ElemType>> elements;
|
||||
if (myfile.is_open())
|
||||
{
|
||||
std::string line;
|
||||
while (std::getline(myfile, line))
|
||||
{
|
||||
// Break on empty line. This allows there to be an empty line at the end of the file.
|
||||
if (line == "")
|
||||
break;
|
||||
|
||||
istringstream iss(line);
|
||||
ElemType element;
|
||||
int numElementsInRow = 0;
|
||||
elements.push_back(std::vector<ElemType>());
|
||||
while (iss >> element)
|
||||
{
|
||||
elements[r].push_back(element);
|
||||
numElementsInRow++;
|
||||
}
|
||||
|
||||
if (r == 0)
|
||||
numColsInFirstRow = numElementsInRow;
|
||||
else if (numElementsInRow != numColsInFirstRow)
|
||||
RuntimeError("The rows in the provided file do not all have the same number of columns: " + filePath);
|
||||
|
||||
r++;
|
||||
}
|
||||
myfile.close();
|
||||
}
|
||||
else
|
||||
RuntimeError("Unable to open file");
|
||||
|
||||
numRows = r;
|
||||
numCols = numColsInFirstRow;
|
||||
|
||||
vector<ElemType> array(numRows * numCols);
|
||||
|
||||
// Perform transpose when copying elements from vectors to ElemType[],
|
||||
// in order to store in column-major format.
|
||||
for (int i = 0; i < numCols; i++)
|
||||
{
|
||||
for (int j = 0; j < numRows; j++)
|
||||
array[i * numRows + j] = elements[j][i];
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
operator FILE*() const { return m_file; }
|
||||
};
|
||||
|
||||
|
|
|
@ -222,7 +222,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
// for raw pointer
|
||||
template<typename ElemType>
|
||||
template<class ElemType>
|
||||
void AllReduce(ElemType* pData, size_t nData)
|
||||
{
|
||||
if ((NumNodesInUse() > 1 && (Communicator() != MPI_COMM_NULL)))
|
||||
|
@ -231,7 +231,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
}
|
||||
|
||||
template<typename ElemType>
|
||||
template<class ElemType>
|
||||
void Bcast(ElemType* pData, size_t nData, size_t srcRank)
|
||||
{
|
||||
if ((NumNodesInUse() > 1) && (Communicator() != MPI_COMM_NULL))
|
||||
|
@ -247,3 +247,5 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
};
|
||||
}}}
|
||||
|
||||
extern Microsoft::MSR::CNTK::MPIWrapper *g_mpi;
|
|
@ -9,6 +9,16 @@
|
|||
#define __UNIX__
|
||||
#endif
|
||||
|
||||
// ===========================================================================
|
||||
// stuff to avoid compiler warnings
|
||||
// ===========================================================================
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define __declspec_noreturn __declspec(noreturn)
|
||||
#else
|
||||
#define __declspec_noreturn
|
||||
#endif
|
||||
|
||||
// ===========================================================================
|
||||
// emulation of some MSVC proprietary CRT
|
||||
// ===========================================================================
|
||||
|
@ -53,10 +63,6 @@ typedef void* HANDLE;
|
|||
#define VOID void
|
||||
#define CONST const
|
||||
|
||||
//standard library conversion
|
||||
//#define min std::min
|
||||
#define hash_map unordered_map
|
||||
|
||||
//macro conversion
|
||||
#define __forceinline inline
|
||||
//string and io conversion
|
||||
|
|
|
@ -0,0 +1,513 @@
|
|||
// BrainScriptObjects.h -- objects that the config parser operates on
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Basics.h"
|
||||
|
||||
#include <memory> // for shared_ptr<>
|
||||
#include <functional> // for function<>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace ScriptableObjects {
|
||||
|
||||
using namespace std;
|
||||
using namespace msra::strfun; // for wstrprintf()
|
||||
using namespace Microsoft::MSR::CNTK; // for stuff from Basics.h
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ScriptingError -- base class for any errors thrown by scripting
|
||||
// It's a runtime_error with an additional virtual function PrintError().
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
class ScriptingError : public runtime_error
|
||||
{
|
||||
public:
|
||||
template<typename M>
|
||||
ScriptingError(const M & msg) : runtime_error(msg) { }
|
||||
virtual void PrintError() const = 0;
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Object -- common base class for objects that can be used in config files
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// All values that can be used in config files
|
||||
// - are heap objects
|
||||
// - primitives are wrapped
|
||||
// - object pointers are ref-counted shared_ptr, wrapped in ConfigValuePtr (see BrainScriptEvaluator.h)
|
||||
// - derive from Object (outside classes get wrapped)
|
||||
//
|
||||
// This code supports three kinds of value types:
|
||||
// - self-defined classes -> derive from Object, e.g. Expression
|
||||
// - classes defined outside -> wrap in a BoxOf object, e.g. String = BoxOf<wstring>
|
||||
// - C++ primitives like 'double' -> wrap in a Wrapper first then in a BoxOf, e.g. Number = BoxOf<Wrapped<double>>
|
||||
|
||||
struct Object { virtual ~Object() { } };
|
||||
|
||||
// indicates that the object has a name should be set from the expression path
|
||||
|
||||
struct HasName { virtual void SetName(const wstring & name) = 0; };
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Wrapped<T> -- wraps non-class primitive C++ type into a class, like 'double'.
|
||||
// (It can also be used for class types, but better use BoxOf<> below directly.)
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
template<typename T> class Wrapped
|
||||
{
|
||||
T value; // meant to be a primitive type
|
||||
public:
|
||||
operator const T&() const { return value; }
|
||||
operator T&() { return value; }
|
||||
Wrapped(T value) : value(value) { }
|
||||
T & operator=(const T & newValue) { value = newValue; }
|
||||
};
|
||||
typedef Wrapped<double> Double;
|
||||
typedef Wrapped<bool> Bool;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// BoxOf<T> -- wraps a pre-defined type, e.g. std::wstring, to derive from Object.
|
||||
// BoxOf<T> can dynamic_cast to T (e.g. BoxOf<wstring> is a wstring).
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
template<class C>
|
||||
class BoxOf : public Object, public C
|
||||
{
|
||||
public:
|
||||
#if 1
|
||||
template<class... _Types> BoxOf(_Types&&... _Args) : C(forward<_Types>(_Args)...) { }
|
||||
#else
|
||||
// TODO: change this to variadic templates, then we can instantiate everything we need through this
|
||||
BoxOf(const C & val) : C(val) { }
|
||||
BoxOf(){}
|
||||
#endif
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// String -- a string in config files
|
||||
// Can cast to wstring (done in a way that ConfigValuePtr can also cast to wstring).
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
typedef BoxOf<wstring> String;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ComputationNodeObject -- the 'magic' class that our parser understands for infix operations
|
||||
// TODO: unify with ComputationNodeBase
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
class ComputationNodeObject : public Object { }; // a base class for all nodes (that has no template parameter)
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// HasToString -- trait to indicate an object can print their content
|
||||
// Derive from HasToString() and implement ToString() method.
|
||||
// FormatConfigValue() will then return ToString().
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
struct HasToString
|
||||
{
|
||||
virtual wstring ToString() const = 0;
|
||||
|
||||
// some string helpers useful for ToString() operations of nested structures
|
||||
// TODO: move these out from this header into some more general place (I had to move them here because otherwise CNTKEval failed to compile)
|
||||
static wstring IndentString(wstring s, size_t indent)
|
||||
{
|
||||
const wstring prefix(indent, L' ');
|
||||
size_t pos = 0;
|
||||
for (;;)
|
||||
{
|
||||
s.insert(pos, prefix);
|
||||
pos = s.find(L'\n', pos + 2);
|
||||
if (pos == wstring::npos)
|
||||
return s;
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
static wstring NestString(wstring s, wchar_t open, bool newline, wchar_t close)
|
||||
{
|
||||
wstring result = IndentString(s, 2);
|
||||
if (newline) // have a new line after the open symbol
|
||||
result = L" \n" + result + L"\n ";
|
||||
else
|
||||
result.append(L" ");
|
||||
result.front() = open;
|
||||
result.back() = close;
|
||||
return result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// WithTag -- trait to give an object a tag string
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
class WithTag
|
||||
{
|
||||
wstring m_tag;
|
||||
public:
|
||||
WithTag(){}
|
||||
void SetTag(const wstring & tag) { m_tag = tag; }
|
||||
const wstring & GetTag() const { return m_tag; }
|
||||
};
|
||||
|
||||
// =======================================================================
|
||||
// ConfigValuePtr -- shared pointer to a config value
|
||||
// =======================================================================
|
||||
|
||||
// A ConfigValuePtr holds the value of a configuration variable.
|
||||
// - specifically, it holds a shared_ptr to a strongly typed C++ object
|
||||
// - ConfigValuePtrs are immutable when consumed.
|
||||
//
|
||||
// All configuration values, that is, values that can be held by a ConfigValuePtr, derive from BS::Object.
|
||||
// To get a shared_ptr<T> of an expected type T, type-cast the ConfigValuePtr to it.
|
||||
// To get the value of a copyable type like T=double or wstring, type-cast to T directly.
|
||||
//
|
||||
// ConfigValuePtrs are evaluated on-demand upon first retrieval:
|
||||
// - initially, a ConfigValuePtr would hold a Thunk; that is, a lambda that computes (resolves) the value
|
||||
// - upon first use, the Thunk is invoked to compute the value, which will then *replace* the Thunk
|
||||
// - any consumer of a ConfigValuePtr will only ever see the resolved value, since any access for consumption will force it to be resolved
|
||||
// - a resolved ConfigValuePtr is immutable
|
||||
//
|
||||
// On-demand evaluation is critical to the semantics of this entire configuration system.
|
||||
// A configuration is but one big expression (of nested records), but some evaluations cause side effects (such as saving a model), and some expressions may not even be in use at all.
|
||||
// Thus, we must use on-demand evaluation in order to ensure that side effects are only executed when desired.
|
||||
//
|
||||
// Further, to ensure a Thunk is executed at most once (otherwise we may get the same side-effect multiple times),
|
||||
// an unresolved ConfigValuePtr can only live in a single place. This means,
|
||||
// - an unresolved ConfigValuePtr (i.e. one holding a Thunk) cannot be copied (while resolved ones are immutable and can be copied freely)
|
||||
// - it can be moved (std::move()) during creation
|
||||
// - after creation, it should only live in a known location from which it can be retrieved; specifically:
|
||||
// - ConfigRecord entries
|
||||
// - ConfigArrays elements
|
||||
// - ConfigLambdas (default values of named arguments)
|
||||
|
||||
// TODO: separate this out from BrainScript to an interface that still does type casts--possible?
|
||||
class ConfigValuePtr : public shared_ptr<Object>
|
||||
{
|
||||
function<void(const wstring &)> failfn; // function to call in case of failure due to this value
|
||||
wstring expressionName; // the expression name reflects the path to reach this expression in the (possibly dynamically macro-expanded) expression tree. Used for naming ComputationNodes.
|
||||
|
||||
// Thunk for resolving a value. This Object represents a function that returns a ConfigValuePtr; call to resolve a deferred value
|
||||
class Thunk : public Object
|
||||
{
|
||||
function<ConfigValuePtr()> f; // the function to compute the value
|
||||
bool currentlyResolving; // set during resolution phase, to detect circular references
|
||||
function<void(const wstring &)> failfn; // function to call in case of failure due to this value
|
||||
public:
|
||||
Thunk(function<ConfigValuePtr()> f, const function<void(const wstring &)> & failfn) : f(f), failfn(failfn), currentlyResolving(false) { }
|
||||
ConfigValuePtr ResolveValue()
|
||||
{
|
||||
if (currentlyResolving) // detect circular references (infinite recursion)
|
||||
failfn(L"circular reference (expression to compute identifier's value uses the identifier's value)");
|
||||
currentlyResolving = true; // can't run from inside ourselves
|
||||
return f();
|
||||
// no need to reset currentlyResolving because this object gets replaced and thus deleted anyway
|
||||
}
|
||||
};
|
||||
Thunk * GetThunk() const { return dynamic_cast<Thunk*>(get()); } // get Thunk object or nullptr if already resolved
|
||||
public:
|
||||
|
||||
// --- assignment and copy/move constructors
|
||||
|
||||
ConfigValuePtr() {} // (formally needed somehow)
|
||||
ConfigValuePtr(const shared_ptr<Object> & p, const function<void(const wstring &)> & failfn, const wstring & expressionName) : shared_ptr<Object>(p), failfn(failfn), expressionName(expressionName) { }
|
||||
//ConfigValuePtr(const function<ConfigValuePtr()> & f, TextLocation location, const wstring & expressionName) : shared_ptr<Object>(make_shared<Thunk>(f, location)), location(location), expressionName(expressionName) { }
|
||||
static ConfigValuePtr MakeThunk(const function<ConfigValuePtr()> & f, const function<void(const wstring &)> & failfn, const wstring & expressionName)
|
||||
{
|
||||
return ConfigValuePtr(make_shared<Thunk>(f, failfn), failfn, expressionName);
|
||||
}
|
||||
// TODO: somehow the constructor overload from Thunk function fails to compile, so for now use MakeThunk instead
|
||||
|
||||
ConfigValuePtr(const ConfigValuePtr & other) { *this = other; }
|
||||
ConfigValuePtr(ConfigValuePtr && other) { *this = move(other); }
|
||||
void operator=(const ConfigValuePtr & other)
|
||||
{
|
||||
if (other.GetThunk()) // unresolved ConfigValuePtrs are not copyable, only movable
|
||||
Microsoft::MSR::CNTK::LogicError("ConfigValuePtr::operator=() on unresolved object; ConfigValuePtr is not assignable until resolved");
|
||||
(shared_ptr<Object>&)*this = other;
|
||||
failfn = other.failfn;
|
||||
expressionName = other.expressionName;
|
||||
}
|
||||
void operator=(ConfigValuePtr && other)
|
||||
{
|
||||
failfn = move(other.failfn);
|
||||
expressionName = move(other.expressionName);
|
||||
(shared_ptr<Object>&)*this = move(other);
|
||||
}
|
||||
void Fail(const wstring & msg) const { failfn(msg); }
|
||||
const function<void(const wstring &)> & GetFailFn() const { return failfn; } // if you need to pass on the fail function
|
||||
|
||||
// --- retrieving values by type cast
|
||||
|
||||
// access as a reference, that is, as a shared_ptr<T> --use this for Objects
|
||||
template<typename T> operator shared_ptr<T>() const { return AsPtr<T>(); }
|
||||
// access as a (const & to) value --use this for primitive types (also works to get a const wstring & from a String)
|
||||
template<typename T> operator T() const { return AsRef<T>(); }
|
||||
// Linux gcc barfs on this ^^ for 'us = (double)((wstring)arg).size();' due to some ambiguity error (while it works fine with Visual Studio).
|
||||
// If you encounter this, instead say 'us = (double)((wstring&)arg).size();' with a &
|
||||
operator double() const { return AsRef<Double>(); }
|
||||
operator float() const { return (float) AsRef<Double>(); }
|
||||
operator bool() const { return AsRef<Bool>(); }
|
||||
template<typename INT> INT AsInt() const
|
||||
{
|
||||
double val = AsRef<Double>();
|
||||
INT ival = (INT)val;
|
||||
const wchar_t * type = L"size_t";
|
||||
const char * t = typeid(INT).name(); t;
|
||||
// TODO: there is some duplication of type checking; can we unify that?
|
||||
if (ival != val)
|
||||
Fail(wstrprintf(L"expected expression of type %ls instead of floating-point value %f", type, val));
|
||||
return ival;
|
||||
}
|
||||
operator size_t() const { return AsInt<size_t>(); }
|
||||
operator int() const { return AsInt<int>(); }
|
||||
|
||||
// --- access functions
|
||||
|
||||
template<class C>
|
||||
bool Is() const
|
||||
{
|
||||
EnsureIsResolved();
|
||||
const auto p = dynamic_cast<C*>(get());
|
||||
return p != nullptr;
|
||||
}
|
||||
template<class C>
|
||||
const C & AsRef() const // returns reference to what the 'value' member. Configs are considered immutable, so return a const&
|
||||
{
|
||||
// TODO: factor these lines into a separate function
|
||||
// Note: since this returns a reference into 'this', you must keep the object you call this on around as long as you use the returned reference
|
||||
EnsureIsResolved();
|
||||
//const C * wanted = (C *) nullptr; const auto * got = get(); wanted; got; // allows to see C in the debugger
|
||||
const auto p = dynamic_cast<C*>(get());
|
||||
if (p == nullptr) // TODO: can we make this look the same as TypeExpected in BrainScriptEvaluator.cpp? We'd need the type name
|
||||
Fail(L"config member has wrong type (" + msra::strfun::utf16(typeid(*get()).name()) + L"), expected a " + TypeId<C>());
|
||||
return *p;
|
||||
}
|
||||
template<class C>
|
||||
shared_ptr<C> AsPtr() const // returns a shared_ptr cast to the 'value' member
|
||||
{
|
||||
EnsureIsResolved();
|
||||
const auto p = dynamic_pointer_cast<C>(*this);
|
||||
if (!p) // TODO: can we make this look the same as TypeExpected in BrainScriptEvaluator.cpp? We'd need the type name
|
||||
Fail(L"config member has wrong type (" + msra::strfun::utf16(typeid(*get()).name()) + L"), expected a " + TypeId<C>());
|
||||
return p;
|
||||
}
|
||||
|
||||
// --- properties
|
||||
|
||||
const char * TypeName() const { return typeid(*get()).name(); }
|
||||
const wstring & GetExpressionName() const{ return expressionName; }
|
||||
// TODO: ^^ it seems by saving the name in the ConfigValuePtr itself, we don't gain anything; maybe remove again in the future
|
||||
|
||||
// --- methods for resolving the value
|
||||
|
||||
const ConfigValuePtr & ResolveValue() const // (this is const but mutates the value if it resolves)
|
||||
{
|
||||
// call this when a a member might be as-of-yet unresolved, to evaluate it on-demand
|
||||
// get() is a pointer to a Thunk in that case, that is, a function object that yields the value
|
||||
const auto thunkp = GetThunk(); // is it a Thunk?
|
||||
if (thunkp) // value is a Thunk: we need to resolve
|
||||
{
|
||||
const auto value = thunkp->ResolveValue(); // completely replace ourselves with the actual result. This also releases the Thunk object
|
||||
const_cast<ConfigValuePtr&>(*this) = value;
|
||||
ResolveValue(); // allow it to return another Thunk...
|
||||
}
|
||||
return *this; // return ourselves so we can access a value as p_resolved = p->ResolveValue()
|
||||
}
|
||||
void EnsureIsResolved() const
|
||||
{
|
||||
if (GetThunk())
|
||||
Microsoft::MSR::CNTK::LogicError("ConfigValuePtr: unexpected access to unresolved object; ConfigValuePtrs can only be accessed after resolution");
|
||||
}
|
||||
}; // ConfigValuePtr
|
||||
|
||||
// use this for primitive values, double and bool
|
||||
template<typename T> static inline ConfigValuePtr MakePrimitiveConfigValuePtr(const T & val, const function<void(const wstring &)> & failfn, const wstring & exprPath)
|
||||
{
|
||||
return ConfigValuePtr(make_shared<BoxOf<Wrapped<T>>>(val), failfn, exprPath);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// IConfigRecord -- config record
|
||||
// Inside BrainScript, this would be a BS::ConfigRecord, but outside of the
|
||||
// evaluator, we will only pass it through this interface, to allow for
|
||||
// extensibility (e.g. Python interfacing).
|
||||
// Also, Objects themselves can expose this interface to make something accessible.
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
struct IConfigRecord // any class that exposes config can derive from this
|
||||
{
|
||||
virtual const ConfigValuePtr & operator[](const wstring & id) const = 0; // e.g. confRec[L"message"]
|
||||
virtual const ConfigValuePtr * Find(const wstring & id) const = 0; // returns nullptr if not found
|
||||
virtual vector<wstring> GetMemberIds() const = 0; // returns the names of all members in this record (but not including parent scopes)
|
||||
};
|
||||
typedef shared_ptr<struct IConfigRecord> IConfigRecordPtr;
|
||||
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ConfigRecord -- collection of named config values
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
class ConfigRecord : public Object, public IConfigRecord // all configuration arguments to class construction, resolved into ConfigValuePtrs
|
||||
{
|
||||
function<void(const wstring &)> failfn; // function to call in case of failure due to this value
|
||||
// change to ContextInsensitiveMap<ConfigValuePtr>
|
||||
map<wstring, ConfigValuePtr> members;
|
||||
IConfigRecordPtr parentScope; // we look up the chain
|
||||
ConfigRecord() { } // forbidden (private) to instantiate without a scope
|
||||
public:
|
||||
|
||||
// --- creation phase
|
||||
|
||||
ConfigRecord(IConfigRecordPtr parentScope, const function<void(const wstring &)> & failfn) : parentScope(parentScope), failfn(failfn) { }
|
||||
void Add(const wstring & id, const function<void(const wstring &)> & failfn, const ConfigValuePtr & value) { members[id] = value; failfn; }
|
||||
void Add(const wstring & id, const function<void(const wstring &)> & failfn, ConfigValuePtr && value) { members[id] = move(value); failfn; } // use this for unresolved ConfigPtrs
|
||||
// TODO: Add() does not yet correctly handle the failfn. It is meant to flag the location of the variable identifier
|
||||
|
||||
// --- usage phase
|
||||
|
||||
// regular lookup: just use record[id] or record(id, L"helpful message what 'id' does")
|
||||
// Any unresolved value is resolved at this time, as it is being consumed. Only after resolving a ConfigValuePtr, it can be copied.
|
||||
const ConfigValuePtr & /*IConfigRecord::*/operator[](const wstring & id) const // e.g. confRec[L"name"]
|
||||
{
|
||||
const auto memberIter = members.find(id);
|
||||
if (memberIter != members.end())
|
||||
return memberIter->second.ResolveValue(); // resolve upon access
|
||||
if (!parentScope) // not found: if at top scope, we fail
|
||||
failfn(L"required parameter '" + id + L"' not found");
|
||||
// The failfn will report the location where the dictionary itself was formed.
|
||||
// This is because this function is meant to be used by C++ code.
|
||||
// When we look up a name by a BrainScript ".FIELD" expression, we will use Find() so we can report the error for the offending FIELD itself.
|
||||
return (*parentScope)[id]; // have parent: look it up there
|
||||
}
|
||||
const ConfigValuePtr * /*IConfigRecord::*/Find(const wstring & id) const // returns nullptr if not found
|
||||
{
|
||||
auto memberIter = members.find(id);
|
||||
if (memberIter == members.end())
|
||||
if (parentScope)
|
||||
return parentScope->Find(id);
|
||||
else
|
||||
return nullptr;
|
||||
else
|
||||
return &memberIter->second.ResolveValue();
|
||||
}
|
||||
// get member ids; use this when you intend to consume all record entries and do not know the names
|
||||
// Note that unlike Find() and operator[], which return parent matches, this only returns entries in this record.
|
||||
virtual vector<wstring> /*IConfigRecord::*/GetMemberIds() const
|
||||
{
|
||||
vector<wstring> ids;
|
||||
for (auto & member : members)
|
||||
ids.push_back(member.first);
|
||||
return ids;
|
||||
}
|
||||
};
|
||||
typedef shared_ptr<ConfigRecord> ConfigRecordPtr;
|
||||
// TODO: can ConfigRecordPtr be IConfigRecordPtr?
|
||||
|
||||
// create a runtime object from its type --general case
|
||||
// There can be specializations of this that instantiate objects that do not take ConfigRecords or involve mapping like ComputationNode.
|
||||
template<typename C>
|
||||
shared_ptr<Object> MakeRuntimeObject(const IConfigRecordPtr config)
|
||||
{
|
||||
return make_shared<C>(config);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ConfigArray -- an array of config values
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// an array is just a vector of config values
|
||||
class ConfigArray : public Object
|
||||
{
|
||||
vector<ConfigValuePtr> values;
|
||||
int firstIndex;
|
||||
public:
|
||||
ConfigArray() : firstIndex(0) { }
|
||||
ConfigArray(int firstIndex, vector<ConfigValuePtr> && values) : firstIndex(firstIndex), values(move(values)) { }
|
||||
pair<int, int> GetIndexRange() const { return make_pair(firstIndex, firstIndex+(int)values.size()-1); }
|
||||
// building the array from expressions: append an element or an array
|
||||
void Append(ConfigValuePtr value) { values.push_back(value); }
|
||||
void Append(const ConfigArray & other) { values.insert(values.end(), other.values.begin(), other.values.end()); }
|
||||
// get element at index, including bounds check
|
||||
template<typename FAILFN>
|
||||
const ConfigValuePtr & At(int index, const FAILFN & failfn/*should report location of the index*/) const
|
||||
{
|
||||
if (index < firstIndex || index >= firstIndex + values.size())
|
||||
failfn(L"index out of bounds");
|
||||
return values[(size_t)(index - firstIndex)].ResolveValue(); // resolve upon access
|
||||
}
|
||||
};
|
||||
typedef shared_ptr<ConfigArray> ConfigArrayPtr;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ConfigLambda -- a lambda
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
class ConfigLambda : public Object
|
||||
{
|
||||
public:
|
||||
typedef map<wstring, ConfigValuePtr> NamedParams; // TODO: maybe even not use a typedef, just use the type
|
||||
private:
|
||||
// the function itself is a C++ lambda
|
||||
function<ConfigValuePtr(vector<ConfigValuePtr> &&, NamedParams &&, const wstring & exprName)> f;
|
||||
// inputs. This defines the interface to the function. Very simple in our case though.
|
||||
// We pass rvalue references because that allows to pass Thunks.
|
||||
vector<wstring> paramNames; // #parameters and parameter names (names are used for naming expressions only)
|
||||
NamedParams namedParams; // lists named parameters with their default values. Named parameters are optional and thus always must have a default.
|
||||
public:
|
||||
template<typename F>
|
||||
ConfigLambda(vector<wstring> && paramNames, NamedParams && namedParams, const F & f) : paramNames(move(paramNames)), namedParams(move(namedParams)), f(f) { }
|
||||
size_t GetNumParams() const { return paramNames.size(); }
|
||||
const vector<wstring> & GetParamNames() const { return paramNames; } // used for expression naming
|
||||
// what this function does is call f() held in this object with the given arguments except optional arguments are verified and fall back to their defaults if not given
|
||||
// The arguments are rvalue references, which allows us to pass Thunks, which is important to allow stuff with circular references like CNTK's DelayedNode.
|
||||
ConfigValuePtr Apply(vector<ConfigValuePtr> && args, NamedParams && namedArgs, const wstring & exprName)
|
||||
{
|
||||
NamedParams actualNamedArgs;
|
||||
// actualNamedArgs is a filtered version of namedArgs that contains all optional args listed in namedParams,
|
||||
// falling back to their default if not given in namedArgs.
|
||||
// On the other hand, any name in namedArgs that is not found in namedParams should be rejected.
|
||||
for (const auto & namedParam : namedParams)
|
||||
{
|
||||
const auto & id = namedParam.first; // id of expected named parameter
|
||||
const auto valuei = namedArgs.find(id); // was such parameter passed?
|
||||
if (valuei == namedArgs.end()) // named parameter not passed
|
||||
{ // if not given then fall back to default
|
||||
auto f = [&namedParam]() // we pass a lambda that resolves it upon first use, in our original location
|
||||
{
|
||||
return namedParam.second.ResolveValue();
|
||||
};
|
||||
actualNamedArgs[id] = move(ConfigValuePtr::MakeThunk(f, namedParam.second.GetFailFn(), exprName));
|
||||
}
|
||||
else // named parameter was passed
|
||||
actualNamedArgs[id] = move(valuei->second); // move it, possibly remaining unresolved
|
||||
// BUGBUG: we should pass in the location of the identifier, not that of the expression
|
||||
}
|
||||
for (const auto & namedArg : namedArgs) // make sure there are no extra named args that the macro does not take
|
||||
if (namedParams.find(namedArg.first) == namedParams.end())
|
||||
namedArg.second.Fail(L"function does not have an optional argument '" + namedArg.first + L"'");
|
||||
return f(move(args), move(actualNamedArgs), exprName);
|
||||
}
|
||||
// TODO: define an overload that takes const & for external users (which will then take a copy and pass it on to Apply &&)
|
||||
};
|
||||
typedef shared_ptr<ConfigLambda> ConfigLambdaPtr;
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ConfigurableRuntimeType -- interface to scriptable runtime types
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// helper for configurableRuntimeTypes initializer below
|
||||
// This returns a ConfigurableRuntimeType info structure that consists of
|
||||
// - a lambda that is a constructor for a given runtime type and
|
||||
// - a bool saying whether T derives from IConfigRecord
|
||||
struct ConfigurableRuntimeType // TODO: rename to ScriptableObjects::Factory or something like that
|
||||
{
|
||||
bool isConfigRecord; // exposes IConfigRecord --in this case the expression name is computed differently, namely relative to this item
|
||||
// TODO: is this ^^ actually still used anywhere?
|
||||
function<shared_ptr<Object>(const IConfigRecordPtr)> construct; // lambda to construct an object of this class
|
||||
// TODO: we should pass the expression name to construct() as well
|
||||
};
|
||||
|
||||
// scriptable runtime types must be exposed by this function
|
||||
// TODO: should this be a static member of above class?
|
||||
const ConfigurableRuntimeType * FindExternalRuntimeTypeInfo(const wstring & typeId);
|
||||
|
||||
}}} // end namespaces
|
|
@ -355,75 +355,55 @@ public:
|
|||
// understood. Also, braces in strings are not protected. [fseide]
|
||||
static std::string::size_type FindBraces(const std::string& str, std::string::size_type tokenStart)
|
||||
{
|
||||
// open braces and quote
|
||||
static const std::string openBraces = OPENBRACES;
|
||||
|
||||
// close braces and quote
|
||||
static const std::string closingBraces = CLOSINGBRACES;
|
||||
|
||||
const auto len = str.length();
|
||||
|
||||
// start is outside (or rather, at end of string): no brace here
|
||||
if (tokenStart >= len) {
|
||||
return npos;
|
||||
}
|
||||
|
||||
auto braceFound = openBraces.find(str[tokenStart]);
|
||||
// open braces and quote
|
||||
static const std::string openBraces = OPENBRACES;
|
||||
// close braces and quote
|
||||
static const std::string closingBraces = CLOSINGBRACES;
|
||||
|
||||
const auto charsToLookFor = closingBraces + openBraces; // all chars we match for
|
||||
|
||||
// get brace index for first character of input string
|
||||
const auto braceFound = openBraces.find(str[tokenStart]);
|
||||
// no brace present at tokenStart
|
||||
if (braceFound == npos) {
|
||||
if (braceFound == npos)
|
||||
return npos;
|
||||
}
|
||||
|
||||
// string begins with a brace--find the closing brace, while correctly handling nested braces
|
||||
std::vector<std::string::size_type> bracesFound;
|
||||
std::string::size_type current, opening;
|
||||
|
||||
current = opening = tokenStart;
|
||||
|
||||
// create a brace pair for string searches
|
||||
std::string braces;
|
||||
braces += openBraces[braceFound];
|
||||
braces += closingBraces[braceFound];
|
||||
|
||||
std::string braceStack; // nesting stack; .back() is closing symbol for inner-most brace
|
||||
braceStack.push_back(closingBraces[braceFound]); // closing symbol for current
|
||||
// search for end brace or other nested layers of this brace type
|
||||
while (current != npos && current + 1 < len)
|
||||
for (auto current = tokenStart; current + 1 < len;)
|
||||
{
|
||||
current = str.find_first_of(braces, current + 1);
|
||||
// check for a nested opening brace
|
||||
if (current == npos)
|
||||
{
|
||||
// look for closing brace and also for another opening brace
|
||||
// Inside strings we only accept the closing quote, and ignore any braces inside.
|
||||
current = str.find_first_of(braceStack.back() == '"' ? "\"" : charsToLookFor, current + 1); //
|
||||
if (current == string::npos) // none found: done or error
|
||||
break;
|
||||
}
|
||||
|
||||
// found a closing brace
|
||||
if (str[current] == braces[1])
|
||||
char brace = str[current];
|
||||
// found the expected closing brace?
|
||||
if (brace == braceStack.back())
|
||||
{
|
||||
// no braces on the stack, we are done
|
||||
if (bracesFound.empty())
|
||||
{
|
||||
braceStack.pop_back(); // yes: pop up and continue (or stop if stack is empty)
|
||||
if (braceStack.empty()) // fully closed: done
|
||||
return current;
|
||||
}
|
||||
|
||||
// have braces on the stack, pop the current one off
|
||||
opening = bracesFound.back();
|
||||
bracesFound.pop_back();
|
||||
}
|
||||
// or any other closing brace? That's an error.
|
||||
else if (closingBraces.find(brace) != string::npos)
|
||||
RuntimeError("unmatched bracket found in parameters");
|
||||
// found another opening brace, push it on the stack
|
||||
else
|
||||
{
|
||||
// found another opening brace, push it on the stack
|
||||
bracesFound.push_back(opening);
|
||||
opening = current;
|
||||
const auto braceFound = openBraces.find(brace); // index of brace
|
||||
braceStack.push_back(closingBraces[braceFound]); // closing symbol for current
|
||||
}
|
||||
}
|
||||
|
||||
// if we found unmatched parenthesis, throw an exception
|
||||
if (opening != npos)
|
||||
{
|
||||
RuntimeError("unmatched bracket found in parameters");
|
||||
}
|
||||
|
||||
return current;
|
||||
// hit end before everything was closed: error
|
||||
RuntimeError("no closing bracket found in parameters");
|
||||
}
|
||||
|
||||
// ParseValue - virtual function to parse a "token" as tokenized by Parse() below.
|
||||
|
@ -981,7 +961,7 @@ public:
|
|||
// ensure that this method was called on a single line (eg, no newline characters exist in 'configLine').
|
||||
if (configLine.find_first_of("\n") != std::string::npos)
|
||||
{
|
||||
throw std::logic_error(
|
||||
LogicError(
|
||||
"\"ResolveVariablesInSingleLine\" shouldn't be called with a string containing a newline character");
|
||||
}
|
||||
|
||||
|
@ -1028,7 +1008,7 @@ public:
|
|||
|
||||
if (varValue.find_first_of("\n") != std::string::npos)
|
||||
{
|
||||
throw std::logic_error(
|
||||
LogicError(
|
||||
"Newline character cannot be contained in the value of a variable which is resolved using $varName$ feature");
|
||||
}
|
||||
|
||||
|
|
|
@ -1545,7 +1545,7 @@ static BOOL ExpandWildcards (wstring path, vector<wstring> & paths)
|
|||
return FALSE; // another error
|
||||
}
|
||||
size_t pos = path.find_last_of (L"\\");
|
||||
if (pos == wstring::npos) throw std::logic_error ("unexpected missing \\ in path");
|
||||
if (pos == wstring::npos) LogicError ("unexpected missing \\ in path");
|
||||
wstring parent = path.substr (0, pos);
|
||||
do
|
||||
{
|
||||
|
|
|
@ -421,7 +421,7 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceSegBatch(Matrix<ElemType> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
|
||||
void SetSentenceSegBatch(Matrix<float> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
|
||||
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<typename BinaryReader<ElemType>::LabelIdType, typename BinaryReader<ElemType>::LabelType>& labelMapping);
|
||||
|
|
|
@ -72,7 +72,7 @@
|
|||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
@ -95,7 +95,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
|
|
|
@ -143,7 +143,7 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceSegBatch(Matrix<ElemType> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
|
||||
void SetSentenceSegBatch(Matrix<float> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
|
||||
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
|
|
|
@ -74,7 +74,7 @@
|
|||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
@ -97,7 +97,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
|
|
|
@ -10,17 +10,22 @@
|
|||
#include "basetypes.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
#ifdef _WIN32
|
||||
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
|
||||
#endif
|
||||
#include "simplesenonehmm.h" // for MMI scoring
|
||||
#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
|
||||
|
||||
#include "rollingwindowsource.h" // minibatch sources
|
||||
#include "utterancesource.h"
|
||||
#ifdef _WIN32
|
||||
#include "readaheadsource.h"
|
||||
#endif
|
||||
#include "chunkevalsource.h"
|
||||
#define DATAREADER_EXPORTS
|
||||
#include "DataReader.h"
|
||||
#include "HTKMLFReader.h"
|
||||
#include "commandArgUtil.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
@ -38,6 +43,7 @@ extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
|||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
#ifdef _WIN32
|
||||
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...
|
||||
|
||||
// Trim - trim white space off the start and end of the string
|
||||
|
@ -56,6 +62,7 @@ void Trim(std::string& str)
|
|||
if (found != npos)
|
||||
str.erase(found+1);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
}}}
|
|
@ -99,7 +99,7 @@ bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstr
|
|||
// saveId - name of the section to save into (section:subsection format)
|
||||
// labelMapping - map we are saving to the file
|
||||
template<class ElemType>
|
||||
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& labelMapping)
|
||||
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_dataWriter->SaveMapping(saveId, labelMapping);
|
||||
}
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#ifdef _WIN32
|
||||
#include <objbase.h>
|
||||
#endif
|
||||
#include "basetypes.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
|
@ -19,19 +21,34 @@
|
|||
#include "utterancesourcemulti.h"
|
||||
#include "utterancesource.h"
|
||||
#include "utterancesourcemulti.h"
|
||||
#ifdef _WIN32
|
||||
#include "readaheadsource.h"
|
||||
#endif
|
||||
#include "chunkevalsource.h"
|
||||
#include "minibatchiterator.h"
|
||||
#define DATAREADER_EXPORTS // creating the exports here
|
||||
#include "DataReader.h"
|
||||
#include "commandArgUtil.h"
|
||||
#include "HTKMLFReader.h"
|
||||
#ifdef LEAKDETECT
|
||||
#include <vld.h> // for memory leak detection
|
||||
#endif
|
||||
|
||||
#ifdef __unix__
|
||||
#include <limits.h>
|
||||
typedef unsigned long DWORD;
|
||||
typedef unsigned short WORD;
|
||||
typedef unsigned int UNINT32;
|
||||
#endif
|
||||
#pragma warning (disable: 4127) // conditional expression is constant; "if (sizeof(ElemType)==sizeof(float))" triggers this
|
||||
|
||||
#ifdef _WIN32
|
||||
int msra::numa::node_override = -1; // for numahelpers.h
|
||||
#endif
|
||||
|
||||
namespace msra { namespace lm {
|
||||
/*static*/ const mgram_map::index_t mgram_map::nindex = (mgram_map::index_t) -1; // invalid index
|
||||
}}
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
@ -44,7 +61,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_cudaAllocator = nullptr;
|
||||
m_mbiter = NULL;
|
||||
m_frameSource = NULL;
|
||||
#ifdef _WIN32
|
||||
m_readAheadSource = NULL;
|
||||
#endif
|
||||
m_lattices = NULL;
|
||||
|
||||
m_truncated = readerConfig("Truncated", "false");
|
||||
|
@ -202,7 +221,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_labelNameToIdMap[labelNames[i]]=iLabel;
|
||||
m_labelNameToDimMap[labelNames[i]]=m_labelDims[i];
|
||||
mlfpaths.clear();
|
||||
mlfpaths.push_back(thisLabel("mlfFile"));
|
||||
if (thisLabel.ExistsCurrent("mlfFile"))
|
||||
{
|
||||
mlfpaths.push_back(thisLabel("mlfFile"));
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!thisLabel.ExistsCurrent("mlfFileList"))
|
||||
{
|
||||
RuntimeError("Either mlfFile or mlfFileList must exist in HTKMLFReder");
|
||||
}
|
||||
wstring list = thisLabel("mlfFileList");
|
||||
for (msra::files::textreader r(list); r;)
|
||||
{
|
||||
mlfpaths.push_back(r.wgetline());
|
||||
}
|
||||
}
|
||||
mlfpathsmulti.push_back(mlfpaths);
|
||||
|
||||
m_labelsBufferMultiIO.push_back(nullptr);
|
||||
|
@ -271,7 +305,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
// see if they want to use readAhead
|
||||
#ifdef _WIN32
|
||||
m_readAhead = readerConfig("readAhead", "false");
|
||||
#endif
|
||||
|
||||
// read all input files (from multiple inputs)
|
||||
// TO DO: check for consistency (same number of files in each script file)
|
||||
|
@ -319,7 +355,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
if (readerConfig.Exists("unigram"))
|
||||
unigrampath = readerConfig("unigram");
|
||||
unigrampath = (wstring)readerConfig("unigram");
|
||||
|
||||
// load a unigram if needed (this is used for MMI training)
|
||||
msra::lm::CSymbolSet unigramsymbols;
|
||||
|
@ -361,13 +397,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
//std::vector<std::wstring> pagepath;
|
||||
foreach_index(i, mlfpathsmulti)
|
||||
{
|
||||
const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
|
||||
msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>
|
||||
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], unigram ? &unigramsymbols : NULL, (map<string,size_t>*) NULL, htktimetoframe); // label MLF
|
||||
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string,size_t>*) NULL, htktimetoframe); // label MLF
|
||||
// get the temp file name for the page file
|
||||
labelsmulti.push_back(labels);
|
||||
}
|
||||
|
||||
|
||||
if (!_stricmp(readMethod.c_str(),"blockRandomize"))
|
||||
{
|
||||
// construct all the parameters we don't need, but need to be passed to the constructor...
|
||||
|
@ -383,7 +419,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
else if (!_stricmp(readMethod.c_str(),"rollingWindow"))
|
||||
{
|
||||
#ifdef _WIN32
|
||||
std::wstring pageFilePath;
|
||||
#else
|
||||
std::string pageFilePath;
|
||||
#endif
|
||||
std::vector<std::wstring> pagePaths;
|
||||
if (readerConfig.Exists("pageFilePath"))
|
||||
{
|
||||
|
@ -391,28 +431,57 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
// replace any '/' with '\' for compat with default path
|
||||
std::replace(pageFilePath.begin(), pageFilePath.end(), '/','\\');
|
||||
|
||||
#ifdef _WIN32
|
||||
// verify path exists
|
||||
DWORD attrib = GetFileAttributes(pageFilePath.c_str());
|
||||
if (attrib==INVALID_FILE_ATTRIBUTES || !(attrib & FILE_ATTRIBUTE_DIRECTORY))
|
||||
throw std::runtime_error ("pageFilePath does not exist");
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
struct stat statbuf;
|
||||
if (stat(pageFilePath.c_str(), &statbuf)==-1)
|
||||
{
|
||||
throw std::runtime_error ("pageFilePath does not exist");
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
else // using default temporary path
|
||||
{
|
||||
#ifdef _WIN32
|
||||
pageFilePath.reserve(MAX_PATH);
|
||||
GetTempPath(MAX_PATH, &pageFilePath[0]);
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
pageFilePath.reserve(PATH_MAX);
|
||||
pageFilePath = "/tmp/temp.CNTK.XXXXXX";
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
if (pageFilePath.size()>MAX_PATH-14) // max length of input to GetTempFileName is MAX_PATH-14
|
||||
throw std::runtime_error (msra::strfun::strprintf ("pageFilePath must be less than %d characters", MAX_PATH-14));
|
||||
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
if (pageFilePath.size()>PATH_MAX-14) // max length of input to GetTempFileName is PATH_MAX-14
|
||||
throw std::runtime_error (msra::strfun::strprintf ("pageFilePath must be less than %d characters", PATH_MAX-14));
|
||||
#endif
|
||||
foreach_index(i, infilesmulti)
|
||||
{
|
||||
|
||||
#ifdef _WIN32
|
||||
wchar_t tempFile[MAX_PATH];
|
||||
GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
|
||||
pagePaths.push_back(tempFile);
|
||||
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
char* tempFile;
|
||||
//GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
|
||||
tempFile = (char*) pageFilePath.c_str();
|
||||
int fid = mkstemp(tempFile);
|
||||
unlink (tempFile);
|
||||
close (fid);
|
||||
pagePaths.push_back(GetWC(tempFile));
|
||||
#endif
|
||||
}
|
||||
|
||||
const bool mayhavenoframe=false;
|
||||
|
@ -513,7 +582,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
n++;
|
||||
}
|
||||
|
||||
fprintf (stderr, " %d entries\n", n);
|
||||
fprintf (stderr, " %d entries\n", (int)n);
|
||||
|
||||
if (i==0)
|
||||
numFiles=n;
|
||||
|
@ -534,7 +603,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
HTKMLFReader<ElemType>::~HTKMLFReader()
|
||||
{
|
||||
delete m_mbiter;
|
||||
#ifdef _WIN32
|
||||
delete m_readAheadSource;
|
||||
#endif
|
||||
delete m_frameSource;
|
||||
delete m_lattices;
|
||||
|
||||
|
@ -587,7 +658,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
void HTKMLFReader<ElemType>::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples /*= requestDataSize*/)
|
||||
{
|
||||
assert(subsetNum < numSubsets);
|
||||
assert(this->SupportsDistributedMBRead() || ((subsetNum == 0) && (numSubsets == 1)));
|
||||
assert(((subsetNum == 0) && (numSubsets == 1)) || this->SupportsDistributedMBRead());
|
||||
|
||||
m_mbSize = mbSize;
|
||||
|
||||
|
@ -664,6 +735,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// delete the old one first (in case called more than once)
|
||||
delete m_mbiter;
|
||||
msra::dbn::minibatchsource* source = m_frameSource;
|
||||
#ifdef _WIN32
|
||||
if (m_readAhead)
|
||||
{
|
||||
if (m_readAheadSource == NULL)
|
||||
|
@ -677,6 +749,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
source = m_readAheadSource;
|
||||
}
|
||||
#endif
|
||||
m_mbiter = new msra::dbn::minibatchiterator(*source, epoch, requestedEpochSamples, mbSize, subsetNum, numSubsets, datapasses);
|
||||
if (!m_featuresBufferMultiIO.empty())
|
||||
{
|
||||
|
@ -789,7 +862,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
// now, access all features and and labels by iterating over map of "matrices"
|
||||
bool first = true;
|
||||
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
for (iter = matrices.begin();iter!=matrices.end(); iter++)
|
||||
{
|
||||
// dereference matrix that corresponds to key (input/output name) and
|
||||
|
@ -810,9 +883,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
m_sentenceBegin.SetValue((ElemType) SEQUENCE_MIDDLE);
|
||||
m_sentenceBegin.SetValue(0, 0, (ElemType) SEQUENCE_START);
|
||||
|
||||
m_sentenceBegin.SetValue(0, (size_t)feat.cols()-1, (ElemType) SEQUENCE_END);
|
||||
std::fill(m_minibatchPackingFlag.begin(), m_minibatchPackingFlag.end(), MinibatchPackingFlag::None);
|
||||
m_minibatchPackingFlag[0] = MinibatchPackingFlag::SequenceStart;
|
||||
m_minibatchPackingFlag[(size_t)feat.cols() - 1] = MinibatchPackingFlag::SequenceEnd;
|
||||
first = false;
|
||||
}
|
||||
|
||||
|
@ -969,6 +1043,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
m_sentenceEnd[i] = false;
|
||||
m_switchFrame[i] = m_mbSize+1;
|
||||
if (m_processedFrame[i] == 1)
|
||||
{
|
||||
m_sentenceBegin.SetValue(i, 0, (ElemType)SEQUENCE_END);
|
||||
m_minibatchPackingFlag[0] = MinibatchPackingFlag::SequenceEnd;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -979,7 +1058,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
actualmbsize[i] = m_mbSize;
|
||||
endFr = startFr + actualmbsize[i];
|
||||
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
for (iter = matrices.begin();iter!=matrices.end(); iter++)
|
||||
{
|
||||
// dereference matrix that corresponds to key (input/output name) and
|
||||
|
@ -1044,7 +1123,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
actualmbsize[i] = m_toProcess[i] - m_processedFrame[i];
|
||||
endFr = startFr + actualmbsize[i];
|
||||
|
||||
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
for (iter = matrices.begin();iter!=matrices.end(); iter++)
|
||||
{
|
||||
// dereference matrix that corresponds to key (input/output name) and
|
||||
|
@ -1108,6 +1187,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_sentenceBegin.SetValue(i, actualmbsize[i], (ElemType)SEQUENCE_START);
|
||||
m_minibatchPackingFlag[actualmbsize[i]] |= MinibatchPackingFlag::SequenceStart;
|
||||
}
|
||||
if (actualmbsize[i] == m_mbSize)
|
||||
{
|
||||
m_sentenceBegin.SetValue(i, actualmbsize[i]-1, (ElemType)SEQUENCE_END);
|
||||
m_minibatchPackingFlag[actualmbsize[i]-1] |= MinibatchPackingFlag::SequenceEnd;
|
||||
}
|
||||
startFr = m_switchFrame[i];
|
||||
endFr = m_mbSize;
|
||||
bool reNewSucc = ReNewBufferForMultiIO(i);
|
||||
|
@ -1158,7 +1242,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
}
|
||||
}
|
||||
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
for (iter = matrices.begin();iter!=matrices.end(); iter++)
|
||||
{
|
||||
// dereference matrix that corresponds to key (input/output name) and
|
||||
|
@ -1195,7 +1279,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
if (matrices.find(iter->first)==matrices.end())
|
||||
{
|
||||
fprintf(stderr,"GetMinibatchToWrite: feature node %ws specified in reader not found in the network\n",iter->first.c_str());
|
||||
fprintf(stderr,"GetMinibatchToWrite: feature node %ls specified in reader not found in the network\n", iter->first.c_str());
|
||||
throw std::runtime_error("GetMinibatchToWrite: feature node specified in reader not found in the network.");
|
||||
}
|
||||
}
|
||||
|
@ -1227,7 +1311,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
reader.read (path, featkind, sampperiod, feat); // whole file read as columns of feature vectors
|
||||
});
|
||||
fprintf (stderr, "evaluate: reading %d frames of %S\n", feat.cols(), ((wstring)path).c_str());
|
||||
fprintf (stderr, "evaluate: reading %d frames of %S\n", (int)feat.cols(), ((wstring)path).c_str());
|
||||
m_fileEvalSource->AddFile(feat, featkind, sampperiod, i);
|
||||
}
|
||||
m_inputFileIndex++;
|
||||
|
@ -1237,7 +1321,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
// populate input matrices
|
||||
bool first = true;
|
||||
std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
typename std::map<std::wstring, Matrix<ElemType>*>::iterator iter;
|
||||
for (iter = matrices.begin();iter!=matrices.end(); iter++)
|
||||
{
|
||||
// dereference matrix that corresponds to key (input/output name) and
|
||||
|
@ -1256,9 +1340,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
m_minibatchPackingFlag.resize((size_t)feat.cols());
|
||||
m_sentenceBegin.SetValue((ElemType)SEQUENCE_MIDDLE);
|
||||
m_sentenceBegin.SetValue(0, 0, (ElemType)SEQUENCE_START);
|
||||
|
||||
m_sentenceBegin.SetValue(0, (size_t)feat.cols() - 1, (ElemType)SEQUENCE_END);
|
||||
std::fill(m_minibatchPackingFlag.begin(), m_minibatchPackingFlag.end(), MinibatchPackingFlag::None);
|
||||
m_minibatchPackingFlag[0] = MinibatchPackingFlag::SequenceStart;
|
||||
m_minibatchPackingFlag[(size_t)feat.cols() - 1] = MinibatchPackingFlag::SequenceEnd;
|
||||
first = false;
|
||||
}
|
||||
|
||||
|
@ -1557,7 +1642,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> &sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
void HTKMLFReader<ElemType>::SetSentenceEndInBatch(vector<size_t> &sentenceEnd)
|
||||
{
|
||||
sentenceEnd.resize(m_switchFrame.size());
|
||||
for (size_t i = 0; i < m_switchFrame.size() ; i++)
|
||||
{
|
||||
sentenceEnd[i] = m_switchFrame[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFReader<ElemType>::SetSentenceSegBatch(Matrix<float> &sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
{
|
||||
if (!m_framemode)
|
||||
{
|
||||
|
@ -1582,7 +1677,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
{
|
||||
features.push_back(msra::strfun::utf16(iter->first));
|
||||
}
|
||||
else if (temp.ExistsCurrent("mlfFile"))
|
||||
else if (temp.ExistsCurrent("mlfFile")|| temp.ExistsCurrent("mlfFileList"))
|
||||
{
|
||||
labels.push_back(msra::strfun::utf16(iter->first));
|
||||
}
|
||||
|
|
|
@ -21,7 +21,9 @@ private:
|
|||
|
||||
msra::dbn::minibatchiterator* m_mbiter;
|
||||
msra::dbn::minibatchsource* m_frameSource;
|
||||
#ifdef _WIN32
|
||||
msra::dbn::minibatchreadaheadsource* m_readAheadSource;
|
||||
#endif
|
||||
msra::dbn::FileEvalSource* m_fileEvalSource;
|
||||
msra::dbn::latticesource* m_lattices;
|
||||
map<wstring,msra::lattices::lattice::htkmlfwordsequence> m_latticeMap;
|
||||
|
@ -39,6 +41,8 @@ private:
|
|||
vector<size_t> m_switchFrame;
|
||||
bool m_noData;
|
||||
bool m_trainOrTest; // if false, in file writing mode
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
|
||||
std::map<LabelIdType, LabelType> m_idToLabelMap;
|
||||
|
||||
|
@ -141,9 +145,29 @@ private:
|
|||
}
|
||||
|
||||
public:
|
||||
Matrix<ElemType> m_sentenceBegin;
|
||||
/// a matrix of n_stream x n_length
|
||||
/// n_stream is the number of streams
|
||||
/// n_length is the maximum lenght of each stream
|
||||
/// for example, two sentences used in parallel in one minibatch would be
|
||||
/// [2 x 5] if the max length of one of the sentences is 5
|
||||
/// the elements of the matrix is 0, 1, or -1, defined as SEQUENCE_START, SEQUENCE_MIDDLE, NO_INPUT in cbasetype.h
|
||||
/// 0 1 1 0 1
|
||||
/// 1 0 1 0 0
|
||||
/// for two parallel data streams. The first has two sentences, with 0 indicating begining of a sentence
|
||||
/// the second data stream has two sentences, with 0 indicating begining of sentences
|
||||
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
|
||||
/// frame.
|
||||
Matrix<float> m_sentenceBegin;
|
||||
|
||||
/// a matrix of 1 x n_length
|
||||
/// 1 denotes the case that there exists sentnece begin or no_labels case in this frame
|
||||
/// 0 denotes such case is not in this frame
|
||||
vector<MinibatchPackingFlag> m_minibatchPackingFlag;
|
||||
|
||||
/// by default it is false
|
||||
/// if true, reader will set to SEQUENCE_MIDDLE for time positions that are orignally correspond to SEQUENCE_START
|
||||
/// set to true so that a current minibatch can uses state activities from the previous minibatch.
|
||||
/// default will have truncated BPTT, which only does BPTT inside a minibatch
|
||||
bool mIgnoreSentenceBeginTag;
|
||||
HTKMLFReader() : m_sentenceBegin(CPUDEVICE) {
|
||||
}
|
||||
|
@ -158,18 +182,19 @@ public:
|
|||
|
||||
virtual bool SupportsDistributedMBRead() const override
|
||||
{
|
||||
return m_frameSource->supportsbatchsubsetting();
|
||||
return ((m_frameSource != nullptr) && m_frameSource->supportsbatchsubsetting());
|
||||
}
|
||||
|
||||
virtual void StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize) override;
|
||||
|
||||
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<unsigned, LabelType>& labelMapping);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetSentenceSegBatch(Matrix<ElemType> &sentenceBegin, vector<MinibatchPackingFlag>& sentenceExistsBeginOrNoInputs);
|
||||
void SetSentenceSegBatch(Matrix<float> &sentenceBegin, vector<MinibatchPackingFlag>& sentenceExistsBeginOrNoLabels);
|
||||
void SetSentenceEndInBatch(vector<size_t> &/*sentenceEnd*/);
|
||||
void SetSentenceEnd(int /*actualMbSize*/){};
|
||||
void SetRandomSeed(int){ NOT_IMPLEMENTED };
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@
|
|||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
|
@ -91,7 +91,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#ifdef _WIN32
|
||||
#include <objbase.h>
|
||||
#endif
|
||||
#include "basetypes.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
|
@ -85,7 +87,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
n++;
|
||||
}
|
||||
|
||||
fprintf (stderr, " %d entries\n", n);
|
||||
fprintf (stderr, " %d entries\n", (int)n);
|
||||
|
||||
if (i==0)
|
||||
numFiles=n;
|
||||
|
@ -163,17 +165,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
msra::files::make_intermediate_dirs (outputFile);
|
||||
msra::util::attempt (5, [&]()
|
||||
{
|
||||
msra::asr::htkfeatwriter::write (outputFile, "USER", sampPeriod, output);
|
||||
msra::asr::htkfeatwriter::write (outputFile, "USER", this->sampPeriod, output);
|
||||
});
|
||||
|
||||
fprintf (stderr, "evaluate: writing %d frames of %S\n", output.cols(), outputFile.c_str());
|
||||
fprintf (stderr, "evaluate: writing %d frames of %S\n", (int)output.cols(), outputFile.c_str());
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& /*labelMapping*/)
|
||||
void HTKMLFWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& /*labelMapping*/)
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
// HTKMLFReader.h - Include file for the MTK and MLF format of features and samples
|
||||
#pragma once
|
||||
#include "DataWriter.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
@ -33,11 +35,13 @@ private:
|
|||
};
|
||||
|
||||
public:
|
||||
using LabelType = typename IDataWriter<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
|
||||
virtual void Init(const ConfigParameters& writerConfig);
|
||||
virtual void Destroy();
|
||||
virtual void GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections);
|
||||
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
|
||||
virtual void SaveMapping(std::wstring saveId, const std::map<typename LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
};
|
||||
|
||||
}}}
|
|
@ -82,18 +82,51 @@ OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_sec
|
|||
#pragma warning(disable : 4702) // unreachable code
|
||||
#endif
|
||||
|
||||
#include "Platform.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h> // include here because we redefine some names later
|
||||
#include <errno.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cmath> // for HUGE_VAL
|
||||
#include <assert.h>
|
||||
#include <stdarg.h>
|
||||
#include <map>
|
||||
#include <windows.h> // for CRITICAL_SECTION
|
||||
#include <stdexcept>
|
||||
#include <locale> // std::wstring_convert
|
||||
#include <string>
|
||||
#include <algorithm> // for transform()
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <codecvt> // std::codecvt_utf8
|
||||
#endif
|
||||
#ifdef _WIN32
|
||||
#include <windows.h> // for CRITICAL_SECTION and Unicode conversion functions --TODO: is there a portable alternative?
|
||||
#endif
|
||||
|
||||
#if __unix__
|
||||
#include <strings.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <dlfcn.h>
|
||||
#include <sys/time.h>
|
||||
|
||||
typedef unsigned char byte;
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#pragma push_macro("STRSAFE_NO_DEPRECATE")
|
||||
#define STRSAFE_NO_DEPRECATE // deprecation managed elsewhere, not by strsafe
|
||||
#include <strsafe.h> // for strbcpy() etc templates
|
||||
#pragma pop_macro("STRSAFE_NO_DEPRECATE")
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
|
||||
// CRT error handling seems to not be included in wince headers
|
||||
// so we define our own imports
|
||||
|
@ -106,6 +139,7 @@ OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_sec
|
|||
#define strerror(x) "strerror error but can't report error number sorry!"
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifndef __in // dummies for sal annotations if compiler does not support it
|
||||
#define __in
|
||||
#define __inout_z
|
||||
|
@ -122,11 +156,103 @@ OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_sec
|
|||
#ifndef __override // and some more non-std extensions required by Office
|
||||
#define __override virtual
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// disable warnings for which fixing would make code less readable
|
||||
#pragma warning(disable : 4290) // throw() declaration ignored
|
||||
#pragma warning(disable : 4244) // conversion from typeA to typeB, possible loss of data
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// (w)cstring -- helper class like std::string but with auto-cast to char*
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace strfun {
|
||||
// a class that can return a std::string with auto-convert into a const char*
|
||||
template<typename C> struct basic_cstring : public std::basic_string<C>
|
||||
{
|
||||
template<typename S> basic_cstring (S p) : std::basic_string<C> (p) { }
|
||||
operator const C * () const { return this->c_str(); }
|
||||
};
|
||||
typedef basic_cstring<char> cstring;
|
||||
typedef basic_cstring<wchar_t> wcstring;
|
||||
}}
|
||||
static inline wchar_t*GetWC(const char *c)
|
||||
{
|
||||
const size_t cSize = strlen(c)+1;
|
||||
wchar_t* wc = new wchar_t[cSize];
|
||||
mbstowcs (wc, c, cSize);
|
||||
|
||||
return wc;
|
||||
}
|
||||
struct MatchPathSeparator
|
||||
{
|
||||
bool operator()( char ch ) const
|
||||
{
|
||||
return ch == '\\' || ch == '/';
|
||||
}
|
||||
};
|
||||
static inline std::string basename( std::string const& pathname)
|
||||
{
|
||||
return std::string (std::find_if(pathname.rbegin(), pathname.rend(),MatchPathSeparator()).base(), pathname.end());
|
||||
}
|
||||
|
||||
static inline std::string removeExtension (std::string const& filename)
|
||||
{
|
||||
//std::string::const_reverse_iterator pivot = std::find(filename.rbegin(), filename.rend(), '.');
|
||||
//return pivot == filename.rend() ? filename: std::string(filename.begin(), pivot.base()-1);
|
||||
size_t lastindex = filename.find_last_of(".");
|
||||
return filename.substr(0, lastindex);
|
||||
}
|
||||
static inline std::wstring basename( std::wstring const& pathname)
|
||||
{
|
||||
return std::wstring (std::find_if(pathname.rbegin(), pathname.rend(),MatchPathSeparator()).base(), pathname.end());
|
||||
}
|
||||
|
||||
static inline std::wstring removeExtension (std::wstring const& filename)
|
||||
{
|
||||
//std::wstring::const_reverse_iterator pivot = std::find(filename.rbegin(), filename.rend(), '.');
|
||||
//return pivot == filename.rend() ? filename: std::wstring(filename.begin(), pivot.base()-1);
|
||||
size_t lastindex = filename.find_last_of(L".");
|
||||
return filename.substr(0, lastindex);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// some mappings for non-Windows builds
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#ifndef _MSC_VER // add some functions that are VS-only
|
||||
// --- basic file functions
|
||||
// convert a wchar_t path to what gets passed to CRT functions that take narrow characters
|
||||
// This is needed for the Linux CRT which does not accept wide-char strings for pathnames anywhere.
|
||||
// Always use this function for mapping the paths.
|
||||
static inline msra::strfun::cstring charpath (const std::wstring & p)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p);
|
||||
#else // old version, delete once we know it works
|
||||
size_t len = p.length();
|
||||
std::vector<char> buf(2 * len + 1, 0); // max: 1 wchar => 2 mb chars
|
||||
::wcstombs(buf.data(), p.c_str(), 2 * len + 1);
|
||||
return msra::strfun::cstring (&buf[0]);
|
||||
#endif
|
||||
}
|
||||
static inline FILE* _wfopen (const wchar_t * path, const wchar_t * mode) { return fopen(charpath(path), charpath(mode)); }
|
||||
static inline int _wunlink (const wchar_t * p) { return unlink (charpath (p)); }
|
||||
static inline int _wmkdir (const wchar_t * p) { return mkdir (charpath (p), 0777/*correct?*/); }
|
||||
// --- basic string functions
|
||||
static inline wchar_t* wcstok_s (wchar_t* s, const wchar_t* delim, wchar_t** ptr) { return ::wcstok(s, delim, ptr); }
|
||||
static inline int _stricmp (const char * a, const char * b) { return ::strcasecmp (a, b); }
|
||||
static inline int _strnicmp (const char * a, const char * b, size_t n) { return ::strncasecmp (a, b, n); }
|
||||
static inline int _wcsicmp (const wchar_t * a, const wchar_t * b) { return ::wcscasecmp (a, b); }
|
||||
static inline int _wcsnicmp (const wchar_t * a, const wchar_t * b, size_t n) { return ::wcsncasecmp (a, b, n); }
|
||||
static inline int64_t _strtoi64 (const char * s, char ** ep, int r) { return strtoll (s, ep, r); } // TODO: check if correct
|
||||
static inline uint64_t _strtoui64 (const char * s, char ** ep, int r) { return strtoull (s, ep, r); } // TODO: correct for size_t?
|
||||
// -- other
|
||||
//static inline void memcpy_s(void * dst, size_t dstsize, const void * src, size_t maxcount) { assert (maxcount <= dstsize); memcpy (dst, src, maxcount); }
|
||||
static inline void Sleep (size_t ms) { std::this_thread::sleep_for (std::chrono::milliseconds (ms)); }
|
||||
#define _countof(_Array) (sizeof(_Array) / sizeof(_Array[0]))
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// basic macros
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -142,6 +268,9 @@ extern void _CHECKED_ASSERT_error(const char * file, int line, const char * exp)
|
|||
#endif
|
||||
#endif
|
||||
|
||||
#define EPSILON 1e-5
|
||||
#define ISCLOSE(a, b, threshold) (abs(a - b) < threshold)?true:false
|
||||
|
||||
/**
|
||||
These macros are used for sentence segmentation information.
|
||||
*/
|
||||
|
@ -190,6 +319,8 @@ namespace msra { namespace basetypes {
|
|||
|
||||
// class ARRAY -- std::vector with array-bounds checking
|
||||
// VS 2008 and above do this, so there is no longer a need for this.
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4555) // expression has no affect, used so retail won't be empty
|
||||
|
||||
template<class _ElemType>
|
||||
class ARRAY : public std::vector<_ElemType>
|
||||
|
@ -296,6 +427,7 @@ public:
|
|||
};
|
||||
template<class _T> inline void swap (fixed_vector<_T> & L, fixed_vector<_T> & R) throw() { L.swap (R); }
|
||||
|
||||
#pragma warning(pop) // pop off waring: expression has no effect
|
||||
// class matrix - simple fixed-size 2-dimensional array, access elements as m(i,j)
|
||||
// stored as concatenation of rows
|
||||
|
||||
|
@ -307,14 +439,14 @@ public:
|
|||
typedef T elemtype;
|
||||
matrix() : numcols (0) {}
|
||||
matrix (size_t n, size_t m) { resize (n, m); }
|
||||
void resize (size_t n, size_t m) { numcols = m; fixed_vector::resize (n * m); }
|
||||
void resize (size_t n, size_t m) { numcols = m; fixed_vector<T>::resize (n * m); }
|
||||
size_t cols() const { return numcols; }
|
||||
size_t rows() const { return empty() ? 0 : size() / cols(); }
|
||||
size_t size() const { return fixed_vector::size(); } // use this for reading and writing... not nice!
|
||||
bool empty() const { return fixed_vector::empty(); }
|
||||
size_t size() const { return fixed_vector<T>::size(); } // use this for reading and writing... not nice!
|
||||
bool empty() const { return fixed_vector<T>::empty(); }
|
||||
T & operator() (size_t i, size_t j) { return (*this)[locate(i,j)]; }
|
||||
const T & operator() (size_t i, size_t j) const { return (*this)[locate(i,j)]; }
|
||||
void swap (matrix & other) throw() { std::swap (numcols, other.numcols); fixed_vector::swap (other); }
|
||||
void swap (matrix & other) throw() { std::swap (numcols, other.numcols); fixed_vector<T>::swap (other); }
|
||||
};
|
||||
template<class _T> inline void swap (matrix<_T> & L, matrix<_T> & R) throw() { L.swap (R); }
|
||||
|
||||
|
@ -333,16 +465,16 @@ public:
|
|||
noncopyable(){}
|
||||
};
|
||||
|
||||
// class CCritSec and CAutoLock -- simple critical section handling
|
||||
class CCritSec
|
||||
{
|
||||
CCritSec (const CCritSec &); CCritSec & operator= (const CCritSec &);
|
||||
CRITICAL_SECTION m_CritSec;
|
||||
CCritSec (const CCritSec &) = delete;
|
||||
CCritSec & operator= (const CCritSec &) = delete;
|
||||
std::mutex m_CritSec;
|
||||
public:
|
||||
CCritSec() { InitializeCriticalSection(&m_CritSec); };
|
||||
~CCritSec() { DeleteCriticalSection(&m_CritSec); };
|
||||
void Lock() { EnterCriticalSection(&m_CritSec); };
|
||||
void Unlock() { LeaveCriticalSection(&m_CritSec); };
|
||||
CCritSec() {};
|
||||
~CCritSec() {};
|
||||
void Lock() { m_CritSec.lock(); };
|
||||
void Unlock() { m_CritSec.unlock(); };
|
||||
};
|
||||
|
||||
// locks a critical section, and unlocks it automatically
|
||||
|
@ -356,6 +488,7 @@ public:
|
|||
~CAutoLock() { m_rLock.Unlock(); };
|
||||
};
|
||||
|
||||
#ifdef _WIN32
|
||||
// an efficient way to write COM code
|
||||
// usage examples:
|
||||
// COM_function() || throw_hr ("message");
|
||||
|
@ -436,9 +569,11 @@ public:
|
|||
operator void * () { return TlsGetValue (tlsSlot); }
|
||||
void *operator = (void *val) { if (!TlsSetValue (tlsSlot,val)) throw std::runtime_error ("tls: TlsSetValue failed"); return val; }
|
||||
};
|
||||
#endif
|
||||
|
||||
};}; // namespace
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifndef BASETYPES_NO_UNSAFECRTOVERLOAD // if on, no unsafe CRT overload functions
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -465,7 +600,11 @@ public:
|
|||
#include <xlocale> // uses strlen()
|
||||
#endif
|
||||
#define strlen strlen_
|
||||
#ifndef LINUX
|
||||
template<typename _T> inline __declspec(deprecated("Dummy general template, cannot be used directly"))
|
||||
#else
|
||||
template<typename _T> inline
|
||||
#endif // LINUX
|
||||
size_t strlen_(_T &s) { return strnlen_s(static_cast<const char *>(s), SIZE_MAX); } // never be called but needed to keep compiler happy
|
||||
template<typename _T> inline size_t strlen_(const _T &s) { return strnlen_s(static_cast<const char *>(s), SIZE_MAX); }
|
||||
template<> inline size_t strlen_(char * &s) { return strnlen_s(s, SIZE_MAX); }
|
||||
|
@ -544,7 +683,10 @@ static inline const char *strerror_(int e)
|
|||
if (msgs.find(e) == msgs.end()) { char msg[1024]; strerror_s (msg, e); msgs[e] = msg; }
|
||||
return msgs[e].c_str();
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
extern int fileno(FILE*); // somehow got deprecated in C++11
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -560,8 +702,11 @@ template<class _T> struct _strprintf : public std::basic_string<_T>
|
|||
{ // works for both wchar_t* and char*
|
||||
_strprintf (const _T * format, ...)
|
||||
{
|
||||
va_list args; va_start (args, format); // varargs stuff
|
||||
va_list args;
|
||||
va_start (args, format); // varargs stuff
|
||||
size_t n = _cprintf (format, args); // num chars excl. '\0'
|
||||
va_end(args);
|
||||
va_start(args, format);
|
||||
const int FIXBUF_SIZE = 128; // incl. '\0'
|
||||
if (n < FIXBUF_SIZE)
|
||||
{
|
||||
|
@ -576,16 +721,47 @@ template<class _T> struct _strprintf : public std::basic_string<_T>
|
|||
}
|
||||
private:
|
||||
// helpers
|
||||
inline size_t _cprintf (const wchar_t * format, va_list args) { return _vscwprintf (format, args); }
|
||||
inline size_t _cprintf (const char * format, va_list args) { return _vscprintf (format, args); }
|
||||
inline const wchar_t * _sprintf (wchar_t * buf, size_t bufsiz, const wchar_t * format, va_list args) { vswprintf_s (buf, bufsiz, format, args); return buf; }
|
||||
inline const char * _sprintf ( char * buf, size_t bufsiz, const char * format, va_list args) { vsprintf_s (buf, bufsiz, format, args); return buf; }
|
||||
inline size_t _cprintf(const wchar_t* format, va_list args)
|
||||
{
|
||||
#ifdef __WINDOWS__
|
||||
return vswprintf(nullptr, 0, format, args);
|
||||
#elif defined(__UNIX__)
|
||||
// TODO: Really??? Write to file in order to know the length of a string?
|
||||
FILE *dummyf = fopen("/dev/null", "w");
|
||||
if (dummyf == NULL)
|
||||
perror("The following error occurred in basetypes.h:cprintf");
|
||||
int n = vfwprintf (dummyf, format, args);
|
||||
if (n < 0)
|
||||
perror("The following error occurred in basetypes.h:cprintf");
|
||||
fclose(dummyf);
|
||||
return n;
|
||||
#endif
|
||||
}
|
||||
inline size_t _cprintf(const char* format, va_list args)
|
||||
{
|
||||
#ifdef __WINDOWS__
|
||||
return vsprintf(nullptr, format, args);
|
||||
#elif defined(__UNIX__)
|
||||
// TODO: Really??? Write to file in order to know the length of a string?
|
||||
FILE *dummyf = fopen("/dev/null", "wb");
|
||||
if (dummyf == NULL)
|
||||
perror("The following error occurred in basetypes.h:cprintf");
|
||||
int n = vfprintf (dummyf, format, args);
|
||||
if (n < 0)
|
||||
perror("The following error occurred in basetypes.h:cprintf");
|
||||
fclose(dummyf);
|
||||
return n;
|
||||
#endif
|
||||
}
|
||||
inline const wchar_t * _sprintf(wchar_t * buf, size_t bufsiz, const wchar_t * format, va_list args) { vswprintf(buf, bufsiz, format, args); return buf; }
|
||||
inline const char * _sprintf ( char * buf, size_t /*bufsiz*/, const char * format, va_list args) { vsprintf (buf, format, args); return buf; }
|
||||
};
|
||||
typedef strfun::_strprintf<char> strprintf; // char version
|
||||
typedef strfun::_strprintf<wchar_t> wstrprintf; // wchar_t version
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
// string-encoding conversion functions
|
||||
struct utf8 : std::string { utf8 (const std::wstring & p) // utf-16 to -8
|
||||
{
|
||||
|
@ -612,6 +788,7 @@ struct utf16 : std::wstring { utf16 (const std::string & p) // utf-8 to -16
|
|||
ASSERT (rc < buf.size ());
|
||||
(*(std::wstring*)this) = &buf[0];
|
||||
}};
|
||||
#endif
|
||||
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996) // Reviewed by Yusheng Li, March 14, 2006. depr. fn (wcstombs, mbstowcs)
|
||||
|
@ -633,6 +810,19 @@ static inline std::wstring mbstowcs (const std::string & p) // input: MBCS
|
|||
return std::wstring (&buf[0]);
|
||||
}
|
||||
#pragma warning(pop)
|
||||
#ifdef _WIN32
|
||||
static inline cstring utf8 (const std::wstring & p) { return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p); } // utf-16 to -8
|
||||
static inline wcstring utf16 (const std::string & p) { return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().from_bytes(p); } // utf-8 to -16
|
||||
#else // BUGBUG: we cannot compile the above on Cygwin GCC, so for now fake it using the mbs functions, which will only work for 7-bit ASCII strings
|
||||
static inline std::string utf8 (const std::wstring & p) { return msra::strfun::wcstombs (p.c_str()); } // output: UTF-8... not really
|
||||
static inline std::wstring utf16 (const std::string & p) { return msra::strfun::mbstowcs(p.c_str()); } // input: UTF-8... not really
|
||||
#endif
|
||||
static inline cstring utf8 (const std::string & p) { return p; } // no conversion (useful in templated functions)
|
||||
static inline wcstring utf16 (const std::wstring & p) { return p; }
|
||||
|
||||
// convert a string to lowercase --TODO: currently only correct for 7-bit ASCII
|
||||
template<typename CHAR>
|
||||
static inline void tolower_ascii (std::basic_string<CHAR> & s) { std::transform(s.begin(), s.end(), s.begin(), [] (CHAR c) { return (c >= 0 && c < 128) ? ::tolower(c) : c; }); }
|
||||
|
||||
// split and join -- tokenize a string like strtok() would, join() strings together
|
||||
template<class _T> static inline std::vector<std::basic_string<_T>> split (const std::basic_string<_T> & s, const _T * delim)
|
||||
|
@ -662,7 +852,11 @@ template<class _T> static inline std::basic_string<_T> join (const std::vector<s
|
|||
// parsing strings to numbers
|
||||
static inline int toint (const wchar_t * s)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
return _wtoi (s); // ... TODO: check it
|
||||
#else
|
||||
return (int)wcstol(s, 0, 10);
|
||||
#endif
|
||||
}
|
||||
static inline int toint (const char * s)
|
||||
{
|
||||
|
@ -766,7 +960,7 @@ public:
|
|||
auto_file_ptr() : f (NULL) { }
|
||||
~auto_file_ptr() { close(); }
|
||||
auto_file_ptr (const char * path, const char * mode) { f = fopen (path, mode); if (f == NULL) openfailed (path); }
|
||||
auto_file_ptr (const wchar_t * path, const char * mode) { f = _wfopen (path, msra::strfun::utf16 (mode).c_str()); if (f == NULL) openfailed (msra::strfun::utf8 (path)); }
|
||||
auto_file_ptr (const wchar_t * wpath, const char * mode) { f = _wfopen (wpath, msra::strfun::utf16 (mode).c_str()); if (f == NULL) openfailed (msra::strfun::utf8 (wpath)); }
|
||||
FILE * operator= (FILE * other) { close(); f = other; return f; }
|
||||
auto_file_ptr (FILE * other) : f (other) { }
|
||||
operator FILE * () const { return f; }
|
||||
|
@ -775,6 +969,7 @@ public:
|
|||
};
|
||||
inline int fclose (auto_file_ptr & af) { return af.fclose(); }
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// auto-closing container for Win32 handles.
|
||||
// Pass close function if not CloseHandle(), e.g.
|
||||
// auto_handle h (FindFirstFile(...), FindClose);
|
||||
|
@ -791,6 +986,7 @@ public:
|
|||
operator _H () const { return h; }
|
||||
};
|
||||
typedef auto_handle_t<HANDLE> auto_handle;
|
||||
#endif
|
||||
|
||||
// like auto_ptr but calls freeFunc_p (type free_func_t) instead of delete to clean up
|
||||
// minor difference - wrapped object is T, not T *, so to wrap a
|
||||
|
@ -814,6 +1010,9 @@ public:
|
|||
|
||||
// simple timer
|
||||
// auto_timer timer; run(); double seconds = timer; // now can abandon the object
|
||||
#ifdef __unix__
|
||||
typedef timeval LARGE_INTEGER;
|
||||
#endif
|
||||
class auto_timer
|
||||
{
|
||||
LARGE_INTEGER freq, start;
|
||||
|
@ -821,15 +1020,26 @@ class auto_timer
|
|||
public:
|
||||
auto_timer()
|
||||
{
|
||||
#ifdef _WIN32
|
||||
if (!QueryPerformanceFrequency (&freq)) // count ticks per second
|
||||
throw std::runtime_error ("auto_timer: QueryPerformanceFrequency failure");
|
||||
QueryPerformanceCounter (&start);
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
gettimeofday (&start, NULL);
|
||||
#endif
|
||||
}
|
||||
operator double() const // each read gives time elapsed since start, in seconds
|
||||
{
|
||||
LARGE_INTEGER end;
|
||||
#ifdef _WIN32
|
||||
QueryPerformanceCounter (&end);
|
||||
return (end.QuadPart - start.QuadPart) / (double) freq.QuadPart;
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
gettimeofday (&end,NULL);
|
||||
return (end.tv_sec - start.tv_sec) + (end.tv_usec - start.tv_usec)/(1000*1000);
|
||||
#endif
|
||||
}
|
||||
void show (const std::string & msg) const
|
||||
{
|
||||
|
@ -881,8 +1091,10 @@ public:
|
|||
#define foreach_index(_i,_dat) for (int _i = 0; _i < (int) (_dat).size(); _i++)
|
||||
#define map_array(_x,_expr,_y) { _y.resize (_x.size()); foreach_index(_i,_x) _y[_i]=_expr(_x[_i]); }
|
||||
#define reduce_array(_x,_expr,_y) { foreach_index(_i,_x) _y = (_i==0) ? _x[_i] : _expr(_y,_x[_i]); }
|
||||
#ifdef _WIN32
|
||||
template<class _A,class _F>
|
||||
static void fill_array(_A & a, _F v) { ::fill (a.begin(), a.end(), v); }
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// frequently missing utility functions
|
||||
|
@ -897,7 +1109,11 @@ namespace msra { namespace util {
|
|||
class command_line
|
||||
{
|
||||
int num;
|
||||
#ifdef _WIN32
|
||||
(const wchar_t *) * args;
|
||||
#else
|
||||
const wchar_t ** args;
|
||||
#endif
|
||||
public:
|
||||
command_line (int argc, wchar_t * argv[]) : num (argc), args ((const wchar_t **) argv) { shift(); }
|
||||
inline int size() const { return num; }
|
||||
|
@ -948,6 +1164,7 @@ template<typename FUNCTION> static void attempt (int retries, const FUNCTION & b
|
|||
|
||||
};}; // namespace
|
||||
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// frequently missing Win32 functions
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -988,6 +1205,7 @@ static inline LPWSTR CoTaskMemString (const wchar_t * s)
|
|||
if (p) for (size_t i = 0; i < n; i++) p[i] = s[i];
|
||||
return p;
|
||||
}
|
||||
#endif
|
||||
|
||||
template<class S> static inline void ZeroStruct (S & s) { memset (&s, 0, sizeof (s)); }
|
||||
|
||||
|
@ -1002,9 +1220,7 @@ using namespace msra::basetypes; // for compatibility
|
|||
#pragma warning (pop)
|
||||
|
||||
// RuntimeError - throw a std::runtime_error with a formatted error string
|
||||
#ifdef _MSC_VER
|
||||
__declspec(noreturn)
|
||||
#endif
|
||||
__declspec_noreturn
|
||||
static inline void RuntimeError(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
|
@ -1016,9 +1232,7 @@ static inline void RuntimeError(const char * format, ...)
|
|||
};
|
||||
|
||||
// LogicError - throw a std::logic_error with a formatted error string
|
||||
#ifdef _MSC_VER
|
||||
__declspec(noreturn)
|
||||
#endif
|
||||
__declspec_noreturn
|
||||
static inline void LogicError(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
|
@ -1047,7 +1261,7 @@ public:
|
|||
m_dllName += L".dll";
|
||||
m_hModule = LoadLibrary(m_dllName.c_str());
|
||||
if (m_hModule == NULL)
|
||||
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName));
|
||||
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName).c_str());
|
||||
|
||||
// create a variable of each type just to call the proper templated version
|
||||
return GetProcAddress(m_hModule, proc.c_str());
|
||||
|
@ -1057,14 +1271,37 @@ public:
|
|||
#else
|
||||
class Plugin
|
||||
{
|
||||
private:
|
||||
void *handle;
|
||||
public:
|
||||
Plugin()
|
||||
{
|
||||
handle = NULL;
|
||||
}
|
||||
|
||||
template<class STRING> // accepts char (UTF-8) and wide string
|
||||
void * Load(const STRING & plugin, const std::string & proc)
|
||||
{
|
||||
RuntimeError("Plugins not implemented on Linux yet");
|
||||
return nullptr;
|
||||
string soName = msra::strfun::utf8(plugin);
|
||||
soName = soName + ".so";
|
||||
void *handle = dlopen(soName.c_str(), RTLD_LAZY);
|
||||
if (handle == NULL)
|
||||
RuntimeError("Plugin not found: %s (error: %s)", soName.c_str(), dlerror());
|
||||
return dlsym(handle, proc.c_str());
|
||||
}
|
||||
|
||||
~Plugin()
|
||||
{
|
||||
if (handle != NULL)
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<class F>
|
||||
static inline bool comparator(const pair<int, F>& l, const pair<int, F>& r)
|
||||
{
|
||||
return l.second > r.second;
|
||||
}
|
||||
|
||||
#endif // _BASETYPES_
|
||||
|
|
|
@ -92,30 +92,30 @@ public:
|
|||
// ---------------------------------------------------------------------------
|
||||
// biggrowablevector -- big vector we can push_back to
|
||||
// ---------------------------------------------------------------------------
|
||||
template<typename ELEMTYPE> class biggrowablevector : public growablevectorbase<std::vector<ELEMTYPE>>
|
||||
template<class ELEMTYPE> class biggrowablevector : public growablevectorbase<std::vector<ELEMTYPE>>
|
||||
{
|
||||
public:
|
||||
biggrowablevector() : growablevectorbase (65536) { }
|
||||
biggrowablevector() : growablevectorbase<std::vector<ELEMTYPE>>::growablevectorbase (65536) { }
|
||||
|
||||
template<typename VALTYPE> void push_back (VALTYPE e) // VALTYPE could be an rvalue reference
|
||||
{
|
||||
size_t i = size();
|
||||
resize_without_commit (i + 1);
|
||||
auto & block = getblockptr (i);
|
||||
size_t i = this->size();
|
||||
this->resize_without_commit (i + 1);
|
||||
auto & block = this->getblockptr (i);
|
||||
if (block.get() == NULL)
|
||||
block.reset (new std::vector<ELEMTYPE> (elementsperblock));
|
||||
(*block)[getblockt (i)] = e;
|
||||
block.reset (new std::vector<ELEMTYPE> (this->elementsperblock));
|
||||
(*block)[this->getblockt (i)] = e;
|
||||
}
|
||||
|
||||
ELEMTYPE & operator[] (size_t t) { return getblock(t)[getblockt (t)]; } // get an element
|
||||
const ELEMTYPE & operator[] (size_t t) const { return getblock(t)[getblockt (t)]; } // get an element
|
||||
ELEMTYPE & operator[] (size_t t) { return this->getblock(t)[this->getblockt (t)]; } // get an element
|
||||
const ELEMTYPE & operator[] (size_t t) const { return this->getblock(t)[this->getblockt (t)]; } // get an element
|
||||
|
||||
void resize (const size_t n)
|
||||
{
|
||||
resize_without_commit (n);
|
||||
foreach_index (i, blocks)
|
||||
if (blocks[i].get() == NULL)
|
||||
blocks[i].reset (new std::vector<ELEMTYPE> (elementsperblock));
|
||||
this->resize_without_commit (n);
|
||||
foreach_index (i, this->blocks)
|
||||
if (this->blocks[i].get() == NULL)
|
||||
this->blocks[i].reset (new std::vector<ELEMTYPE> (this->elementsperblock));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -10,7 +10,9 @@
|
|||
#include "basetypes.h" // for attempt()
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
#include "minibatchsourcehelpers.h"
|
||||
#include "ssematrix.h"
|
||||
#ifndef __unix__
|
||||
#include "ssematrix.h" // TODO: why can it not be removed for Windows as well? At least needs a comment here.
|
||||
#endif
|
||||
|
||||
#ifdef LEAKDETECT
|
||||
#include <vld.h> // for memory leak detection
|
||||
|
@ -58,7 +60,7 @@ namespace msra { namespace dbn {
|
|||
unsigned int sampperiod = sampperiods[k];
|
||||
size_t n = numframes[k];
|
||||
msra::files::make_intermediate_dirs (outfile);
|
||||
fprintf (stderr, "saveandflush: writing %d frames to %S\n", n, outfile.c_str());
|
||||
fprintf (stderr, "saveandflush: writing %d frames to %S\n", (int)n, outfile.c_str());
|
||||
msra::dbn::matrixstripe thispred (pred, firstframe, n);
|
||||
// some sanity check for the data we've written
|
||||
const size_t nansinf = thispred.countnaninf();
|
||||
|
@ -171,7 +173,7 @@ namespace msra { namespace dbn {
|
|||
unsigned int sampperiod = sampperiods[index][k];
|
||||
size_t n = numframes[k];
|
||||
msra::files::make_intermediate_dirs (outfile);
|
||||
fprintf (stderr, "saveandflush: writing %d frames to %S\n", n, outfile.c_str());
|
||||
fprintf (stderr, "saveandflush: writing %d frames to %S\n", (int)n, outfile.c_str());
|
||||
msra::dbn::matrixstripe thispred (pred, firstframe, n);
|
||||
// some sanity check for the data we've written
|
||||
const size_t nansinf = thispred.countnaninf();
|
||||
|
|
|
@ -245,7 +245,7 @@ void fflushOrDie (FILE * f)
|
|||
// ----------------------------------------------------------------------------
|
||||
size_t filesize (FILE * f)
|
||||
{
|
||||
#ifdef WIN32
|
||||
#ifdef _WIN32
|
||||
size_t curPos = _ftelli64 (f);
|
||||
if (curPos == -1L)
|
||||
{
|
||||
|
@ -269,6 +269,27 @@ size_t filesize (FILE * f)
|
|||
return len;
|
||||
#else
|
||||
// linux version
|
||||
long curPos = ftell (f);
|
||||
if (curPos == -1L)
|
||||
{
|
||||
RuntimeError ("error determining file position: %s", strerror (errno));
|
||||
}
|
||||
int rc = fseek (f, 0, SEEK_END);
|
||||
if (rc != 0)
|
||||
{
|
||||
RuntimeError ("error seeking to end of file: %s", strerror (errno));
|
||||
}
|
||||
long len = ftell (f);
|
||||
if (len == -1L)
|
||||
{
|
||||
RuntimeError ("error determining file position: %s", strerror (errno));
|
||||
}
|
||||
rc = fseek (f, curPos, SEEK_SET);
|
||||
if (rc != 0)
|
||||
{
|
||||
RuntimeError ("error resetting file position: %s", strerror (errno));
|
||||
}
|
||||
return (size_t) len;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -10,13 +10,28 @@
|
|||
#ifndef _FILEUTIL_
|
||||
#define _FILEUTIL_
|
||||
|
||||
#include "Platform.h"
|
||||
#ifdef _WIN32
|
||||
#include "basetypes.h"
|
||||
#endif
|
||||
#include <stdio.h>
|
||||
#ifdef __WINDOWS__
|
||||
#include <windows.h> // for mmreg.h and FILETIME
|
||||
#include <mmreg.h>
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
#include <algorithm> // for std::find
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include <cctype>
|
||||
#include <errno.h>
|
||||
#include <stdint.h>
|
||||
#include <assert.h>
|
||||
#include <string.h> // for strerror()
|
||||
using namespace std;
|
||||
|
||||
#define SAFE_CLOSE(f) (((f) == NULL) || (fcloseOrDie ((f)), (f) = NULL))
|
||||
|
@ -28,8 +43,8 @@ using namespace std;
|
|||
// not to fclose() such a handle.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
FILE * fopenOrDie (const STRING & pathname, const char * mode);
|
||||
FILE * fopenOrDie (const WSTRING & pathname, const wchar_t * mode);
|
||||
FILE * fopenOrDie (const string & pathname, const char * mode);
|
||||
FILE * fopenOrDie (const wstring & pathname, const wchar_t * mode);
|
||||
|
||||
#ifndef __unix__ // don't need binary/text distinction on unix
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -44,7 +59,9 @@ void fsetmode (FILE * f, char type);
|
|||
// ----------------------------------------------------------------------------
|
||||
|
||||
void freadOrDie (void * ptr, size_t size, size_t count, FILE * f);
|
||||
#ifdef _WIN32
|
||||
void freadOrDie (void * ptr, size_t size, size_t count, const HANDLE f);
|
||||
#endif
|
||||
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, int num, FILE * f) // template for vector<>
|
||||
|
@ -53,12 +70,14 @@ template<class _T>
|
|||
void freadOrDie (_T & data, size_t num, FILE * f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
#ifdef _WIN32
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, int num, const HANDLE f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, size_t num, const HANDLE f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
#endif
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -66,15 +85,19 @@ void freadOrDie (_T & data, size_t num, const HANDLE f) // template for vecto
|
|||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fwriteOrDie (const void * ptr, size_t size, size_t count, FILE * f);
|
||||
#ifdef _WIN32
|
||||
void fwriteOrDie (const void * ptr, size_t size, size_t count, const HANDLE f);
|
||||
#endif
|
||||
|
||||
template<class _T>
|
||||
void fwriteOrDie (const _T & data, FILE * f) // template for vector<>
|
||||
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
#ifdef _WIN32
|
||||
template<class _T>
|
||||
void fwriteOrDie (const _T & data, const HANDLE f) // template for vector<>
|
||||
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
#endif
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -111,6 +134,10 @@ int64_t filesize64 (const wchar_t * pathname);
|
|||
// 32-bit offsets only
|
||||
long fseekOrDie (FILE * f, long offset, int mode = SEEK_SET);
|
||||
#define ftellOrDie ftell
|
||||
// ----------------------------------------------------------------------------
|
||||
// fget/setpos(): seek functions with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
uint64_t fgetpos (FILE * f);
|
||||
void fsetpos (FILE * f, uint64_t pos);
|
||||
|
||||
|
@ -158,27 +185,6 @@ void fskipspace (FILE * F);
|
|||
// fskipNewLine(): skip all white space until end of line incl. the newline
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
|
||||
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
|
||||
STRING fgetline (FILE * f);
|
||||
WSTRING fgetlinew (FILE * f);
|
||||
void fgetline (FILE * f, std::string & s, ARRAY<char> & buf);
|
||||
void fgetline (FILE * f, std::wstring & s, ARRAY<char> & buf);
|
||||
void fgetline (FILE * f, ARRAY<char> & buf);
|
||||
void fgetline (FILE * f, ARRAY<wchar_t> & buf);
|
||||
|
||||
const char * fgetstring (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
const char * fgetstring (const HANDLE f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
wstring fgetwstring (FILE * f);
|
||||
|
||||
const char * fgettoken (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
|
||||
STRING fgettoken (FILE * f);
|
||||
|
||||
void fskipNewline (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputstring(): write a 0-terminated string (terminate if error)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -189,32 +195,75 @@ void fputstring (FILE * f, const std::string &);
|
|||
void fputstring (FILE * f, const wchar_t *);
|
||||
void fputstring (FILE * f, const std::wstring &);
|
||||
|
||||
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
|
||||
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
|
||||
string fgetline (FILE * f);
|
||||
wstring fgetlinew (FILE * f);
|
||||
void fgetline (FILE * f, std::string & s, std::vector<char> & buf);
|
||||
void fgetline (FILE * f, std::wstring & s, std::vector<char> & buf);
|
||||
void fgetline (FILE * f, std::vector<char> & buf);
|
||||
void fgetline (FILE * f, std::vector<wchar_t> & buf);
|
||||
|
||||
const char * fgetstring (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
const char * fgetstring (const HANDLE f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
|
||||
const wchar_t * fgetstring (FILE * f, wchar_t * buf, int size);
|
||||
wstring fgetwstring (FILE * f);
|
||||
string fgetstring (FILE * f);
|
||||
|
||||
const char * fgettoken (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
|
||||
string fgettoken (FILE * f);
|
||||
const wchar_t * fgettoken (FILE * f, wchar_t * buf, int size);
|
||||
wstring fgetwtoken (FILE * f);
|
||||
|
||||
int fskipNewline (FILE * f, bool skip = true);
|
||||
int fskipwNewline (FILE * f, bool skip = true);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputstring(): write a 0-terminated string (terminate if error)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputstring (FILE * f, const char *);
|
||||
#ifdef _WIN32
|
||||
void fputstring (const HANDLE f, const char * str);
|
||||
#endif
|
||||
void fputstring (FILE * f, const std::string &);
|
||||
void fputstring (FILE * f, const wchar_t *);
|
||||
void fputstring (FILE * f, const std::wstring &);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetTag(): read a 4-byte tag & return as a string
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
STRING fgetTag (FILE * f);
|
||||
string fgetTag (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcheckTag (FILE * f, const char * expectedTag);
|
||||
#ifdef _WIN32
|
||||
void fcheckTag (const HANDLE f, const char * expectedTag);
|
||||
void fcheckTag_ascii (FILE * f, const STRING & expectedTag);
|
||||
#endif
|
||||
void fcheckTag_ascii (FILE * f, const string & expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcompareTag(): compare two tags; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcompareTag (const STRING & readTag, const STRING & expectedTag);
|
||||
void fcompareTag (const string & readTag, const string & expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputTag(): write a 4-byte tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputTag (FILE * f, const char * tag);
|
||||
#ifdef _WIN32
|
||||
void fputTag(const HANDLE f, const char * tag);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fskipstring(): skip a 0-terminated string, such as a pad string
|
||||
|
@ -252,10 +301,17 @@ int fgetint24 (FILE * f);
|
|||
// ----------------------------------------------------------------------------
|
||||
|
||||
int fgetint (FILE * f);
|
||||
#ifdef _WIN32
|
||||
int fgetint (const HANDLE f);
|
||||
#endif
|
||||
int fgetint_bigendian (FILE * f);
|
||||
int fgetint_ascii (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetlong(): read an long value
|
||||
// ----------------------------------------------------------------------------
|
||||
long fgetlong (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfloat(): read a float value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -270,6 +326,7 @@ float fgetfloat_ascii (FILE * f);
|
|||
|
||||
double fgetdouble (FILE * f);
|
||||
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetwav(): read an entire .wav file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -283,6 +340,7 @@ void fgetwav (const wstring & fn, ARRAY<short> & wav, int & sampleRate);
|
|||
|
||||
void fputwav (FILE * f, const vector<short> & wav, int sampleRate, int nChannels = 1);
|
||||
void fputwav (const wstring & fn, const vector<short> & wav, int sampleRate, int nChannels = 1);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputbyte(): write a byte value
|
||||
|
@ -307,7 +365,16 @@ void fputint24 (FILE * f, int v);
|
|||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputint (FILE * f, int val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputlong(): write an long value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputlong (FILE * f, long val);
|
||||
|
||||
#ifdef _WIN32
|
||||
void fputint (const HANDLE f, int v);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfloat(): write a float value
|
||||
|
@ -320,27 +387,154 @@ void fputfloat (FILE * f, float val);
|
|||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputdouble (FILE * f, double val);
|
||||
// template versions of put/get functions for binary files
|
||||
template <typename T>
|
||||
void fput(FILE * f, T v)
|
||||
{
|
||||
fwriteOrDie (&v, sizeof (v), 1, f);
|
||||
}
|
||||
|
||||
|
||||
// template versions of put/get functions for binary files
|
||||
template <typename T>
|
||||
void fget(FILE * f, T& v)
|
||||
{
|
||||
freadOrDie ((void *)&v, sizeof (v), 1, f);
|
||||
}
|
||||
|
||||
|
||||
// GetFormatString - get the format string for a particular type
|
||||
template <typename T>
|
||||
const wchar_t* GetFormatString(T /*t*/)
|
||||
{
|
||||
// if this _ASSERT goes off it means that you are using a type that doesn't have
|
||||
// a read and/or write routine.
|
||||
// If the type is a user defined class, you need to create some global functions that handles file in/out.
|
||||
// for example:
|
||||
//File& operator>>(File& stream, MyClass& test);
|
||||
//File& operator<<(File& stream, MyClass& test);
|
||||
//
|
||||
// in your class you will probably want to add these functions as friends so you can access any private members
|
||||
// friend File& operator>>(File& stream, MyClass& test);
|
||||
// friend File& operator<<(File& stream, MyClass& test);
|
||||
//
|
||||
// if you are using wchar_t* or char* types, these use other methods because they require buffers to be passed
|
||||
// either use std::string and std::wstring, or use the WriteString() and ReadString() methods
|
||||
assert(false); // need a specialization
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// GetFormatString - specalizations to get the format string for a particular type
|
||||
template <> const wchar_t* GetFormatString(char);
|
||||
template <> const wchar_t* GetFormatString(wchar_t);
|
||||
template <> const wchar_t* GetFormatString(short);
|
||||
template <> const wchar_t* GetFormatString(int);
|
||||
template <> const wchar_t* GetFormatString(long);
|
||||
template <> const wchar_t* GetFormatString(unsigned short);
|
||||
template <> const wchar_t* GetFormatString(unsigned int);
|
||||
template <> const wchar_t* GetFormatString(unsigned long);
|
||||
template <> const wchar_t* GetFormatString(float);
|
||||
template <> const wchar_t* GetFormatString(double);
|
||||
template <> const wchar_t* GetFormatString(size_t);
|
||||
template <> const wchar_t* GetFormatString(long long);
|
||||
template <> const wchar_t* GetFormatString(const char*);
|
||||
template <> const wchar_t* GetFormatString(const wchar_t*);
|
||||
|
||||
// GetScanFormatString - get the format string for a particular type
|
||||
template <typename T>
|
||||
const wchar_t* GetScanFormatString(T t)
|
||||
{
|
||||
assert(false); // need a specialization
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// GetScanFormatString - specalizations to get the format string for a particular type
|
||||
template <> const wchar_t* GetScanFormatString(char);
|
||||
template <> const wchar_t* GetScanFormatString(wchar_t);
|
||||
template <> const wchar_t* GetScanFormatString(short);
|
||||
template <> const wchar_t* GetScanFormatString(int);
|
||||
template <> const wchar_t* GetScanFormatString(long);
|
||||
template <> const wchar_t* GetScanFormatString(unsigned short);
|
||||
template <> const wchar_t* GetScanFormatString(unsigned int);
|
||||
template <> const wchar_t* GetScanFormatString(unsigned long);
|
||||
template <> const wchar_t* GetScanFormatString(float);
|
||||
template <> const wchar_t* GetScanFormatString(double);
|
||||
template <> const wchar_t* GetScanFormatString(size_t);
|
||||
template <> const wchar_t* GetScanFormatString(long long);
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetText(): get a value from a text file
|
||||
// ----------------------------------------------------------------------------
|
||||
template <typename T>
|
||||
void fgetText(FILE * f, T& v)
|
||||
{
|
||||
int rc = ftrygetText(f, v);
|
||||
if (rc == 0)
|
||||
throw std::runtime_error("error reading value from file (invalid format)");
|
||||
else if (rc == EOF)
|
||||
throw std::runtime_error(std::string("error reading from file: ") + strerror(errno));
|
||||
assert(rc == 1);
|
||||
}
|
||||
|
||||
// version to try and get a string, and not throw exceptions if contents don't match
|
||||
template <typename T>
|
||||
int ftrygetText(FILE * f, T& v)
|
||||
{
|
||||
const wchar_t* formatString = GetScanFormatString<T>(v);
|
||||
int rc = fwscanf (f, formatString, &v);
|
||||
assert(rc == 1 || rc == 0);
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <> int ftrygetText<bool>(FILE * f, bool& v);
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetText() specializations for fwscanf_s differences: get a value from a text file
|
||||
// ----------------------------------------------------------------------------
|
||||
void fgetText(FILE * f, char& v);
|
||||
void fgetText(FILE * f, wchar_t& v);
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputText(): write a value out as text
|
||||
// ----------------------------------------------------------------------------
|
||||
template <typename T>
|
||||
void fputText(FILE * f, T v)
|
||||
{
|
||||
const wchar_t* formatString = GetFormatString(v);
|
||||
int rc = fwprintf(f, formatString, v);
|
||||
if (rc == 0)
|
||||
throw std::runtime_error("error writing value to file, no values written");
|
||||
else if (rc < 0)
|
||||
throw std::runtime_error(std::string("error writing to file: ") + strerror(errno));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputText(): write a bool out as character
|
||||
// ----------------------------------------------------------------------------
|
||||
template <> void fputText<bool>(FILE * f, bool v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfile(): write a binary block or a string as a file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputfile (const WSTRING & pathname, const ARRAY<char> & buffer);
|
||||
void fputfile (const WSTRING & pathname, const std::wstring & string);
|
||||
void fputfile (const WSTRING & pathname, const std::string & string);
|
||||
void fputfile (const wstring & pathname, const std::vector<char> & buffer);
|
||||
void fputfile (const wstring & pathname, const std::wstring & string);
|
||||
void fputfile (const wstring & pathname, const std::string & string);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfile(): load a file as a binary block
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fgetfile (const WSTRING & pathname, ARRAY<char> & buffer);
|
||||
void fgetfile (FILE * f, ARRAY<char> & buffer);
|
||||
void fgetfile (const wstring & pathname, std::vector<char> & buffer);
|
||||
void fgetfile (FILE * f, std::vector<char> & buffer);
|
||||
namespace msra { namespace files {
|
||||
void fgetfilelines (const std::wstring & pathname, vector<char> & readbuffer, std::vector<std::string> & lines);
|
||||
static inline std::vector<std::string> fgetfilelines (const std::wstring & pathname) { vector<char> buffer; std::vector<std::string> lines; fgetfilelines (pathname, buffer, lines); return lines; }
|
||||
vector<char*> fgetfilelines (const wstring & pathname, vector<char> & readbuffer);
|
||||
};};
|
||||
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// getfiletime(), setfiletime(): access modification time
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -348,6 +542,7 @@ namespace msra { namespace files {
|
|||
bool getfiletime (const std::wstring & path, FILETIME & time);
|
||||
void setfiletime (const std::wstring & path, const FILETIME & time);
|
||||
|
||||
#endif
|
||||
// ----------------------------------------------------------------------------
|
||||
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -370,6 +565,7 @@ namespace msra { namespace files {
|
|||
bool fuptodate (const wstring & target, const wstring & input, bool inputrequired = true);
|
||||
};};
|
||||
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// simple support for WAV file I/O
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -408,7 +604,8 @@ void fputwfx (FILE *f, const WAVEFORMATEX & wfx, unsigned int numSamples);
|
|||
// For example, data[i][j]: i is channel index, 0 means the first
|
||||
// channel. j is sample index.
|
||||
// ----------------------------------------------------------------------------
|
||||
void fgetraw (FILE *f,ARRAY< ARRAY<short> > & data,const WAVEHEADER & wavhd);
|
||||
void fgetraw (FILE *f,std::vector< std::vector<short> > & data,const WAVEHEADER & wavhd);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// temp functions -- clean these up
|
||||
|
@ -445,4 +642,23 @@ static inline bool relpath (const wchar_t * path)
|
|||
template<class CHAR>
|
||||
static inline bool relpath (const std::basic_string<CHAR> & s) { return relpath (s.c_str()); }
|
||||
|
||||
// trim from start
|
||||
static inline std::string <rim(std::string &s) {
|
||||
s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun<int, int>(std::isspace))));
|
||||
return s;
|
||||
}
|
||||
|
||||
// trim from end
|
||||
static inline std::string &rtrim(std::string &s) {
|
||||
s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun<int, int>(std::isspace))).base(), s.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
// trim from both ends
|
||||
static inline std::string &trim(std::string &s) {
|
||||
return ltrim(rtrim(s));
|
||||
}
|
||||
|
||||
vector<string> sep_string(const string & str, const string & sep);
|
||||
|
||||
#endif // _FILEUTIL_
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
#include <string>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <hash_map>
|
||||
#include <unordered_map>
|
||||
#include <stdint.h>
|
||||
#include <limits.h>
|
||||
#include <wchar.h>
|
||||
|
||||
namespace msra { namespace asr {
|
||||
|
||||
|
@ -263,9 +265,11 @@ public:
|
|||
#else
|
||||
W.close (numframes);
|
||||
#endif
|
||||
#ifdef _WIN32 // BUGBUG: and on Linux??
|
||||
// rename to final destination
|
||||
// (This would only fail in strange circumstances such as accidental multiple processes writing to the same file.)
|
||||
renameOrDie (tmppath, path);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -386,7 +390,7 @@ private:
|
|||
{
|
||||
wstring physpath = ppath.physicallocation();
|
||||
//auto_file_ptr f = fopenOrDie (physpath, L"rbS");
|
||||
auto_file_ptr f = fopenOrDie (physpath, L"rb"); // removed 'S' for now, as we mostly run local anyway, and this will speed up debugging
|
||||
auto_file_ptr f(fopenOrDie (physpath, L"rb")); // removed 'S' for now, as we mostly run local anyway, and this will speed up debugging
|
||||
|
||||
// read the header (12 bytes for htk feature files)
|
||||
fileheader H;
|
||||
|
@ -655,7 +659,7 @@ private:
|
|||
public:
|
||||
|
||||
// parse format with original HTK state align MLF format and state list
|
||||
void parsewithstatelist (const vector<char*> & toks, const hash_map<const string, size_t> & statelisthash, const double htkTimeToFrame)
|
||||
void parsewithstatelist (const vector<char*> & toks, const unordered_map<std::string, size_t> & statelisthash, const double htkTimeToFrame)
|
||||
{
|
||||
size_t ts, te;
|
||||
parseframerange (toks, ts, te, htkTimeToFrame);
|
||||
|
@ -682,7 +686,7 @@ template<class ENTRY, class WORDSEQUENCE>
|
|||
class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
|
||||
{
|
||||
wstring curpath; // for error messages
|
||||
hash_map<const std::string, size_t> statelistmap; // for state <=> index
|
||||
unordered_map<std::string, size_t> statelistmap; // for state <=> index
|
||||
map<wstring,WORDSEQUENCE> wordsequences; // [key] word sequences (if we are building word entries as well, for MMI)
|
||||
|
||||
void strtok (char * s, const char * delim, vector<char*> & toks)
|
||||
|
@ -700,7 +704,7 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
|
|||
vector<char*> readlines (const wstring & path, vector<char> & buffer)
|
||||
{
|
||||
// load it into RAM in one huge chunk
|
||||
auto_file_ptr f = fopenOrDie (path, L"rb");
|
||||
auto_file_ptr f(fopenOrDie (path, L"rb"));
|
||||
size_t len = filesize (f);
|
||||
buffer.reserve (len +1);
|
||||
freadOrDie (buffer, len, f);
|
||||
|
@ -752,7 +756,12 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
|
|||
|
||||
filename = filename.substr (1, filename.length() -2); // strip quotes
|
||||
if (filename.find ("*/") == 0) filename = filename.substr (2);
|
||||
#ifdef _WIN32
|
||||
wstring key = msra::strfun::utf16 (regex_replace (filename, regex ("\\.[^\\.\\\\/:]*$"), string())); // delete extension (or not if none)
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
wstring key = msra::strfun::utf16 (removeExtension(basename(filename))); // note that c++ 4.8 is incomplete for supporting regex
|
||||
#endif
|
||||
|
||||
// determine lines range
|
||||
size_t s = line;
|
||||
|
@ -785,7 +794,7 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
|
|||
const char * w = toks[6]; // the word name
|
||||
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
|
||||
size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t) wid;
|
||||
wordseqbuffer.push_back (WORDSEQUENCE::word (wordindex, entries[i-s].firstframe, alignseqbuffer.size()));
|
||||
wordseqbuffer.push_back (typename WORDSEQUENCE::word (wordindex, entries[i-s].firstframe, alignseqbuffer.size()));
|
||||
}
|
||||
if (unitmap)
|
||||
{
|
||||
|
@ -796,7 +805,7 @@ class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
|
|||
if (iter == unitmap->end())
|
||||
throw std::runtime_error (string ("parseentry: unknown unit ") + u + " in utterance " + strfun::utf8 (key));
|
||||
const size_t uid = iter->second;
|
||||
alignseqbuffer.push_back (WORDSEQUENCE::aligninfo (uid, 0/*#frames--we accumulate*/));
|
||||
alignseqbuffer.push_back (typename WORDSEQUENCE::aligninfo (uid, 0/*#frames--we accumulate*/));
|
||||
}
|
||||
if (alignseqbuffer.empty())
|
||||
throw std::runtime_error ("parseentry: lonely senone entry at start without phone/word entry found, for utterance " + strfun::utf8 (key));
|
||||
|
@ -880,7 +889,7 @@ public:
|
|||
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
void read (const wstring & path, const set<wstring> & restricttokeys, const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, const double htkTimeToFrame)
|
||||
{
|
||||
if (!restricttokeys.empty() && size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
|
||||
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
|
||||
return;
|
||||
|
||||
fprintf (stderr, "htkmlfreader: reading MLF file %S ...", path.c_str());
|
||||
|
@ -888,18 +897,18 @@ public:
|
|||
|
||||
vector<char> buffer; // buffer owns the characters--don't release until done
|
||||
vector<char*> lines = readlines (path, buffer);
|
||||
vector<WORDSEQUENCE::word> wordsequencebuffer;
|
||||
vector<WORDSEQUENCE::aligninfo> alignsequencebuffer;
|
||||
vector<typename WORDSEQUENCE::word> wordsequencebuffer;
|
||||
vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
|
||||
|
||||
if (lines.empty() || strcmp (lines[0], "#!MLF!#")) malformed ("header missing");
|
||||
|
||||
// parse entries
|
||||
size_t line = 1;
|
||||
while (line < lines.size() && (restricttokeys.empty() || size() < restricttokeys.size()))
|
||||
while (line < lines.size() && (restricttokeys.empty() || this->size() < restricttokeys.size()))
|
||||
parseentry (lines, line, restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
|
||||
|
||||
curpath.clear();
|
||||
fprintf (stderr, " total %lu entries\n", size());
|
||||
fprintf (stderr, " total %lu entries\n", this->size());
|
||||
}
|
||||
|
||||
// read state list, index is from 0
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <hash_map>
|
||||
#include <unordered_map>
|
||||
#include <regex>
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
|
@ -95,20 +95,6 @@ static size_t tryfind (const MAPTYPE & map, const KEYTYPE & key, VALTYPE deflt)
|
|||
const msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence> & labels, // non-empty: build numer lattices
|
||||
const msra::lm::CMGramLM & unigram, const msra::lm::CSymbolSet & unigramsymbols) // for numer lattices
|
||||
{
|
||||
#if 0 // little unit test helper for testing the read function
|
||||
bool test = true;
|
||||
if (test)
|
||||
{
|
||||
archive a;
|
||||
a.open (outpath + L".toc");
|
||||
lattice L;
|
||||
std::hash_map<string,size_t> symmap;
|
||||
a.getlattice (L"sw2001_A_1263622500_1374610000", L, symmap);
|
||||
a.getlattice (L"sw2001_A_1391162500_1409287500", L, symmap);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const bool numermode = !labels.empty(); // if labels are passed then we shall convert the MLFs to lattices, and 'infiles' are regular keys
|
||||
|
||||
const std::wstring tocpath = outpath + L".toc";
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
#undef HACK_IN_SILENCE // [v-hansu] hack to simulate DEL in the lattice
|
||||
#define SILENCE_PENALTY // give penalty to added silence
|
||||
|
||||
#define __STDC_FORMAT_MACROS
|
||||
#include <inttypes.h>
|
||||
|
||||
#include "basetypes.h"
|
||||
#include "latticestorage.h"
|
||||
|
@ -20,11 +22,9 @@
|
|||
#include <stdint.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <hash_map>
|
||||
#include <unordered_map>
|
||||
#include <algorithm> // for find()
|
||||
#include "simplesenonehmm.h"
|
||||
|
||||
namespace msra { namespace math { class ssematrixbase; template<class ssematrixbase> class ssematrix; template<class ssematrixbase> class ssematrixstriperef; };};
|
||||
|
||||
namespace msra { namespace lm { class CMGramLM; class CSymbolSet; };}; // for numer-lattice building
|
||||
|
@ -60,7 +60,7 @@ class lattice
|
|||
size_t impliedspunitid : 31; // id of implied last unit (intended as /sp/); only used in V2
|
||||
size_t hasacscores : 1; // if 1 then ac scores are embedded
|
||||
|
||||
header_v1_v2() : numnodes (0), numedges (0), lmf (1.0f), wp (0.0f), frameduration (0.01/*assumption*/), numframes (0), impliedspunitid (SIZE_MAX), hasacscores (1) { }
|
||||
header_v1_v2() : numnodes (0), numedges (0), lmf (1.0f), wp (0.0f), frameduration (0.01/*assumption*/), numframes (0), impliedspunitid (INT_MAX), hasacscores (1) { }
|
||||
};
|
||||
header_v1_v2 info; // information about the lattice
|
||||
static const unsigned int NOEDGE = 0xffffff; // 24 bits
|
||||
|
@ -188,7 +188,7 @@ public: // TODO: make private again once
|
|||
if (ai.size() < 2) // less than 2--must be /sil/
|
||||
continue;
|
||||
spunit = ai[ai.size() - 1].unit;
|
||||
fprintf (stderr, "builduniquealignments: /sp/ unit inferred through heuristics as %d\n", spunit);
|
||||
fprintf (stderr, "builduniquealignments: /sp/ unit inferred through heuristics as %d\n", (int)spunit);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -235,7 +235,7 @@ public: // TODO: make private again once
|
|||
&& nodes[edges[prevj].E].t == nodes[edges[j].E].t
|
||||
&& edges[prevj].l != edges[j].l) // some diagnostics
|
||||
fprintf (stderr, "build: merging edges %d and %d despite slightly different LM scores %.8f vs. %.8f, ts/te=%.2f/%.2f\n",
|
||||
prevj, j, edges[prevj].l, edges[j].l, nodes[edges[prevj].S].t * 0.01f, nodes[edges[prevj].E].t * 0.01f);
|
||||
(int)prevj, (int)j, edges[prevj].l, edges[j].l, nodes[edges[prevj].S].t * 0.01f, nodes[edges[prevj].E].t * 0.01f);
|
||||
#endif
|
||||
if (prevj == SIZE_MAX || fabs (edges[prevj].l - edges[j].l) > lmargin || (info.hasacscores && edges[prevj].a != edges[j].a) || comparealign (prevj, j, false) != 0)
|
||||
{
|
||||
|
@ -287,8 +287,8 @@ public: // TODO: make private again once
|
|||
const size_t uniquealigntokens = uniquededgedatatokens.size() - (numuniquealignments * (info.hasacscores ? 2 : 1));
|
||||
const size_t nonuniquenonsptokens = align.size() - numimpliedsp;
|
||||
fprintf (stderr, "builduniquealignments: %d edges: %d unique alignments (%.2f%%); %d align tokens - %d implied /sp/ units = %d, uniqued to %d (%.2f%%)\n",
|
||||
edges.size(), numuniquealignments, 100.0f * numuniquealignments / edges.size(),
|
||||
align.size(), numimpliedsp, nonuniquenonsptokens, uniquealigntokens, 100.0f * uniquealigntokens / nonuniquenonsptokens);
|
||||
(int)edges.size(), (int)numuniquealignments, 100.0f * numuniquealignments / edges.size(),
|
||||
(int)align.size(), (int)numimpliedsp, (int)nonuniquenonsptokens, (int)uniquealigntokens, 100.0f * uniquealigntokens / nonuniquenonsptokens);
|
||||
|
||||
// sort it back into original order (sorted by E, then by S)
|
||||
sort (edges2.begin(), edges2.end(), [&] (const edgeinfo & e1, const edgeinfo & e2) { return latticeorder (e1, e2) < 0; });
|
||||
|
@ -507,7 +507,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
typedef aligninfo aligninfo; // now we can access it as htkmlfwordsequence::aligninfo although it comes from some totally other corner of the system
|
||||
typedef msra::lattices::aligninfo aligninfo; // now we can access it as htkmlfwordsequence::aligninfo although it comes from some totally other corner of the system
|
||||
|
||||
std::vector<word> words;
|
||||
std::vector<aligninfo> align;
|
||||
|
@ -593,7 +593,7 @@ private:
|
|||
#if 1 // multiple /sil/ -> log this (as we are not sure whether this is actually proper--probably it is)
|
||||
if (numsilunits > 1)
|
||||
{
|
||||
fprintf (stderr, "backpointers: lattice '%S', edge %d has %d /sil/ phonemes\n", L.getkey(), j, numsilunits);
|
||||
fprintf (stderr, "backpointers: lattice '%S', edge %d has %d /sil/ phonemes\n", L.getkey(), j, (int)numsilunits);
|
||||
fprintf (stderr, "alignments: :");
|
||||
foreach_index (a, aligntokens)
|
||||
{
|
||||
|
@ -643,9 +643,9 @@ private:
|
|||
double bestpathlattice (const std::vector<float> & edgeacscores, std::vector<double> & logpps,
|
||||
const float lmf, const float wp, const float amf) const;
|
||||
|
||||
static float lattice::alignedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
|
||||
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
|
||||
size_t edgeindex, const bool returnsenoneids, array_ref<unsigned short> thisedgealignments);
|
||||
static float alignedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
|
||||
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
|
||||
size_t edgeindex, const bool returnsenoneids, array_ref<unsigned short> thisedgealignments);
|
||||
|
||||
const_array_ref<aligninfo> getaligninfo (size_t j) const { size_t begin = (size_t) edges[j].firstalign; size_t end = j+1 < edges.size() ? (size_t) edges[j+1].firstalign : align.size(); return const_array_ref<aligninfo> (align.data() + begin, end - begin); }
|
||||
|
||||
|
@ -674,9 +674,9 @@ private:
|
|||
const std::vector<float> & transcriptunigrams, const msra::math::ssematrixbase & logLLs,
|
||||
const msra::asr::simplesenonehmm & hset, const float lmf, const float wp, const float amf);
|
||||
|
||||
static float lattice::forwardbackwardedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
|
||||
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
|
||||
size_t edgeindex);
|
||||
static float forwardbackwardedge (const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm & hset,
|
||||
const msra::math::ssematrixbase & logLLs, msra::math::ssematrixbase & gammas,
|
||||
size_t edgeindex);
|
||||
|
||||
double forwardbackwardlattice (const std::vector<float> & edgeacscores, parallelstate & parallelstate,
|
||||
std::vector<double> & logpps, std::vector<double> & logalphas, std::vector<double> & logbetas,
|
||||
|
@ -747,7 +747,7 @@ public:
|
|||
for (size_t j = 0; j < info.numedges; j++)
|
||||
totaledgeframes += nodes[edges[j].E].t - (size_t) nodes[edges[j].S].t;
|
||||
fprintf (stderr, "lattice: read %d nodes, %d edges, %d units, %d frames, %.1f edges/node, %.1f units/edge, %.1f frames/edge, density %.1f\n",
|
||||
info.numnodes, info.numedges, align.size(), info.numframes,
|
||||
(int)info.numnodes, (int)info.numedges, (int)align.size(), (int)info.numframes,
|
||||
info.numedges / (double) info.numnodes, align.size() / (double) info.numedges, totaledgeframes / (double) info.numedges, totaledgeframes / (double) info.numframes);
|
||||
}
|
||||
|
||||
|
@ -895,7 +895,7 @@ public:
|
|||
#if 1 // post-bugfix for incorrect inference of spunit
|
||||
if (info.impliedspunitid != SIZE_MAX && info.impliedspunitid >= idmap.size()) // we have buggy lattices like that--what do they mean??
|
||||
{
|
||||
fprintf (stderr, "fread: detected buggy spunit id %d which is out of range (%d entries in map)\n", info.impliedspunitid, idmap.size());
|
||||
fprintf (stderr, "fread: detected buggy spunit id %d which is out of range (%d entries in map)\n", (int)info.impliedspunitid, (int)idmap.size());
|
||||
throw std::runtime_error ("fread: out of bounds spunitid");
|
||||
}
|
||||
#endif
|
||||
|
@ -949,7 +949,7 @@ public:
|
|||
k += skipscoretokens;
|
||||
uniquealignments++;
|
||||
}
|
||||
fprintf (stderr, "fread: mapped %d unique alignments\n", uniquealignments);
|
||||
fprintf (stderr, "fread: mapped %d unique alignments\n", (int)uniquealignments);
|
||||
}
|
||||
if (info.impliedspunitid != spunit)
|
||||
{
|
||||
|
@ -1078,7 +1078,7 @@ class archive
|
|||
|
||||
mutable size_t currentarchiveindex; // which archive is open
|
||||
mutable auto_file_ptr f; // cached archive file handle of currentarchiveindex
|
||||
hash_map<std::wstring,latticeref> toc; // [key] -> (file, offset) --table of content (.toc file)
|
||||
unordered_map<std::wstring, latticeref> toc; // [key] -> (file, offset) --table of content (.toc file)
|
||||
public:
|
||||
// construct = open the archive
|
||||
//archive() : currentarchiveindex (SIZE_MAX) {}
|
||||
|
@ -1091,13 +1091,13 @@ public:
|
|||
{
|
||||
if (tocpaths.empty()) // nothing to read--keep silent
|
||||
return;
|
||||
fprintf (stderr, "archive: opening %d lattice-archive TOC files ('%S' etc.)..", tocpaths.size(), tocpaths[0].c_str());
|
||||
fprintf (stderr, "archive: opening %d lattice-archive TOC files ('%S' etc.)..", (int)tocpaths.size(), tocpaths[0].c_str());
|
||||
foreach_index (i, tocpaths)
|
||||
{
|
||||
fprintf (stderr, ".");
|
||||
open (tocpaths[i]);
|
||||
}
|
||||
fprintf (stderr, " %d total lattices referenced in %d archive files\n", toc.size(), archivepaths.size());
|
||||
fprintf (stderr, " %d total lattices referenced in %d archive files\n", (int)toc.size(), (int)archivepaths.size());
|
||||
}
|
||||
|
||||
// open an archive
|
||||
|
@ -1133,7 +1133,12 @@ public:
|
|||
throw std::runtime_error ("open: invalid TOC line (empty archive pathname): " + std::string (line));
|
||||
char c;
|
||||
uint64_t offset;
|
||||
#ifdef _WIN32
|
||||
if (sscanf_s (q, "[%I64u]%c", &offset, &c, sizeof (c)) != 1)
|
||||
#else
|
||||
|
||||
if (sscanf (q, "[%" PRIu64 "]%c", &offset, &c) != 1)
|
||||
#endif
|
||||
throw std::runtime_error ("open: invalid TOC line (bad [] expression): " + std::string (line));
|
||||
if (!toc.insert (make_pair (key, latticeref (offset, archiveindex))).second)
|
||||
throw std::runtime_error ("open: TOC entry leads to duplicate key: " + std::string (line));
|
||||
|
|
|
@ -25,7 +25,7 @@ static void checkoverflow (size_t fieldval, size_t targetval, const char * field
|
|||
if (fieldval != targetval)
|
||||
{
|
||||
char buf[1000];
|
||||
sprintf_s (buf, "lattice: bit field %s too small for value 0x%x (cut from 0x%x)", fieldname, targetval, fieldval);
|
||||
sprintf_s (buf, sizeof(buf), "lattice: bit field %s too small for value 0x%x (cut from 0x%x)", fieldname, (unsigned int)targetval, (unsigned int)fieldval);
|
||||
throw std::runtime_error (buf);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -209,7 +209,7 @@ public:
|
|||
{
|
||||
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
|
||||
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
|
||||
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
|
||||
(int)epoch, (int)epochstartframe, (int)epochendframe, (int)firstvalidepochstartframe, (int)subsetnum, (int)numsubsets, (int)datapasses);
|
||||
mbstartframe = firstvalidepochstartframe;
|
||||
datapass = 0;
|
||||
fillorclear(); // get the first batch
|
||||
|
@ -228,7 +228,7 @@ public:
|
|||
{
|
||||
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
|
||||
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
|
||||
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
|
||||
(int)epoch, (int)epochstartframe, (int)epochendframe, (int)firstvalidepochstartframe, (int)subsetnum, (int)numsubsets, (int)datapasses);
|
||||
mbstartframe = firstvalidepochstartframe;
|
||||
datapass = 0;
|
||||
fillorclear(); // get the first batch
|
||||
|
@ -253,7 +253,7 @@ public:
|
|||
{
|
||||
mbstartframe = firstvalidepochstartframe;
|
||||
datapass++;
|
||||
fprintf (stderr, "\nminibatchiterator: entering %d-th repeat pass through the data\n", datapass+1);
|
||||
fprintf (stderr, "\nminibatchiterator: entering %d-th repeat pass through the data\n", (int)(datapass+1));
|
||||
}
|
||||
fillorclear();
|
||||
}
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#ifndef __unix__
|
||||
#include "ssematrix.h" // for matrix type
|
||||
#endif
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
|
@ -246,7 +248,7 @@ public:
|
|||
retries++;
|
||||
}
|
||||
}
|
||||
fprintf (stderr, "randomordering: %d retries for %d elements (%.1f%%) to ensure window condition\n", retries, map.size(), 100.0 * retries / map.size());
|
||||
fprintf (stderr, "randomordering: %d retries for %d elements (%.1f%%) to ensure window condition\n", (int)retries, (int)map.size(), 100.0 * retries / map.size());
|
||||
// ensure the window condition
|
||||
foreach_index (t, map) assert ((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2);
|
||||
#if 1 // and a live check since I don't trust myself here yet
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "fileutil.h" // for opening/reading the ARPA file
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <hash_map>
|
||||
#include <unordered_map>
|
||||
#include <algorithm> // for various sort() calls
|
||||
#include <math.h>
|
||||
|
||||
|
@ -22,8 +22,9 @@ namespace msra { namespace lm {
|
|||
// core LM interface -- LM scores are accessed through this exclusively
|
||||
// ===========================================================================
|
||||
|
||||
interface ILM // generic interface -- mostly the score() function
|
||||
class ILM // generic interface -- mostly the score() function
|
||||
{
|
||||
public:
|
||||
virtual double score (const int * mgram, int m) const = 0;
|
||||
virtual bool oov (int w) const = 0; // needed for perplexity calculation
|
||||
// ... TODO (?): return true/false to indicate whether anything changed.
|
||||
|
@ -31,8 +32,9 @@ interface ILM // generic interface -- mostly the score() function
|
|||
virtual void adapt (const int * data, size_t m) = 0; // (NULL,M) to reset, (!NULL,0) to flush
|
||||
|
||||
// iterator for composing models --iterates in increasing order w.r.t. w
|
||||
interface IIter
|
||||
class IIter
|
||||
{
|
||||
public:
|
||||
virtual operator bool() const = 0; // has iterator not yet reached end?
|
||||
// ... TODO: ensure iterators do not return OOVs w.r.t. user symbol table
|
||||
// (It needs to be checked which LM type's iterator currently does.)
|
||||
|
@ -85,7 +87,7 @@ static inline double invertlogprob (double logP) { return logclip (1.0 - exp (lo
|
|||
// CSymbolSet -- a simple symbol table
|
||||
// ===========================================================================
|
||||
|
||||
// compare function to allow char* as keys (without, hash_map will correctly
|
||||
// compare function to allow char* as keys (without, unordered_map will correctly
|
||||
// compute a hash key from the actual strings, but then compare the pointers
|
||||
// -- duh!)
|
||||
struct less_strcmp : public binary_function<const char *, const char *, bool>
|
||||
|
@ -94,7 +96,7 @@ struct less_strcmp : public binary_function<const char *, const char *, bool>
|
|||
{ return strcmp (_Left, _Right) < 0; }
|
||||
};
|
||||
|
||||
class CSymbolSet : public stdext::hash_map<const char *, int, stdext::hash_compare<const char*,less_strcmp>>
|
||||
class CSymbolSet : public std::unordered_map<const char *, int, std::hash<const char*>, less_strcmp>
|
||||
{
|
||||
vector<const char *> symbols; // the symbols
|
||||
|
||||
|
@ -106,14 +108,14 @@ public:
|
|||
void clear()
|
||||
{
|
||||
foreach_index (i, symbols) free ((void*) symbols[i]);
|
||||
hash_map::clear();
|
||||
unordered_map::clear();
|
||||
}
|
||||
|
||||
// operator[key] on a 'const' object
|
||||
// get id for an existing word, returns -1 if not existing
|
||||
int operator[] (const char * key) const
|
||||
{
|
||||
hash_map<const char *,int>::const_iterator iter = find (key);
|
||||
unordered_map<const char *, int>::const_iterator iter = find(key);
|
||||
return (iter != end()) ? iter->second : -1;
|
||||
}
|
||||
|
||||
|
@ -121,14 +123,18 @@ public:
|
|||
// determine unique id for a word ('key')
|
||||
int operator[] (const char * key)
|
||||
{
|
||||
hash_map<const char *,int>::const_iterator iter = find (key);
|
||||
unordered_map<const char *, int>::const_iterator iter = find(key);
|
||||
if (iter != end())
|
||||
return iter->second;
|
||||
|
||||
// create
|
||||
const char * p = _strdup (key);
|
||||
if (!p)
|
||||
#ifdef _WIN32
|
||||
throw std::bad_exception ("CSymbolSet:id string allocation failure");
|
||||
#else
|
||||
throw std::bad_exception ();
|
||||
#endif
|
||||
try
|
||||
{
|
||||
int id = (int) symbols.size();
|
||||
|
@ -274,7 +280,7 @@ class mgram_map
|
|||
{
|
||||
typedef unsigned int index_t; // (-> size_t when we really need it)
|
||||
//typedef size_t index_t; // (tested once, seems to work)
|
||||
static const index_t nindex = (index_t) -1; // invalid index
|
||||
static const index_t nindex; // invalid index
|
||||
// entry [m][i] is first index of children in level m+1, entry[m][i+1] the end.
|
||||
int M; // order, e.g. M=3 for trigram
|
||||
std::vector<std::vector<index_t>> firsts; // [M][i] ([0] = zerogram = root)
|
||||
|
@ -1124,7 +1130,7 @@ public:
|
|||
void read (const std::wstring & pathname, SYMMAP & userSymMap, bool filterVocabulary, int maxM)
|
||||
{
|
||||
int lineNo = 0;
|
||||
msra::basetypes::auto_file_ptr f = fopenOrDie (pathname, L"rbS");
|
||||
msra::basetypes::auto_file_ptr f(fopenOrDie (pathname, L"rbS"));
|
||||
fprintf (stderr, "read: reading %S", pathname.c_str());
|
||||
filename = pathname; // (keep this info for debugging)
|
||||
|
||||
|
@ -1769,7 +1775,7 @@ protected:
|
|||
mcounts.push_back (mmap.create (newkey, mmapCache), count); // store 'count' under 'key'
|
||||
}
|
||||
}
|
||||
fprintf (stderr, " %d %d-grams", mcounts.size (m), m);
|
||||
fprintf (stderr, " %d %d-grams", (int)mcounts.size (m), m);
|
||||
}
|
||||
|
||||
// remove used up tokens from the buffer
|
||||
|
@ -1977,10 +1983,11 @@ public:
|
|||
//// set prune value to 0 3 3
|
||||
//setMinObs (iMinObs);
|
||||
|
||||
for (size_t i = 0; i < minObs.size(); i++)
|
||||
{
|
||||
MESSAGE("minObs %d: %d.", i, minObs[i]);
|
||||
}
|
||||
// TODO: Re-enable when MESSAGE definition is provided (printf?)
|
||||
// for (size_t i = 0; i < minObs.size(); i++)
|
||||
// {
|
||||
// MESSAGE("minObs %d: %d.", i, minObs[i]);
|
||||
// }
|
||||
|
||||
estimate (startId, minObs, dropWord);
|
||||
|
||||
|
@ -2027,7 +2034,7 @@ public:
|
|||
while (M > 0 && counts.size (M) == 0) resize (M-1);
|
||||
|
||||
for (int m = 1; m <= M; m++)
|
||||
fprintf (stderr, "estimate: read %d %d-grams\n", counts.size (m), m);
|
||||
fprintf (stderr, "estimate: read %d %d-grams\n", (int)counts.size (m), m);
|
||||
|
||||
// === Kneser-Ney smoothing
|
||||
// This is a strange algorithm.
|
||||
|
@ -2197,8 +2204,8 @@ public:
|
|||
for (int m = 1; m <= M; m++)
|
||||
{
|
||||
fprintf (stderr, "estimate: %d-grams after pruning: %d out of %d (%.1f%%)\n", m,
|
||||
numMGrams[m], counts.size (m),
|
||||
100.0 * numMGrams[m] / max (counts.size (m), 1));
|
||||
numMGrams[m], (int)counts.size (m),
|
||||
100.0 * numMGrams[m] / max (counts.size (m), size_t(1)));
|
||||
}
|
||||
|
||||
// ensure M reflects the actual order of read data after pruning
|
||||
|
@ -2282,6 +2289,9 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
double dcount;
|
||||
double dP;
|
||||
|
||||
// pruned case
|
||||
if (count == 0) // this entry was pruned before
|
||||
goto skippruned;
|
||||
|
@ -2314,7 +2324,7 @@ public:
|
|||
}
|
||||
|
||||
// estimate discounted probability
|
||||
double dcount = count; // "modified Kneser-Ney" discounting
|
||||
dcount = count; // "modified Kneser-Ney" discounting
|
||||
if (count >= 3) dcount -= d3[m];
|
||||
else if (count == 2) dcount -= d2[m];
|
||||
else if (count == 1) dcount -= d1[m];
|
||||
|
@ -2323,7 +2333,7 @@ public:
|
|||
|
||||
if (histCount == 0)
|
||||
RuntimeError ("estimate: unexpected 0 denominator");
|
||||
double dP = dcount / histCount;
|
||||
dP = dcount / histCount;
|
||||
// and this is the discounted probability value
|
||||
{
|
||||
// Actually, 'key' uses a "mapped" word ids, while create()
|
||||
|
@ -2412,7 +2422,7 @@ skippruned:; // m-gram was pruned
|
|||
updateOOVScore();
|
||||
|
||||
fprintf (stderr, "estimate: done");
|
||||
for (int m = 1; m <= M; m++) fprintf (stderr, ", %d %d-grams", logP.size (m), m);
|
||||
for (int m = 1; m <= M; m++) fprintf (stderr, ", %d %d-grams", (int)logP.size (m), m);
|
||||
fprintf (stderr, "\n");
|
||||
}
|
||||
};
|
||||
|
@ -2521,7 +2531,7 @@ skipMGram:
|
|||
wstring dir, file;
|
||||
splitpath (clonepath, dir, file); // we allow relative paths in the file
|
||||
|
||||
msra::basetypes::auto_file_ptr f = fopenOrDie (clonepath, L"rbS");
|
||||
msra::basetypes::auto_file_ptr f(fopenOrDie (clonepath, L"rbS"));
|
||||
std::string line = fgetline (f);
|
||||
if (line != "#clone")
|
||||
throw runtime_error ("read: invalid header line " + line);
|
||||
|
|
|
@ -7,9 +7,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifndef __unix__
|
||||
#include <Windows.h>
|
||||
#include <stdexcept>
|
||||
#include "pplhelpers.h"
|
||||
#endif
|
||||
#include <stdexcept>
|
||||
#include "simple_checked_arrays.h"
|
||||
#include "basetypes.h" // for FormatWin32Error
|
||||
|
||||
|
|
|
@ -8,8 +8,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifndef __unix__
|
||||
#include <ppl.h>
|
||||
|
||||
#endif
|
||||
namespace msra { namespace parallel {
|
||||
|
||||
// ===========================================================================
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
#include "basetypes.h"
|
||||
#include "minibatchiterator.h"
|
||||
#include "latticearchive.h"
|
||||
#ifdef _WIN32
|
||||
#include "simplethread.h"
|
||||
#endif
|
||||
#include <deque>
|
||||
#include <stdexcept>
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@
|
|||
#pragma once
|
||||
|
||||
#include "basetypes.h" // for attempt()
|
||||
#ifdef _WIN32
|
||||
#include "numahelpers.h" // for NUMA allocation
|
||||
#endif
|
||||
#include "minibatchsourcehelpers.h"
|
||||
#include "minibatchiterator.h"
|
||||
#include "biggrowablevectors.h"
|
||||
|
@ -37,9 +39,13 @@ namespace msra { namespace dbn {
|
|||
msra::dbn::matrix * newblock() const
|
||||
{
|
||||
// we stripe the data across NUMA nodes as to not fill up one node with the feature data
|
||||
#ifdef _WIN32
|
||||
msra::numa::overridenode ((int) msra::numa::getmostspaciousnumanode());
|
||||
#endif
|
||||
msra::dbn::matrix * res = new msra::dbn::matrix (m, elementsperblock);
|
||||
#ifdef _WIN32
|
||||
msra::numa::overridenode (-1); // note: we really should reset it also in case of failure
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -100,7 +106,7 @@ namespace msra { namespace dbn {
|
|||
size_t blockid = t0 / elementsperblock;
|
||||
assert (blockid * elementsperblock == t0);
|
||||
assert (blocks[blockid]);
|
||||
fprintf (stderr, "recoverblock: releasing feature block %d [%d..%d)\n", blockid, t0, t0 + elementsperblock -1);
|
||||
fprintf (stderr, "recoverblock: releasing feature block %d [%d..%d)\n", (int)blockid, (int)t0, (int)(t0 + elementsperblock -1));
|
||||
blocks[blockid].reset(); // free the memory
|
||||
}
|
||||
void recoverblock (size_t t0) // t0=block start time
|
||||
|
@ -109,7 +115,7 @@ namespace msra { namespace dbn {
|
|||
size_t blockid = t0 / elementsperblock;
|
||||
assert (blockid * elementsperblock == t0);
|
||||
assert (!blocks[blockid]);
|
||||
fprintf (stderr, "recoverblock: recovering feature block %d [%d..%d)\n", blockid, t0, t0 + elementsperblock -1);
|
||||
fprintf (stderr, "recoverblock: recovering feature block %d [%d..%d)\n", (int)blockid, (int)t0, (int)(t0 + elementsperblock -1));
|
||||
blocks[blockid].reset (newblock());
|
||||
msra::dbn::matrix & block = *blocks[blockid];
|
||||
fsetpos (f, blockid * block.sizeinpagefile());
|
||||
|
@ -163,7 +169,7 @@ namespace msra { namespace dbn {
|
|||
// finish off last block
|
||||
flushlastblock();
|
||||
fflushOrDie (f);
|
||||
fprintf (stderr, "biggrowablevectorarray: disk backup store created, %d frames, %ull bytes\n", (int) n, fgetpos (f));
|
||||
fprintf (stderr, "biggrowablevectorarray: disk backup store created, %d frames, %lu bytes\n", (int) n, fgetpos (f));
|
||||
fclose (f);
|
||||
foreach_index (i, blocks) assert (!blocks[i]); // ensure we flushed
|
||||
assert (inmembegin == inmemend); // nothing in cache
|
||||
|
@ -265,7 +271,7 @@ namespace msra { namespace dbn {
|
|||
// - implement block-wise paging directly from HTK feature files through htkfeatreader
|
||||
featkind.clear();
|
||||
std::vector<float> frame;
|
||||
fprintf (stderr, "minibatchframesource: reading %d utterances..", infiles.size());
|
||||
fprintf (stderr, "minibatchframesource: reading %d utterances..", (int)infiles.size());
|
||||
size_t numclasses = 0; // number of units found (actually max id +1)
|
||||
size_t notfound = 0; // number of entries missing in MLF
|
||||
msra::asr::htkfeatreader reader; // feature reader
|
||||
|
@ -281,7 +287,12 @@ namespace msra { namespace dbn {
|
|||
wstring key;
|
||||
if (!labels.empty()) // empty means unsupervised mode (don't load any)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
key = removeExtension(basename(ppath));
|
||||
#endif
|
||||
if (labels.find (key) == labels.end())
|
||||
{
|
||||
if (notfound < 5)
|
||||
|
@ -309,7 +320,7 @@ namespace msra { namespace dbn {
|
|||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
|
||||
if (abs ((int) labframes - (int) feat.cols()) > 0)
|
||||
{
|
||||
fprintf (stderr, "\nminibatchframesource: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
|
||||
fprintf (stderr, "\nminibatchframesource: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, (int)labframes, (int)feat.cols(), key.c_str());
|
||||
notfound++;
|
||||
continue; // skip this utterance at all
|
||||
}
|
||||
|
@ -346,7 +357,7 @@ namespace msra { namespace dbn {
|
|||
if (e.classid != (CLASSIDTYPE) e.classid)
|
||||
throw std::runtime_error ("CLASSIDTYPE has too few bits");
|
||||
classids.push_back ((CLASSIDTYPE) e.classid);
|
||||
numclasses = max (numclasses, 1u + e.classid);
|
||||
numclasses = max (numclasses, (size_t)(1u + e.classid));
|
||||
}
|
||||
}
|
||||
if (vdim == 0)
|
||||
|
@ -364,10 +375,10 @@ namespace msra { namespace dbn {
|
|||
assert (labels.empty() || numframes == classids.size());
|
||||
if ((vdim != 0 && numframes != frames.size()) || (!labels.empty() && numframes != classids.size()))
|
||||
throw std::runtime_error ("minibatchframesource: numframes variable screwup");
|
||||
fprintf (stderr, " %d frames read from %d utterances; %d classes\n", numframes, infiles.size(), numclasses);
|
||||
fprintf (stderr, " %d frames read from %d utterances; %d classes\n", (int)numframes, (int)infiles.size(), (int)numclasses);
|
||||
if (notfound > 0)
|
||||
{
|
||||
fprintf (stderr, "minibatchframesource: %d files out of %d not found in label set\n", notfound, infiles.size());
|
||||
fprintf (stderr, "minibatchframesource: %d files out of %d not found in label set\n", (int)notfound, (int)infiles.size());
|
||||
if (notfound > infiles.size() / 2)
|
||||
throw std::runtime_error ("minibatchframesource: too many files not found in label set--assuming broken configuration\n");
|
||||
}
|
||||
|
@ -421,7 +432,7 @@ namespace msra { namespace dbn {
|
|||
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
|
||||
assert (te > ts);
|
||||
if (verbosity >= 2)
|
||||
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", ts, te-1, sweep);
|
||||
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", (int)ts, (int)(te-1), (int)sweep);
|
||||
|
||||
// get random sequence (each time index occurs exactly once)
|
||||
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.
|
||||
|
@ -543,7 +554,7 @@ namespace msra { namespace dbn {
|
|||
}
|
||||
|
||||
|
||||
fprintf (stderr, "minibatchframesourcemulti: reading %d feature sets and %d label sets...", infiles.size(),labels.size());
|
||||
fprintf (stderr, "minibatchframesourcemulti: reading %d feature sets and %d label sets...", (int)infiles.size(), (int)labels.size());
|
||||
|
||||
foreach_index (m, infiles)
|
||||
{
|
||||
|
@ -567,7 +578,12 @@ namespace msra { namespace dbn {
|
|||
{
|
||||
if (!labels[0].empty()) // empty means unsupervised mode (don't load any)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
key = removeExtension(basename(ppath));
|
||||
#endif
|
||||
if (labels[0].find (key) == labels[0].end())
|
||||
{
|
||||
if (notfound < 5)
|
||||
|
@ -595,7 +611,7 @@ namespace msra { namespace dbn {
|
|||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
|
||||
if (abs ((int) labframes - (int) feat.cols()) > 0)
|
||||
{
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file has small duration mismatch (%d in label vs. %d in feat file), skipping: %S", i, (int)labframes, (int)feat.cols(), key.c_str());
|
||||
notfound++;
|
||||
continue; // skip this utterance at all
|
||||
}
|
||||
|
@ -645,7 +661,7 @@ namespace msra { namespace dbn {
|
|||
if (e.classid != (CLASSIDTYPE) e.classid)
|
||||
throw std::runtime_error ("CLASSIDTYPE has too few bits");
|
||||
classids[j].push_back ((CLASSIDTYPE) e.classid);
|
||||
numclasses[j] = max (numclasses[j], 1u + e.classid);
|
||||
numclasses[j] = max (numclasses[j], (size_t)(1u + e.classid));
|
||||
}
|
||||
}
|
||||
if (vdim[m] == 0)
|
||||
|
@ -676,12 +692,12 @@ namespace msra { namespace dbn {
|
|||
if (m==0)
|
||||
{
|
||||
foreach_index (j, numclasses)
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: read label set %d: %d classes\n", j, numclasses[j]);
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: read label set %d: %d classes\n", j, (int)numclasses[j]);
|
||||
}
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: feature set %d: %d frames read from %d utterances\n", m, pframes[m]->size(), infiles[m].size());
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: feature set %d: %d frames read from %d utterances\n", m, (int)pframes[m]->size(), (int)infiles[m].size());
|
||||
if (notfound > 0)
|
||||
{
|
||||
fprintf (stderr, "minibatchframesourcemulti: %d files out of %d not found in label set\n", notfound, infiles[m].size());
|
||||
fprintf (stderr, "minibatchframesourcemulti: %d files out of %d not found in label set\n", (int)notfound, (int)infiles[m].size());
|
||||
if (notfound > infiles[m].size() / 2)
|
||||
throw std::runtime_error ("minibatchframesourcemulti: too many files not found in label set--assuming broken configuration\n");
|
||||
}
|
||||
|
@ -741,7 +757,7 @@ namespace msra { namespace dbn {
|
|||
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
|
||||
assert (te > ts);
|
||||
if (verbosity >= 2)
|
||||
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", ts, te-1, sweep);
|
||||
fprintf (stderr, "getbatch: frames [%d..%d] in sweep %d\n", (int)ts, (int)(te-1), (int)sweep);
|
||||
|
||||
// get random sequence (each time index occurs exactly once)
|
||||
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.
|
||||
|
|
|
@ -64,7 +64,7 @@ public: // (TODO: better encapsulation)
|
|||
transP() : numstates (0) {}
|
||||
};
|
||||
std::vector<transP> transPs; // the transition matrices --TODO: finish this
|
||||
std::hash_map<std::string,size_t> transPmap; // [transPname] -> index into transPs[]
|
||||
std::unordered_map<std::string, size_t> transPmap; // [transPname] -> index into transPs[]
|
||||
public:
|
||||
// get an hmm by index
|
||||
const hmm & gethmm (size_t i) const { return hmms[i]; }
|
||||
|
@ -216,7 +216,7 @@ public:
|
|||
}
|
||||
}
|
||||
fprintf (stderr, "simplesenonehmm: %d units with %d unique HMMs, %d tied states, and %d trans matrices read\n",
|
||||
symmap.size(), hmms.size(), statemap.size(), transPs.size());
|
||||
(int)symmap.size(), (int)hmms.size(), (int)statemap.size(), (int)transPs.size());
|
||||
}
|
||||
|
||||
// exposed so we can pass it to the lattice reader, which maps the symbol ids for us
|
||||
|
|
|
@ -9,7 +9,9 @@
|
|||
#pragma once
|
||||
|
||||
#include "basetypes.h"
|
||||
#ifdef _WIN32
|
||||
#include <process.h> // for _beginthread()
|
||||
#endif
|
||||
|
||||
namespace msra { namespace util {
|
||||
|
||||
|
|
|
@ -8,7 +8,12 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <intrin.h> // for intrinsics
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
namespace msra { namespace math {
|
||||
|
||||
|
|
|
@ -13,11 +13,14 @@
|
|||
#include "simple_checked_arrays.h" // ... for dotprod(); we can eliminate this I believe
|
||||
#include "ssefloat4.h"
|
||||
#include <stdexcept>
|
||||
#include "numahelpers.h"
|
||||
#ifndef __unix__
|
||||
#include <ppl.h>
|
||||
#include "pplhelpers.h"
|
||||
#include "numahelpers.h"
|
||||
#endif
|
||||
#include "fileutil.h" // for saving and reading matrices
|
||||
#include <limits> // for NaN
|
||||
#include <malloc.h>
|
||||
|
||||
namespace msra { namespace math {
|
||||
|
||||
|
@ -275,7 +278,7 @@ public:
|
|||
bool addtoresult, const float thisscale, const float weight)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
assert ((15 & (int) &a[0]) == 0); assert ((15 & (int) &b[0]) == 0); // enforce SSE alignment
|
||||
assert ((15 & reinterpret_cast<uintptr_t>(&a[0])) == 0); assert ((15 & reinterpret_cast<uintptr_t>(&b[0])) == 0); // enforce SSE alignment
|
||||
|
||||
size_t nlong = (a.size() + 3) / 4; // number of SSE elements
|
||||
const msra::math::float4 * pa = (const msra::math::float4 *) &a[0];
|
||||
|
@ -310,9 +313,9 @@ public:
|
|||
// for (size_t k = 0; k < 4; k++)
|
||||
// dotprod (row, const_array_ref<float> (&cols4[k * cols4stride], cols4stride), usij[k * usijstride]);
|
||||
|
||||
assert ((15 & (int) &row[0]) == 0);
|
||||
assert ((15 & (int) &cols4[0]) == 0);
|
||||
assert ((15 & (int) &cols4[cols4stride]) == 0);
|
||||
assert ((15 & reinterpret_cast<uintptr_t>(&row[0])) == 0);
|
||||
assert ((15 & reinterpret_cast<uintptr_t>(&cols4[0])) == 0);
|
||||
assert ((15 & reinterpret_cast<uintptr_t>(&cols4[cols4stride])) == 0);
|
||||
//assert (cols4stride * 4 == cols4.size()); // (passed in one vector with 4 columns stacked on top of each other)
|
||||
//assert (row.size() * 4 == cols4.size()); // this assert is no longer appropriate because of further breaking into blocks
|
||||
|
||||
|
@ -389,6 +392,7 @@ public:
|
|||
matprod_mtm (Mt, 0, Mt.cols(), V);
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
void parallel_matprod_mtm (const ssematrixbase & Mt, const ssematrixbase & V)
|
||||
{
|
||||
msra::parallel::foreach_index_block (Mt.cols(), Mt.cols(), 1, [&] (size_t i0, size_t i1)
|
||||
|
@ -396,6 +400,7 @@ public:
|
|||
matprod_mtm (Mt, i0, i1, V);
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
// swap data of i-th column and j-th column
|
||||
void swapcolumn (size_t i, size_t j)
|
||||
|
@ -801,6 +806,7 @@ public:
|
|||
scaleandaddmatprod_mtm (thisscale, Mt, 0, Mt.cols(), V);
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
void parallel_scaleandaddmatprod_mtm (const float thisscale, const ssematrixbase & Mt, const ssematrixbase & V)
|
||||
{
|
||||
#if 0
|
||||
|
@ -813,6 +819,7 @@ public:
|
|||
});
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
// same as matprod_mtm except result is added to result matrix instead of replacing it
|
||||
// For all comments, see matprod_mtm.
|
||||
|
@ -912,6 +919,7 @@ public:
|
|||
// to = this'
|
||||
void transpose (ssematrixbase & to) const { transposecolumns (to, 0, cols()); }
|
||||
|
||||
#ifdef _WIN32
|
||||
void parallel_transpose (ssematrixbase & to) const
|
||||
{
|
||||
msra::parallel::foreach_index_block (cols(), cols(), 4/*align*/, [&] (size_t j0, size_t j1)
|
||||
|
@ -925,6 +933,7 @@ public:
|
|||
throw std::logic_error ("parallel_transpose: post-condition check failed--you got it wrong, man!");
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
// transpose columns [j0,j1) to rows [j0,j1) of 'to'
|
||||
void transposecolumns (ssematrixbase & to, size_t j0, size_t j1) const
|
||||
|
@ -1149,7 +1158,7 @@ public:
|
|||
foreach_coord (i, j, us)
|
||||
if (std::isnan (us(i,j)))
|
||||
{
|
||||
fprintf (stderr, "hasnan: NaN detected at %s (%d,%d)\n", name, i, j);
|
||||
fprintf (stderr, "hasnan: NaN detected at %s (%d,%d)\n", name, (int)i, (int)j);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
@ -1200,7 +1209,7 @@ class ssematrixfrombuffer : public ssematrixbase
|
|||
{
|
||||
void operator= (const ssematrixfrombuffer &); ssematrixfrombuffer (const ssematrixfrombuffer &); // base cannot be assigned except by move
|
||||
public:
|
||||
ssematrixfrombuffer() { clear(); }
|
||||
ssematrixfrombuffer() { this->clear(); }
|
||||
|
||||
// instantiate from a float vector --buffer must be SSE-aligned
|
||||
template<class VECTOR> ssematrixfrombuffer (VECTOR & buffer, size_t n, size_t m) : ssematrixbase (buffer, n, m) {}
|
||||
|
@ -1233,10 +1242,10 @@ public:
|
|||
assert (other.empty() || j0 + m <= other.cols());
|
||||
if (!other.empty() && j0 + m > other.cols()) // (runtime check to be sure--we use this all the time)
|
||||
throw std::logic_error ("ssematrixstriperef: stripe outside original matrix' dimension");
|
||||
p = other.empty() ? NULL : &other(0,j0);
|
||||
numrows = other.rows();
|
||||
numcols = m;
|
||||
colstride = other.getcolstride();
|
||||
this->p = other.empty() ? NULL : &other(0,j0);
|
||||
this->numrows = other.rows();
|
||||
this->numcols = m;
|
||||
this->colstride = other.getcolstride();
|
||||
}
|
||||
|
||||
// only assignment is by rvalue reference
|
||||
|
@ -1255,14 +1264,20 @@ public:
|
|||
template<class ssematrixbase> class ssematrix : public ssematrixbase
|
||||
{
|
||||
// helpers for SSE-compatible memory allocation
|
||||
static __declspec(noreturn) void failed (size_t nbytes) { static/*not thread-safe--for diagnostics only*/ char buf[80] = { 0 }; sprintf_s (buf, "allocation of SSE vector failed (%d bytes)", nbytes); throw std::bad_exception (buf); }
|
||||
#if 1 // TODO: move to separate header file numahelpers.h
|
||||
template<typename T> static T * new_sse (size_t nbytes) { T * pv = (T *) msra::numa::malloc (nbytes * sizeof (T), 16); if (pv) return pv; failed (nbytes * sizeof (T)); }
|
||||
static void delete_sse (void * p) { if (p) msra::numa::free (p); }
|
||||
#else
|
||||
#ifdef _MSC_VER
|
||||
static __declspec_noreturn void failed(size_t nbytes) { static/*not thread-safe--for diagnostics only*/ char buf[80] = { 0 }; sprintf_s(buf, "allocation of SSE vector failed (%d bytes)", nbytes); throw std::bad_exception(buf); }
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
static void failed (size_t nbytes) { static/*not thread-safe--for diagnostics only*/ char buf[80] = { 0 }; sprintf_s (buf, sizeof(buf), "allocation of SSE vector failed (%d bytes)", (int)nbytes); throw std::bad_exception (); }
|
||||
#endif
|
||||
#ifdef _WIN32
|
||||
template<typename T> static T * new_sse (size_t nbytes) { T * pv = (T *) _aligned_malloc (nbytes * sizeof (T), 16); if (pv) return pv; failed (nbytes * sizeof (T)); }
|
||||
static void delete_sse (void * p) { if (p) _aligned_free (p); }
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
template<typename T> static T * new_sse (size_t nbytes) { T * pv = (T *) _mm_malloc (nbytes * sizeof (T),16); if (pv) return pv; failed (nbytes * sizeof (T)); }
|
||||
static void delete_sse (void * p) { if (p) _mm_free (p); }
|
||||
#endif
|
||||
|
||||
// helper to assign a copy from another matrix
|
||||
void assign (const ssematrixbase & other)
|
||||
|
@ -1272,18 +1287,18 @@ template<class ssematrixbase> class ssematrix : public ssematrixbase
|
|||
};
|
||||
public:
|
||||
// construction
|
||||
ssematrix() { clear(); }
|
||||
ssematrix (size_t n, size_t m) { clear(); resize (n, m); }
|
||||
ssematrix (size_t n) { clear(); resize (n, 1); } // vector
|
||||
ssematrix (const ssematrix & other) { clear(); assign (other); }
|
||||
ssematrix (const ssematrixbase & other) { clear(); assign (other); }
|
||||
ssematrix (ssematrix && other) { move (other); }
|
||||
ssematrix (const std::vector<float> & other) { clear(); resize (other.size(), 1); foreach_index (k, other) (*this)[k] = other[k]; }
|
||||
ssematrix() { this->clear(); }
|
||||
ssematrix (size_t n, size_t m) { this->clear(); resize (n, m); }
|
||||
ssematrix (size_t n) { this->clear(); resize (n, 1); } // vector
|
||||
ssematrix (const ssematrix & other) { this->clear(); assign (other); }
|
||||
ssematrix (const ssematrixbase & other) { this->clear(); assign (other); }
|
||||
ssematrix (ssematrix && other) { this->move (other); }
|
||||
ssematrix (const std::vector<float> & other) { this->clear(); resize (other.size(), 1); foreach_index (k, other) (*this)[k] = other[k]; }
|
||||
|
||||
// construct elementwise with a function f(i,j)
|
||||
template<typename FUNCTION> ssematrix (size_t n, size_t m, const FUNCTION & f)
|
||||
{
|
||||
clear();
|
||||
this->clear();
|
||||
resize (n, m);
|
||||
auto & us = *this;
|
||||
foreach_coord (i, j, us)
|
||||
|
@ -1291,12 +1306,12 @@ public:
|
|||
}
|
||||
|
||||
// destructor
|
||||
~ssematrix() { delete_sse (p); }
|
||||
~ssematrix() { delete_sse (this->p); }
|
||||
|
||||
// assignment
|
||||
ssematrix & operator= (const ssematrix & other) { assign (other); return *this; }
|
||||
ssematrix & operator= (const ssematrixbase & other) { assign (other); return *this; }
|
||||
ssematrix & operator= (ssematrix && other) { delete_sse(p); move (other); return *this; }
|
||||
ssematrix & operator= (ssematrix && other) { delete_sse(this->p); move (other); return *this; }
|
||||
|
||||
void swap (ssematrix & other) throw() { ssematrixbase::swap (other); }
|
||||
|
||||
|
@ -1304,23 +1319,23 @@ public:
|
|||
// One or both dimensions can be 0, for special purposes.
|
||||
void resize (size_t n, size_t m)
|
||||
{
|
||||
if (n == numrows && m == numcols)
|
||||
if (n == this->numrows && m == this->numcols)
|
||||
return; // no resize needed
|
||||
const size_t newcolstride = (n + 3) & ~3; // pad to multiples of four floats (required SSE alignment)
|
||||
const size_t totalelem = newcolstride * m;
|
||||
//fprintf (stderr, "resize (%d, %d) allocating %d elements\n", n, m, totalelem);
|
||||
float * pnew = totalelem > 0 ? new_sse<float> (totalelem) : NULL;
|
||||
::swap (p, pnew);
|
||||
::swap (this->p, pnew);
|
||||
delete_sse (pnew); // pnew is now the old p
|
||||
numrows = n; numcols = m;
|
||||
colstride = newcolstride;
|
||||
this->numrows = n; this->numcols = m;
|
||||
this->colstride = newcolstride;
|
||||
// touch the memory to ensure the page is created
|
||||
for (size_t offset = 0; offset < totalelem; offset += 4096 / sizeof (float))
|
||||
p[offset] = 0.0f; //nan;
|
||||
this->p[offset] = 0.0f; //nan;
|
||||
// clear padding elements (numrows <= i < colstride) to 0.0 for SSE optimization
|
||||
for (size_t j = 0; j < numcols; j++)
|
||||
for (size_t i = numrows; i < colstride; i++)
|
||||
p[j * colstride + i] = 0.0f;
|
||||
for (size_t j = 0; j < this->numcols; j++)
|
||||
for (size_t i = this->numrows; i < this->colstride; i++)
|
||||
this->p[j * this->colstride + i] = 0.0f;
|
||||
#if 1 // for debugging: set all elements to 0
|
||||
// We keep this code alive because allocations are supposed to be done at the start only.
|
||||
auto & us = *this;
|
||||
|
@ -1335,8 +1350,8 @@ public:
|
|||
void resizeonce (size_t n, size_t m)
|
||||
{
|
||||
#if 1 // BUGBUG: at end of epoch, resizes are OK... so we log but allow them
|
||||
if (!empty() && (n != numrows || m != numcols))
|
||||
fprintf (stderr, "resizeonce: undesired resize from %d x %d to %d x %d\n", numrows, numcols, n, m);
|
||||
if (!this->empty() && (n != this->numrows || m != this->numcols))
|
||||
fprintf (stderr, "resizeonce: undesired resize from %d x %d to %d x %d\n", this->numrows, this->numcols, n, m);
|
||||
resize (n, m);
|
||||
#else
|
||||
if (empty())
|
||||
|
@ -1349,10 +1364,10 @@ public:
|
|||
// non-destructive resize() to a smaller size
|
||||
void shrink(size_t newrows, size_t newcols)
|
||||
{
|
||||
if (newrows > numrows || newcols > numcols)
|
||||
if (newrows > this->numrows || newcols > this->numcols)
|
||||
throw std::logic_error ("shrink: attempted to grow the matrix");
|
||||
numrows = newrows;
|
||||
numcols = newcols;
|
||||
this->numrows = newrows;
|
||||
this->numcols = newcols;
|
||||
}
|
||||
|
||||
// file I/O
|
||||
|
@ -1360,8 +1375,8 @@ public:
|
|||
{
|
||||
fputTag (f, "BMAT");
|
||||
fputstring (f, name);
|
||||
fputint (f, (int) numrows);
|
||||
fputint (f, (int) numcols);
|
||||
fputint (f, (int) this->numrows);
|
||||
fputint (f, (int) this->numcols);
|
||||
const auto & us = *this;
|
||||
foreach_column (j, us)
|
||||
{
|
||||
|
@ -1375,8 +1390,8 @@ public:
|
|||
{
|
||||
fputTag(f, "BMAT");
|
||||
fputstring (f, name);
|
||||
fputint (f, (int) numrows);
|
||||
fputint (f, (int) numcols);
|
||||
fputint (f, (int) this->numrows);
|
||||
fputint (f, (int) this->numcols);
|
||||
const auto & us = *this;
|
||||
foreach_column (j, us)
|
||||
{
|
||||
|
@ -1426,9 +1441,9 @@ public:
|
|||
}
|
||||
|
||||
// paging support (used in feature source)
|
||||
void topagefile (FILE * f) const { if (!empty()) fwriteOrDie (p, sizeinpagefile(), 1, f); }
|
||||
void frompagefile (FILE * f) { if (!empty()) freadOrDie (p, sizeinpagefile(), 1, f); }
|
||||
size_t sizeinpagefile() const { return colstride * numcols * sizeof (*p); }
|
||||
void topagefile (FILE * f) const { if (!this->empty()) fwriteOrDie (this->p, sizeinpagefile(), 1, f); }
|
||||
void frompagefile (FILE * f) { if (!this->empty()) freadOrDie (this->p, sizeinpagefile(), 1, f); }
|
||||
size_t sizeinpagefile() const { return this->colstride * this->numcols * sizeof (*(this->p)); }
|
||||
|
||||
// getting a one-column sub-view on this
|
||||
ssematrixstriperef<ssematrixbase> col (size_t j)
|
||||
|
@ -1541,7 +1556,7 @@ template<class M> pair<unsigned int,unsigned int> printmatvaluedistributionf (co
|
|||
const size_t numparts = 100;
|
||||
for (size_t i=1; i<=numparts; i++)
|
||||
{
|
||||
fprintf (stderr, "%.5f%% absolute values are under %.10f\n", i*100.0/numparts, vals[min(num-1,i*num/numparts)]);
|
||||
fprintf (stderr, "%.5f%% absolute values are under %.10f\n", i*100.0/numparts, vals[min((size_t)num-1,i*num/numparts)]);
|
||||
}
|
||||
fprintf (stderr, "\n%.5f%% values are zero\n\n", 100.0*numzeros/num);
|
||||
#endif
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
|
||||
#include "Platform.h"
|
||||
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms
|
||||
#include "targetver.h"
|
||||
|
||||
#ifndef __unix__
|
||||
#include "targetver.h"
|
||||
#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
|
||||
// Windows Header Files:
|
||||
#include <windows.h>
|
||||
|
|
|
@ -113,7 +113,7 @@ class minibatchutterancesource : public minibatchsource
|
|||
if (featdim == 0)
|
||||
{
|
||||
reader.getinfo (utteranceset[0].parsedpath, featkind, featdim, sampperiod);
|
||||
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", featdim, featkind.c_str(), sampperiod / 1e4);
|
||||
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", (int)featdim, featkind.c_str(), sampperiod / 1e4);
|
||||
}
|
||||
// read all utterances; if they are in the same archive, htkfeatreader will be efficient in not closing the file
|
||||
frames.resize (featdim, totalframes);
|
||||
|
@ -130,7 +130,7 @@ class minibatchutterancesource : public minibatchsource
|
|||
latticesource.getlattices (utteranceset[i].key(), lattices[i], uttframes.cols());
|
||||
}
|
||||
//fprintf (stderr, "\n");
|
||||
fprintf (stderr, "requiredata: %d utterances read\n", utteranceset.size());
|
||||
fprintf (stderr, "requiredata: %d utterances read\n", (int)utteranceset.size());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
|
@ -199,17 +199,17 @@ class minibatchutterancesource : public minibatchsource
|
|||
}
|
||||
};
|
||||
std::vector<utteranceref> randomizedutterancerefs; // [pos] randomized utterance ids
|
||||
std::hash_map<size_t,size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
|
||||
std::unordered_map<size_t, size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
|
||||
struct positionchunkwindow // chunk window required in memory when at a certain position, for controlling paging
|
||||
{
|
||||
std::vector<chunk>::const_iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
|
||||
std::vector<chunk>::iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
|
||||
size_t windowbegin() const { return definingchunk->windowbegin; }
|
||||
size_t windowend() const { return definingchunk->windowend; }
|
||||
bool isvalidforthisposition (const utteranceref & utt) const
|
||||
{
|
||||
return utt.chunkindex >= windowbegin() && utt.chunkindex < windowend(); // check if 'utt' lives in is in allowed range for this position
|
||||
}
|
||||
positionchunkwindow (std::vector<chunk>::const_iterator definingchunk) : definingchunk (definingchunk) {}
|
||||
positionchunkwindow (std::vector<chunk>::iterator definingchunk) : definingchunk (definingchunk) {}
|
||||
};
|
||||
std::vector<positionchunkwindow> positionchunkwindows; // [utterance position] -> [windowbegin, windowend) for controlling paging
|
||||
|
||||
|
@ -297,7 +297,7 @@ public:
|
|||
throw std::runtime_error ("minibatchutterancesource: utterances < 2 frames not supported");
|
||||
if (uttframes > frameref::maxframesperutterance)
|
||||
{
|
||||
fprintf (stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S", i, uttframes, frameref::maxframesperutterance, key.c_str());
|
||||
fprintf (stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S", i, (int)uttframes, (int)frameref::maxframesperutterance, key.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -331,7 +331,7 @@ public:
|
|||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
|
||||
if (labframes != uttframes)
|
||||
{
|
||||
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", labframes, uttframes, key.c_str());
|
||||
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", (int)labframes, (int)uttframes, key.c_str());
|
||||
nomlf++;
|
||||
continue; // skip this utterance at all
|
||||
}
|
||||
|
@ -347,7 +347,7 @@ public:
|
|||
throw std::runtime_error ("CLASSIDTYPE has too few bits");
|
||||
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
|
||||
classids.push_back ((CLASSIDTYPE) e.classid);
|
||||
numclasses = max (numclasses, 1u + e.classid);
|
||||
numclasses = max (numclasses, (size_t)(1u + e.classid));
|
||||
counts.resize (numclasses, 0);
|
||||
counts[e.classid] += e.numframes;
|
||||
}
|
||||
|
@ -360,7 +360,7 @@ public:
|
|||
throw std::logic_error (msra::strfun::strprintf ("minibatchutterancesource: label duration inconsistent with feature file in MLF label set: %S", key.c_str()));
|
||||
assert (labels.empty() || classids.size() == _totalframes + utteranceset.size());
|
||||
}
|
||||
fprintf (stderr, " %d frames in %d out of %d utterances; %d classes\n", _totalframes, utteranceset.size(),infiles.size(), numclasses);
|
||||
fprintf (stderr, " %d frames in %d out of %d utterances; %d classes\n", (int)_totalframes, (int)utteranceset.size(), (int)infiles.size(), (int)numclasses);
|
||||
if (!labels.empty())
|
||||
foreach_index (i, utteranceset)
|
||||
{
|
||||
|
@ -369,7 +369,7 @@ public:
|
|||
}
|
||||
if (nomlf + nolat > 0)
|
||||
{
|
||||
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", infiles.size(), nomlf, nolat);
|
||||
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", (int)infiles.size(), (int)nomlf, (int)nolat);
|
||||
if (nomlf + nolat > infiles.size() / 2)
|
||||
throw std::runtime_error ("minibatchutterancesource: too many files not found in label set--assuming broken configuration\n");
|
||||
}
|
||||
|
@ -398,7 +398,7 @@ public:
|
|||
}
|
||||
numutterances = utteranceset.size();
|
||||
fprintf (stderr, "minibatchutterancesource: %d utterances grouped into %d chunks, av. chunk size: %.1f utterances, %.1f frames\n",
|
||||
numutterances, allchunks.size(), numutterances / (double) allchunks.size(), _totalframes / (double) allchunks.size());
|
||||
(int)numutterances, (int)allchunks.size(), numutterances / (double) allchunks.size(), _totalframes / (double) allchunks.size());
|
||||
// Now utterances are stored exclusively in allchunks[]. They are never referred to by a sequential utterance id at this point, only by chunk/within-chunk index.
|
||||
|
||||
// preliminary mem allocation for frame references (if in frame mode)
|
||||
|
@ -462,7 +462,7 @@ private:
|
|||
return sweep;
|
||||
|
||||
currentsweep = sweep;
|
||||
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", currentsweep, framemode ? "frame" : "utterance");
|
||||
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", (int)currentsweep, framemode ? "frame" : "utterance");
|
||||
|
||||
const size_t sweepts = sweep * _totalframes; // first global frame index for this sweep
|
||||
|
||||
|
@ -751,7 +751,7 @@ private:
|
|||
|
||||
if (verbosity)
|
||||
fprintf (stderr, "releaserandomizedchunk: paging out randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n",
|
||||
k, randomizedchunks[k].globalts, randomizedchunks[k].globalte()-1, chunksinram-1);
|
||||
(int)k, (int)randomizedchunks[k].globalts, (int)(randomizedchunks[k].globalte()-1), (int)(chunksinram-1));
|
||||
chunkdata.releasedata();
|
||||
chunksinram--;
|
||||
}
|
||||
|
@ -770,7 +770,7 @@ private:
|
|||
return false;
|
||||
|
||||
if (verbosity)
|
||||
fprintf (stderr, "requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", chunkindex, chunk.globalts, chunk.globalte()-1, chunksinram+1);
|
||||
fprintf (stderr, "requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", (int)chunkindex, (int)chunk.globalts, (int)(chunk.globalte()-1), (int)(chunksinram+1));
|
||||
msra::util::attempt (5, [&]() // (reading from network)
|
||||
{
|
||||
chunkdata.requiredata (featkind, featdim, sampperiod, this->lattices);
|
||||
|
@ -861,7 +861,7 @@ public:
|
|||
|
||||
// return these utterances
|
||||
if (verbosity > 0)
|
||||
fprintf (stderr, "getbatch: getting utterances %d..%d (%d frames out of %d requested) in sweep %d\n", spos, epos -1, mbframes, framesrequested, sweep);
|
||||
fprintf (stderr, "getbatch: getting utterances %d..%d (%d frames out of %d requested) in sweep %d\n", (int)spos, (int)(epos -1), (int)mbframes, (int)framesrequested, (int)sweep);
|
||||
size_t tspos = 0; // relative start of utterance 'pos' within the returned minibatch
|
||||
for (size_t pos = spos; pos < epos; pos++)
|
||||
{
|
||||
|
@ -927,7 +927,7 @@ public:
|
|||
const size_t windowend = randomizedchunks[lastchunk].windowend;
|
||||
if (verbosity > 0)
|
||||
fprintf (stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n",
|
||||
globalts, globalte, mbframes, framesrequested, sweep, firstchunk, lastchunk, windowbegin, windowend);
|
||||
(int)globalts, (int)globalte, (int)mbframes, (int)framesrequested, (int)sweep, (int)firstchunk, (int)lastchunk, (int)windowbegin, (int)windowend);
|
||||
// release all data outside, and page in all data inside
|
||||
for (size_t k = 0; k < windowbegin; k++)
|
||||
releaserandomizedchunk (k);
|
||||
|
|
|
@ -54,9 +54,14 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
size_t numframes() const { return parsedpath.numframes(); }
|
||||
const wstring key() const // key used for looking up lattice (not stored to save space)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
static const wstring emptywstring;
|
||||
static const wregex deleteextensionre (L"\\.[^\\.\\\\/:]*$");
|
||||
return regex_replace (logicalpath(), deleteextensionre, emptywstring); // delete extension (or not if none)
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
return removeExtension(basename(logicalpath()));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
struct utterancechunkdata // data for a chunk of utterances
|
||||
|
@ -116,7 +121,7 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
if (featdim == 0)
|
||||
{
|
||||
reader.getinfo (utteranceset[0].parsedpath, featkind, featdim, sampperiod);
|
||||
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", featdim, featkind.c_str(), sampperiod / 1e4);
|
||||
fprintf (stderr, "requiredata: determined feature kind as %d-dimensional '%s' with frame shift %.1f ms\n", (int)featdim, featkind.c_str(), sampperiod / 1e4);
|
||||
}
|
||||
// read all utterances; if they are in the same archive, htkfeatreader will be efficient in not closing the file
|
||||
frames.resize (featdim, totalframes);
|
||||
|
@ -134,7 +139,7 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
}
|
||||
//fprintf (stderr, "\n");
|
||||
if (verbosity)
|
||||
fprintf (stderr, "requiredata: %d utterances read\n", utteranceset.size());
|
||||
fprintf (stderr, "requiredata: %d utterances read\n", (int)utteranceset.size());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
|
@ -203,41 +208,30 @@ class minibatchutterancesourcemulti : public minibatchsource
|
|||
}
|
||||
};
|
||||
std::vector<utteranceref> randomizedutterancerefs; // [pos] randomized utterance ids
|
||||
std::hash_map<size_t,size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
|
||||
std::unordered_map<size_t, size_t> randomizedutteranceposmap; // [globalts] -> pos lookup table
|
||||
struct positionchunkwindow // chunk window required in memory when at a certain position, for controlling paging
|
||||
{
|
||||
std::vector<chunk>::const_iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
|
||||
std::vector<chunk>::iterator definingchunk; // the chunk in randomizedchunks[] that defined the utterance position of this utterance
|
||||
size_t windowbegin() const { return definingchunk->windowbegin; }
|
||||
size_t windowend() const { return definingchunk->windowend; }
|
||||
bool isvalidforthisposition (const utteranceref & utt) const
|
||||
{
|
||||
return utt.chunkindex >= windowbegin() && utt.chunkindex < windowend(); // check if 'utt' lives in is in allowed range for this position
|
||||
}
|
||||
positionchunkwindow (std::vector<chunk>::const_iterator definingchunk) : definingchunk (definingchunk) {}
|
||||
positionchunkwindow (std::vector<chunk>::iterator definingchunk) : definingchunk (definingchunk) {}
|
||||
};
|
||||
std::vector<positionchunkwindow> positionchunkwindows; // [utterance position] -> [windowbegin, windowend) for controlling paging
|
||||
|
||||
// frame-level randomization layered on top of utterance chunking (randomized, where randomization is cached)
|
||||
struct frameref
|
||||
{
|
||||
#ifdef _WIN64 // (sadly, the compiler makes this 8 bytes, not 6)
|
||||
unsigned short chunkindex; // lives in this chunk (index into randomizedchunks[])
|
||||
unsigned short utteranceindex; // utterance index in that chunk
|
||||
static const size_t maxutterancesperchunk = 65535;
|
||||
unsigned short frameindex; // frame index within the utterance
|
||||
static const size_t maxframesperutterance = 65535;
|
||||
#else // For Win32, we care to keep it inside 32 bits. We have already encountered setups where that's not enough.
|
||||
unsigned int chunkindex : 13; // lives in this chunk (index into randomizedchunks[])
|
||||
unsigned int utteranceindex : 8; // utterance index in that chunk
|
||||
static const size_t maxutterancesperchunk = 255;
|
||||
unsigned int frameindex : 11; // frame index within the utterance
|
||||
static const size_t maxframesperutterance = 2047;
|
||||
#endif
|
||||
frameref (size_t ci, size_t ui, size_t fi) : chunkindex ((unsigned short) ci), utteranceindex ((unsigned short) ui), frameindex ((unsigned short) fi)
|
||||
{
|
||||
#ifndef _WIN64
|
||||
static_assert (sizeof (frameref) == 4, "frameref: bit fields too large to fit into 32-bit integer");
|
||||
#endif
|
||||
if (ci == chunkindex && ui == utteranceindex && fi == frameindex)
|
||||
return;
|
||||
throw std::logic_error ("frameref: bit fields too small");
|
||||
|
@ -334,8 +328,8 @@ public:
|
|||
// first check consistency across feature streams
|
||||
// We'll go through the SCP files for each stream to make sure the duration is consistent
|
||||
// If not, we'll plan to ignore the utterance, and inform the user
|
||||
// m indexes the feature stream
|
||||
// i indexes the files within a stream, i.e. in the SCP file)
|
||||
// m indexes the feature stream
|
||||
// i indexes the files within a stream, i.e. in the SCP file)
|
||||
foreach_index(m, infiles){
|
||||
if (m == 0){
|
||||
numutts = infiles[m].size();
|
||||
|
@ -353,7 +347,7 @@ public:
|
|||
throw std::runtime_error("minibatchutterancesource: utterances < 2 frames not supported");
|
||||
if (uttframes > frameref::maxframesperutterance)
|
||||
{
|
||||
fprintf(stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S\n", i, uttframes, frameref::maxframesperutterance, key.c_str());
|
||||
fprintf(stderr, "minibatchutterancesource: skipping %d-th file (%d frames) because it exceeds max. frames (%d) for frameref bit field: %S\n", i, (int)uttframes, (int)frameref::maxframesperutterance, key.c_str());
|
||||
uttduration[i] = 0;
|
||||
uttisvalid[i] = false;
|
||||
}
|
||||
|
@ -363,7 +357,7 @@ public:
|
|||
uttisvalid[i] = true;
|
||||
}
|
||||
else if (uttduration[i] != uttframes){
|
||||
fprintf(stderr, "minibatchutterancesource: skipping %d-th file due to inconsistency in duration in different feature streams (%d vs %d frames)\n", i, uttduration[i], uttframes);
|
||||
fprintf(stderr, "minibatchutterancesource: skipping %d-th file due to inconsistency in duration in different feature streams (%d vs %d frames)\n", i, (int)uttduration[i], (int)uttframes);
|
||||
uttduration[i] = 0;
|
||||
uttisvalid[i] = false;
|
||||
}
|
||||
|
@ -378,7 +372,7 @@ public:
|
|||
if (invalidutts > uttisvalid.size() / 2)
|
||||
throw std::runtime_error("minibatchutterancesource: too many files with inconsistent durations, assuming broken configuration\n");
|
||||
else if (invalidutts>0)
|
||||
fprintf(stderr, "Found inconsistent durations across feature streams in %d out of %d files\n", invalidutts, uttisvalid.size());
|
||||
fprintf(stderr, "Found inconsistent durations across feature streams in %d out of %d files\n", (int)invalidutts, (int)uttisvalid.size());
|
||||
|
||||
|
||||
// now process the features and labels
|
||||
|
@ -459,7 +453,7 @@ public:
|
|||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
|
||||
if (labframes != uttframes)
|
||||
{
|
||||
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", labframes, uttframes, key.c_str());
|
||||
fprintf (stderr, " [duration mismatch (%d in label vs. %d in feat file), skipping %S]", (int)labframes, (int)uttframes, key.c_str());
|
||||
nomlf++;
|
||||
uttisvalid[i] = false;
|
||||
//continue; // skip this utterance at all
|
||||
|
@ -484,13 +478,13 @@ public:
|
|||
}
|
||||
if (e.classid >= udim[j])
|
||||
{
|
||||
throw std::runtime_error(msra::strfun::strprintf("minibatchutterancesource: class id %d exceeds model output dimension %d in file %S", e.classid, udim, key.c_str()));
|
||||
throw std::runtime_error(msra::strfun::strprintf("minibatchutterancesource: class id %d exceeds model output dimension %d in file %S", e.classid, udim[j], key.c_str()));
|
||||
}
|
||||
if (e.classid != (CLASSIDTYPE) e.classid)
|
||||
throw std::runtime_error ("CLASSIDTYPE has too few bits");
|
||||
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
|
||||
classids[j]->push_back ((CLASSIDTYPE) e.classid);
|
||||
numclasses[j] = max (numclasses[j], 1u + e.classid);
|
||||
numclasses[j] = max (numclasses[j], (size_t)(1u + e.classid));
|
||||
counts[j].resize (numclasses[j], 0);
|
||||
counts[j][e.classid] += e.numframes;
|
||||
}
|
||||
|
@ -521,7 +515,7 @@ public:
|
|||
else
|
||||
assert(utteranceset.size() == utterancesetsize);
|
||||
|
||||
fprintf (stderr, "feature set %d: %d frames in %d out of %d utterances\n", m, _totalframes, utteranceset.size(),infiles[m].size());
|
||||
fprintf (stderr, "feature set %d: %d frames in %d out of %d utterances\n", m, (int)_totalframes, (int)utteranceset.size(), (int)infiles[m].size());
|
||||
|
||||
if (!labels.empty()){
|
||||
foreach_index (j, labels){
|
||||
|
@ -538,11 +532,11 @@ public:
|
|||
}
|
||||
if (nomlf + nolat > 0)
|
||||
{
|
||||
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", infiles[0].size(), nomlf, nolat);
|
||||
fprintf (stderr, "minibatchutterancesource: out of %d files, %d files not found in label set and %d have no lattice\n", (int)infiles[0].size(), (int)nomlf, (int)nolat);
|
||||
if (nomlf + nolat > infiles[m].size() / 2)
|
||||
throw std::runtime_error ("minibatchutterancesource: too many files not found in label set--assuming broken configuration\n");
|
||||
}
|
||||
if (m==0) {foreach_index(j, numclasses) { fprintf(stderr,"label set %d: %d classes\n",j, numclasses[j]); } }
|
||||
if (m==0) {foreach_index(j, numclasses) { fprintf(stderr,"label set %d: %d classes\n", j, (int)numclasses[j]); } }
|
||||
// distribute them over chunks
|
||||
// We simply count off frames until we reach the chunk size.
|
||||
// Note that we first randomize the chunks, i.e. when used, chunks are non-consecutive and thus cause the disk head to seek for each chunk.
|
||||
|
@ -568,7 +562,7 @@ public:
|
|||
}
|
||||
numutterances = utteranceset.size();
|
||||
fprintf (stderr, "minibatchutterancesource: %d utterances grouped into %d chunks, av. chunk size: %.1f utterances, %.1f frames\n",
|
||||
numutterances, thisallchunks.size(), numutterances / (double) thisallchunks.size(), _totalframes / (double) thisallchunks.size());
|
||||
(int)numutterances, (int)thisallchunks.size(), numutterances / (double) thisallchunks.size(), _totalframes / (double) thisallchunks.size());
|
||||
// Now utterances are stored exclusively in allchunks[]. They are never referred to by a sequential utterance id at this point, only by chunk/within-chunk index.
|
||||
}
|
||||
// preliminary mem allocation for frame references (if in frame mode)
|
||||
|
@ -657,7 +651,7 @@ private:
|
|||
|
||||
currentsweep = sweep;
|
||||
if (verbosity>0)
|
||||
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", currentsweep, framemode ? "frame" : "utterance");
|
||||
fprintf (stderr, "lazyrandomization: re-randomizing for sweep %d in %s mode\n", (int)currentsweep, framemode ? "frame" : "utterance");
|
||||
|
||||
const size_t sweepts = sweep * _totalframes; // first global frame index for this sweep
|
||||
|
||||
|
@ -968,7 +962,7 @@ private:
|
|||
{
|
||||
if (verbosity)
|
||||
fprintf (stderr, "releaserandomizedchunk: paging out randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n",
|
||||
k, randomizedchunks[m][k].globalts, randomizedchunks[m][k].globalte()-1, chunksinram-1);
|
||||
(int)k, (int)randomizedchunks[m][k].globalts, (int)(randomizedchunks[m][k].globalte()-1), (int)(chunksinram-1));
|
||||
chunkdata.releasedata();
|
||||
numreleased++;
|
||||
}
|
||||
|
@ -1010,7 +1004,7 @@ private:
|
|||
auto & chunk = randomizedchunks[m][chunkindex];
|
||||
auto & chunkdata = chunk.getchunkdata();
|
||||
if (verbosity)
|
||||
fprintf (stderr, "feature set %d: requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", m, chunkindex, chunk.globalts, chunk.globalte()-1, chunksinram+1);
|
||||
fprintf (stderr, "feature set %d: requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", m, (int)chunkindex, (int)chunk.globalts, (int)(chunk.globalte()-1), (int)(chunksinram+1));
|
||||
msra::util::attempt (5, [&]() // (reading from network)
|
||||
{
|
||||
chunkdata.requiredata (featkind[m], featdim[m], sampperiod[m], this->lattices, verbosity);
|
||||
|
@ -1154,7 +1148,7 @@ public:
|
|||
}
|
||||
// return these utterances
|
||||
if (verbosity > 0)
|
||||
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", spos, epos - 1, tspos, mbframes, framesrequested, sweep);
|
||||
fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int)spos, (int)(epos - 1), (int)tspos, (int)mbframes, (int)framesrequested, (int)sweep);
|
||||
tspos = 0; // relative start of utterance 'pos' within the returned minibatch
|
||||
for (size_t pos = spos; pos < epos; pos++)
|
||||
{
|
||||
|
@ -1239,9 +1233,9 @@ public:
|
|||
const size_t lastchunk = chunkforframepos (globalte-1);
|
||||
const size_t windowbegin = randomizedchunks[0][firstchunk].windowbegin;
|
||||
const size_t windowend = randomizedchunks[0][lastchunk].windowend;
|
||||
if (verbosity)
|
||||
if (verbosity > 0)
|
||||
fprintf (stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n",
|
||||
globalts, globalte, mbframes, framesrequested, sweep, firstchunk, lastchunk, windowbegin, windowend);
|
||||
(int)globalts, (int)globalte, (int)mbframes, (int)framesrequested, (int)sweep, (int)firstchunk, (int)lastchunk, (int)windowbegin, (int)windowend);
|
||||
// release all data outside, and page in all data inside
|
||||
for (size_t k = 0; k < windowbegin; k++)
|
||||
releaserandomizedchunk (k);
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
//
|
||||
// <copyright file="DataReader.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// DataReader.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "basetypes.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
//#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
|
||||
#include "simplesenonehmm.h" // for MMI scoring
|
||||
//#include "msra_mgram.h" // for unigram scores of ground-truth path in sequence training
|
||||
|
||||
#include "rollingwindowsource.h" // minibatch sources
|
||||
#include "utterancesource.h"
|
||||
//#include "readaheadsource.h"
|
||||
#include "chunkevalsource.h"
|
||||
#define DATAREADER_EXPORTS
|
||||
#include "DataReader.h"
|
||||
#include "HTKMLFReader.h"
|
||||
#include "commandArgUtil.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template<class ElemType>
|
||||
void DATAREADER_API GetReader(IDataReader<ElemType>** preader)
|
||||
{
|
||||
*preader = new HTKMLFReader<ElemType>();
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API void GetReaderF(IDataReader<float>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
extern "C" DATAREADER_API void GetReaderD(IDataReader<double>** preader)
|
||||
{
|
||||
GetReader(preader);
|
||||
}
|
||||
|
||||
// Utility function, in ConfigFile.cpp, but HTKMLFReader doesn't need that code...
|
||||
|
||||
// Trim - trim white space off the start and end of the string
|
||||
// str - string to trim
|
||||
// NOTE: if the entire string is empty, then the string will be set to an empty string
|
||||
/* void Trim(std::string& str)
|
||||
{
|
||||
auto found = str.find_first_not_of(" \t");
|
||||
if (found == npos)
|
||||
{
|
||||
str.erase(0);
|
||||
return;
|
||||
}
|
||||
str.erase(0, found);
|
||||
found = str.find_last_not_of(" \t");
|
||||
if (found != npos)
|
||||
str.erase(found+1);
|
||||
}*/
|
||||
|
||||
|
||||
}}}
|
|
@ -1,111 +0,0 @@
|
|||
//
|
||||
// <copyright file="DataWriter.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// DataWriter.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "basetypes.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
|
||||
#define DATAWRITER_EXPORTS
|
||||
#include "DataWriter.h"
|
||||
#include "HTKMLFWriter.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template<class ElemType>
|
||||
void DATAWRITER_API GetWriter(IDataWriter<ElemType>** pwriter)
|
||||
{
|
||||
*pwriter = new HTKMLFWriter<ElemType>();
|
||||
}
|
||||
|
||||
extern "C" DATAWRITER_API void GetWriterF(IDataWriter<float>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
extern "C" DATAWRITER_API void GetWriterD(IDataWriter<double>** pwriter)
|
||||
{
|
||||
GetWriter(pwriter);
|
||||
}
|
||||
|
||||
|
||||
template<class ElemType>
|
||||
void DataWriter<ElemType>::Init(const ConfigParameters& writerConfig)
|
||||
{
|
||||
m_dataWriter = new HTKMLFWriter<ElemType>();
|
||||
m_dataWriter->Init(writerConfig);
|
||||
}
|
||||
|
||||
|
||||
template<class ElemType>
|
||||
void DataWriter<ElemType>::GetDataWriter(const ConfigParameters& /*config*/)
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template<class ElemType>
|
||||
void DataWriter<ElemType>::Destroy()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
}
|
||||
|
||||
|
||||
// DataWriter Constructor
|
||||
// config - [in] configuration data for the data writer
|
||||
template<class ElemType>
|
||||
DataWriter<ElemType>::DataWriter(const ConfigParameters& config)
|
||||
{
|
||||
Init(config);
|
||||
}
|
||||
|
||||
|
||||
// destructor - cleanup temp files, etc.
|
||||
template<class ElemType>
|
||||
DataWriter<ElemType>::~DataWriter()
|
||||
{
|
||||
delete m_dataWriter;
|
||||
m_dataWriter = NULL;
|
||||
}
|
||||
|
||||
// GetSections - Get the sections of the file
|
||||
// sections - a map of section name to section. Data sepcifications from config file will be used to determine where and how to save data
|
||||
template<class ElemType>
|
||||
void DataWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections)
|
||||
{
|
||||
m_dataWriter->GetSections(sections);
|
||||
}
|
||||
|
||||
// SaveData - save data in the file/files
|
||||
// recordStart - Starting record number
|
||||
// matricies - a map of section name (section:subsection) to data pointer. Data sepcifications from config file will be used to determine where and how to save data
|
||||
// numRecords - number of records we are saving, can be zero if not applicable
|
||||
// datasetSize - Size of the dataset
|
||||
// byteVariableSized - for variable sized data, size of current block to be written, zero when not used, or ignored if not variable sized data
|
||||
template<class ElemType>
|
||||
bool DataWriter<ElemType>::SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized)
|
||||
{
|
||||
return m_dataWriter->SaveData(recordStart, matrices, numRecords, datasetSize, byteVariableSized);
|
||||
}
|
||||
|
||||
// SaveMapping - save a map into the file
|
||||
// saveId - name of the section to save into (section:subsection format)
|
||||
// labelMapping - map we are saving to the file
|
||||
template<class ElemType>
|
||||
void DataWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping)
|
||||
{
|
||||
m_dataWriter->SaveMapping(saveId, labelMapping);
|
||||
}
|
||||
|
||||
//The explicit instantiation
|
||||
template class DataWriter<double>;
|
||||
template class DataWriter<float>;
|
||||
|
||||
}}}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,202 +0,0 @@
|
|||
//
|
||||
// <copyright file="HTKMLFReader.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// HTKMLFReader.h - Include file for the MTK and MLF format of features and samples
|
||||
#pragma once
|
||||
#include "DataReader.h"
|
||||
#include "commandArgUtil.h" // for intargvector
|
||||
#include "CUDAPageLockedMemAllocator.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template<class ElemType>
|
||||
class HTKMLFReader : public IDataReader<ElemType>
|
||||
{
|
||||
private:
|
||||
const static size_t m_htkRandomizeAuto = 0;
|
||||
const static size_t m_htkRandomizeDisable = (size_t)-1;
|
||||
|
||||
msra::dbn::minibatchiterator* m_mbiter;
|
||||
msra::dbn::minibatchsource* m_frameSource;
|
||||
//msra::dbn::minibatchreadaheadsource* m_readAheadSource;
|
||||
msra::dbn::FileEvalSource* m_fileEvalSource;
|
||||
msra::dbn::latticesource* m_lattices;
|
||||
map<wstring,msra::lattices::lattice::htkmlfwordsequence> m_latticeMap;
|
||||
|
||||
vector<bool> m_sentenceEnd;
|
||||
bool m_readAhead;
|
||||
bool m_truncated;
|
||||
bool m_framemode;
|
||||
vector<size_t> m_processedFrame;
|
||||
intargvector m_numberOfuttsPerMinibatchForAllEpochs;
|
||||
size_t m_numberOfuttsPerMinibatch;
|
||||
size_t m_actualnumberOfuttsPerMinibatch;
|
||||
size_t m_mbSize;
|
||||
vector<size_t> m_toProcess;
|
||||
vector<size_t> m_switchFrame;
|
||||
bool m_noData;
|
||||
|
||||
bool m_trainOrTest; // if false, in file writing mode
|
||||
using LabelType = typename IDataReader<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
|
||||
|
||||
std::map<LabelIdType, LabelType> m_idToLabelMap;
|
||||
|
||||
bool m_partialMinibatch; // allow partial minibatches?
|
||||
|
||||
std::vector<ElemType*> m_featuresBufferMultiUtt;
|
||||
std::vector<size_t> m_featuresBufferAllocatedMultiUtt;
|
||||
std::vector<ElemType*> m_labelsBufferMultiUtt;
|
||||
std::vector<size_t> m_labelsBufferAllocatedMultiUtt;
|
||||
std::vector<size_t> m_featuresStartIndexMultiUtt;
|
||||
std::vector<size_t> m_labelsStartIndexMultiUtt;
|
||||
|
||||
CUDAPageLockedMemAllocator* m_cudaAllocator;
|
||||
std::vector<std::shared_ptr<ElemType>> m_featuresBufferMultiIO;
|
||||
std::vector<size_t> m_featuresBufferAllocatedMultiIO;
|
||||
std::vector<std::shared_ptr<ElemType>> m_labelsBufferMultiIO;
|
||||
std::vector<size_t> m_labelsBufferAllocatedMultiIO;
|
||||
|
||||
std::map<std::wstring,size_t> m_featureNameToIdMap;
|
||||
std::map<std::wstring,size_t> m_labelNameToIdMap;
|
||||
std::map<std::wstring,size_t> m_nameToTypeMap;
|
||||
std::map<std::wstring,size_t> m_featureNameToDimMap;
|
||||
std::map<std::wstring,size_t> m_labelNameToDimMap;
|
||||
// for writing outputs to files (standard single input/output network) - deprecate eventually
|
||||
bool m_checkDictionaryKeys;
|
||||
bool m_convertLabelsToTargets;
|
||||
std::vector <bool> m_convertLabelsToTargetsMultiIO;
|
||||
std::vector<std::vector<std::wstring>> m_inputFilesMultiIO;
|
||||
|
||||
size_t m_inputFileIndex;
|
||||
std::vector<size_t> m_featDims;
|
||||
std::vector<size_t> m_labelDims;
|
||||
|
||||
std::vector<std::vector<std::vector<ElemType>>>m_labelToTargetMapMultiIO;
|
||||
|
||||
void PrepareForTrainingOrTesting(const ConfigParameters& config);
|
||||
void PrepareForWriting(const ConfigParameters& config);
|
||||
|
||||
bool GetMinibatchToTrainOrTest(std::map<std::wstring, Matrix<ElemType>*>&matrices);
|
||||
bool GetMinibatchToWrite(std::map<std::wstring, Matrix<ElemType>*>&matrices);
|
||||
|
||||
void StartMinibatchLoopToTrainOrTest(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize);
|
||||
void StartMinibatchLoopToWrite(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
|
||||
|
||||
bool ReNewBufferForMultiIO(size_t i);
|
||||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return m_numberOfuttsPerMinibatch ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
|
||||
void GetDataNamesFromConfig(const ConfigParameters& readerConfig, std::vector<std::wstring>& features, std::vector<std::wstring>& labels);
|
||||
|
||||
|
||||
size_t ReadLabelToTargetMappingFile (const std::wstring& labelToTargetMappingFile, const std::wstring& labelListFile, std::vector<std::vector<ElemType>>& labelToTargetMap);
|
||||
void ExpandDotDotDot(wstring & featPath, const wstring & scpPath, wstring & scpDirCached);
|
||||
enum InputOutputTypes
|
||||
{
|
||||
real,
|
||||
category,
|
||||
};
|
||||
|
||||
private:
|
||||
CUDAPageLockedMemAllocator* GetCUDAAllocator(int deviceID)
|
||||
{
|
||||
if (m_cudaAllocator != nullptr)
|
||||
{
|
||||
if (m_cudaAllocator->GetDeviceID() != deviceID)
|
||||
{
|
||||
delete m_cudaAllocator;
|
||||
m_cudaAllocator = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_cudaAllocator == nullptr)
|
||||
{
|
||||
m_cudaAllocator = new CUDAPageLockedMemAllocator(deviceID);
|
||||
}
|
||||
|
||||
return m_cudaAllocator;
|
||||
}
|
||||
|
||||
std::shared_ptr<ElemType> AllocateIntermediateBuffer(int deviceID, size_t numElements)
|
||||
{
|
||||
if (deviceID >= 0)
|
||||
{
|
||||
// Use pinned memory for GPU devices for better copy performance
|
||||
size_t totalSize = sizeof(ElemType) * numElements;
|
||||
return std::shared_ptr<ElemType>((ElemType*)GetCUDAAllocator(deviceID)->Malloc(totalSize), [this, deviceID](ElemType* p) {
|
||||
this->GetCUDAAllocator(deviceID)->Free((char*)p);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::shared_ptr<ElemType>(new ElemType[numElements], [](ElemType* p) {
|
||||
delete[] p;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// a matrix of n_stream x n_length
|
||||
/// n_stream is the number of streams
|
||||
/// n_length is the maximum lenght of each stream
|
||||
/// for example, two sentences used in parallel in one minibatch would be
|
||||
/// [2 x 5] if the max length of one of the sentences is 5
|
||||
/// the elements of the matrix is 0, 1, or -1, defined as SEQUENCE_START, SEQUENCE_MIDDLE, NO_INPUT in cbasetype.h
|
||||
/// 0 1 1 0 1
|
||||
/// 1 0 1 0 0
|
||||
/// for two parallel data streams. The first has two sentences, with 0 indicating begining of a sentence
|
||||
/// the second data stream has two sentences, with 0 indicating begining of sentences
|
||||
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
|
||||
/// frame.
|
||||
Matrix<ElemType> m_sentenceBegin;
|
||||
|
||||
/// a matrix of 1 x n_length
|
||||
/// 1 denotes the case that there exists sentnece begin or no_labels case in this frame
|
||||
/// 0 denotes such case is not in this frame
|
||||
|
||||
|
||||
vector<MinibatchPackingFlag> m_minibatchPackingFlag;
|
||||
|
||||
/// by default it is false
|
||||
/// if true, reader will set to SEQUENCE_MIDDLE for time positions that are orignally correspond to SEQUENCE_START
|
||||
/// set to true so that a current minibatch can uses state activities from the previous minibatch.
|
||||
/// default will have truncated BPTT, which only does BPTT inside a minibatch
|
||||
|
||||
bool mIgnoreSentenceBeginTag;
|
||||
HTKMLFReader() : m_sentenceBegin(CPUDEVICE) {
|
||||
}
|
||||
|
||||
virtual void Init(const ConfigParameters& config);
|
||||
virtual void Destroy() {delete this;}
|
||||
virtual ~HTKMLFReader();
|
||||
|
||||
virtual void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples = requestDataSize)
|
||||
{
|
||||
return StartDistributedMinibatchLoop(mbSize, epoch, 0, 1, requestedEpochSamples);
|
||||
}
|
||||
|
||||
virtual bool SupportsDistributedMBRead() const override
|
||||
{
|
||||
return m_frameSource->supportsbatchsubsetting();
|
||||
}
|
||||
|
||||
virtual void StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize) override;
|
||||
|
||||
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetSentenceEndInBatch(vector<size_t> &/*sentenceEnd*/);
|
||||
void SetSentenceEnd(int /*actualMbSize*/){};
|
||||
void SetSentenceSegBatch(Matrix<ElemType> &sentenceBegin, vector<MinibatchPackingFlag>& sentenceExistsBeginOrNoLabels);
|
||||
|
||||
bool RequireSentenceSeg() { return !m_framemode; };
|
||||
};
|
||||
|
||||
}}}
|
|
@ -1,184 +0,0 @@
|
|||
//
|
||||
// <copyright file="HTKMLFReader.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// HTKMLFReader.cpp : Defines the exported functions for the DLL application.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "basetypes.h"
|
||||
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
//#ifndef __unix__
|
||||
#include "ssematrix.h"
|
||||
//#endif
|
||||
|
||||
#define DATAWRITER_EXPORTS // creating the exports here
|
||||
#include "DataWriter.h"
|
||||
#include "commandArgUtil.h"
|
||||
#include "HTKMLFWriter.h"
|
||||
#ifdef LEAKDETECT
|
||||
#include <vld.h> // for memory leak detection
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// Create a Data Writer
|
||||
//DATAWRITER_API IDataWriter* DataWriterFactory(void)
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFWriter<ElemType>::Init(const ConfigParameters& writerConfig)
|
||||
{
|
||||
m_tempArray = nullptr;
|
||||
m_tempArraySize = 0;
|
||||
|
||||
vector<wstring> scriptpaths;
|
||||
vector<wstring> filelist;
|
||||
size_t numFiles;
|
||||
size_t firstfilesonly = SIZE_MAX; // set to a lower value for testing
|
||||
|
||||
ConfigArray outputNames = writerConfig("outputNodeNames","");
|
||||
if (outputNames.size()<1)
|
||||
RuntimeError("writer needs at least one outputNodeName specified in config");
|
||||
|
||||
|
||||
foreach_index(i, outputNames) // inputNames should map to node names
|
||||
{
|
||||
ConfigParameters thisOutput = writerConfig(outputNames[i]);
|
||||
if (thisOutput.Exists("dim"))
|
||||
udims.push_back(thisOutput("dim"));
|
||||
else
|
||||
RuntimeError("HTKMLFWriter::Init: writer need to specify dim of output");
|
||||
|
||||
if (thisOutput.Exists("file"))
|
||||
scriptpaths.push_back(thisOutput("file"));
|
||||
else if (thisOutput.Exists("scpFile"))
|
||||
scriptpaths.push_back(thisOutput("scpFile"));
|
||||
else
|
||||
RuntimeError("HTKMLFWriter::Init: writer needs to specify scpFile for output");
|
||||
|
||||
outputNameToIdMap[outputNames[i]]= i;
|
||||
outputNameToDimMap[outputNames[i]]=udims[i];
|
||||
wstring type = thisOutput("type","Real");
|
||||
if (type == L"Real")
|
||||
{
|
||||
outputNameToTypeMap[outputNames[i]] = OutputTypes::outputReal;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error ("HTKMLFWriter::Init: output type for writer output expected to be Real");
|
||||
}
|
||||
}
|
||||
|
||||
numFiles=0;
|
||||
foreach_index(i,scriptpaths)
|
||||
{
|
||||
filelist.clear();
|
||||
std::wstring scriptPath = scriptpaths[i];
|
||||
fprintf(stderr, "HTKMLFWriter::Init: reading output script file %S ...", scriptPath.c_str());
|
||||
size_t n = 0;
|
||||
for (msra::files::textreader reader(scriptPath); reader && filelist.size() <= firstfilesonly/*optimization*/; )
|
||||
{
|
||||
filelist.push_back (reader.wgetline());
|
||||
n++;
|
||||
}
|
||||
|
||||
fprintf (stderr, " %zu entries\n", n);
|
||||
|
||||
if (i==0)
|
||||
numFiles=n;
|
||||
else
|
||||
if (n!=numFiles)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("HTKMLFWriter:Init: number of files in each scriptfile inconsistent (%d vs. %d)", numFiles,n));
|
||||
|
||||
outputFiles.push_back(filelist);
|
||||
}
|
||||
outputFileIndex=0;
|
||||
sampPeriod=100000;
|
||||
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFWriter<ElemType>::Destroy()
|
||||
{
|
||||
delete [] m_tempArray;
|
||||
m_tempArray = nullptr;
|
||||
m_tempArraySize = 0;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFWriter<ElemType>::GetSections(std::map<std::wstring, SectionType, nocase_compare>& /*sections*/)
|
||||
{
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
bool HTKMLFWriter<ElemType>::SaveData(size_t /*recordStart*/, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t /*numRecords*/, size_t /*datasetSize*/, size_t /*byteVariableSized*/)
|
||||
{
|
||||
|
||||
|
||||
//std::map<std::wstring, void*, nocase_compare>::iterator iter;
|
||||
if (outputFileIndex>=outputFiles[0].size())
|
||||
RuntimeError("index for output scp file out of range...");
|
||||
|
||||
for (auto iter = matrices.begin();iter!=matrices.end(); iter++)
|
||||
{
|
||||
wstring outputName = iter->first;
|
||||
Matrix<ElemType>& outputData = *(static_cast<Matrix<ElemType>*>(iter->second));
|
||||
size_t id = outputNameToIdMap[outputName];
|
||||
size_t dim = outputNameToDimMap[outputName];
|
||||
wstring outFile = outputFiles[id][outputFileIndex];
|
||||
|
||||
assert(outputData.GetNumRows()==dim); dim;
|
||||
|
||||
SaveToFile(outFile,outputData);
|
||||
}
|
||||
|
||||
outputFileIndex++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFWriter<ElemType>::SaveToFile(std::wstring& outputFile, Matrix<ElemType>& outputData)
|
||||
{
|
||||
msra::dbn::matrix output;
|
||||
output.resize(outputData.GetNumRows(),outputData.GetNumCols());
|
||||
outputData.CopyToArray(m_tempArray, m_tempArraySize);
|
||||
ElemType * pValue = m_tempArray;
|
||||
|
||||
for (int j=0; j< outputData.GetNumCols(); j++)
|
||||
{
|
||||
for (int i=0; i<outputData.GetNumRows(); i++)
|
||||
{
|
||||
output(i,j) = (float)*pValue++;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t nansinf = output.countnaninf();
|
||||
if (nansinf > 0)
|
||||
fprintf (stderr, "chunkeval: %d NaNs or INF detected in '%S' (%d frames)\n", (int) nansinf, outputFile.c_str(), (int) output.cols());
|
||||
// save it
|
||||
msra::files::make_intermediate_dirs (outputFile);
|
||||
msra::util::attempt (5, [&]()
|
||||
{
|
||||
msra::asr::htkfeatwriter::write (outputFile, "USER", this->sampPeriod, output);
|
||||
});
|
||||
|
||||
fprintf (stderr, "evaluate: writing %zu frames of %S\n", output.cols(), outputFile.c_str());
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<class ElemType>
|
||||
void HTKMLFWriter<ElemType>::SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& /*labelMapping*/)
|
||||
{
|
||||
}
|
||||
|
||||
template class HTKMLFWriter<float>;
|
||||
template class HTKMLFWriter<double>;
|
||||
|
||||
}}}
|
|
@ -1,47 +0,0 @@
|
|||
//
|
||||
// <copyright file="HTKMLFReader.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// HTKMLFReader.h - Include file for the MTK and MLF format of features and samples
|
||||
#pragma once
|
||||
#include "DataWriter.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template<class ElemType>
|
||||
class HTKMLFWriter : public IDataWriter<ElemType>
|
||||
{
|
||||
private:
|
||||
std::vector<size_t> outputDims;
|
||||
std::vector<std::vector<std::wstring>> outputFiles;
|
||||
|
||||
std::vector<size_t> udims;
|
||||
std::map<std::wstring,size_t> outputNameToIdMap;
|
||||
std::map<std::wstring,size_t> outputNameToDimMap;
|
||||
std::map<std::wstring,size_t> outputNameToTypeMap;
|
||||
unsigned int sampPeriod;
|
||||
size_t outputFileIndex;
|
||||
void SaveToFile(std::wstring& outputFile, Matrix<ElemType>& outputData);
|
||||
ElemType * m_tempArray;
|
||||
size_t m_tempArraySize;
|
||||
|
||||
enum OutputTypes
|
||||
{
|
||||
outputReal,
|
||||
outputCategory,
|
||||
};
|
||||
|
||||
public:
|
||||
using LabelType = typename IDataWriter<ElemType>::LabelType;
|
||||
using LabelIdType = typename IDataWriter<ElemType>::LabelIdType;
|
||||
virtual void Init(const ConfigParameters& writerConfig);
|
||||
virtual void Destroy();
|
||||
virtual void GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections);
|
||||
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
|
||||
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
};
|
||||
|
||||
}}}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,885 +0,0 @@
|
|||
// TODO: This is a dup, we should get back to the shared one. But this one has some stuff the other doesn't.
|
||||
|
||||
//
|
||||
// <copyright file="basetypes.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#ifndef _BASETYPES_
|
||||
#define _BASETYPES_
|
||||
|
||||
// [kit]: seems SECURE_SCL=0 doesn't work - causes crashes in release mode
|
||||
// there are some complaints along this line on the web
|
||||
// so disabled for now
|
||||
//
|
||||
//// we have agreed that _SECURE_SCL is disabled for release builds
|
||||
//// it would be super dangerous to mix projects where this is inconsistent
|
||||
//// this is one way to detect possible mismatches
|
||||
//#ifdef NDEBUG
|
||||
//#if !defined(_CHECKED) && _SECURE_SCL != 0
|
||||
//#error "_SECURE_SCL should be disabled for release builds"
|
||||
//#endif
|
||||
//#endif
|
||||
|
||||
#ifndef UNDER_CE // fixed-buffer overloads not available for wince
|
||||
#ifdef _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES // fixed-buffer overloads for strcpy() etc.
|
||||
#undef _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES
|
||||
#endif
|
||||
#define _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES 1
|
||||
#endif
|
||||
|
||||
#pragma warning (push)
|
||||
#pragma warning (disable: 4793) // caused by varargs
|
||||
|
||||
// disable certain parts of basetypes for wince compilation
|
||||
#ifdef UNDER_CE
|
||||
#define BASETYPES_NO_UNSAFECRTOVERLOAD // disable unsafe CRT overloads (safe functions don't exist in wince)
|
||||
#define BASETYPES_NO_STRPRINTF // dependent functions here are not defined for wince
|
||||
#endif
|
||||
|
||||
#ifndef OACR // dummies when we are not compiling under Office
|
||||
#define OACR_WARNING_SUPPRESS(x, y)
|
||||
#define OACR_WARNING_DISABLE(x, y)
|
||||
#define OACR_WARNING_PUSH
|
||||
#define OACR_WARNING_POP
|
||||
#endif
|
||||
#ifndef OACR_ASSUME // this seems to be a different one
|
||||
#define OACR_ASSUME(x)
|
||||
#endif
|
||||
|
||||
// following oacr warnings are not level1 or level2-security
|
||||
// in currect stage we want to ignore those warnings
|
||||
// if necessay this can be fixed at later stage
|
||||
|
||||
// not a bug
|
||||
OACR_WARNING_DISABLE(EXC_NOT_CAUGHT_BY_REFERENCE, "Not indicating a bug or security threat.");
|
||||
OACR_WARNING_DISABLE(LOCALDECLHIDESLOCAL, "Not indicating a bug or security threat.");
|
||||
|
||||
// not reviewed
|
||||
OACR_WARNING_DISABLE(MISSING_OVERRIDE, "Not level1 or level2_security.");
|
||||
OACR_WARNING_DISABLE(EMPTY_DTOR, "Not level1 or level2_security.");
|
||||
OACR_WARNING_DISABLE(DEREF_NULL_PTR, "Not level1 or level2_security.");
|
||||
OACR_WARNING_DISABLE(INVALID_PARAM_VALUE_1, "Not level1 or level2_security.");
|
||||
OACR_WARNING_DISABLE(VIRTUAL_CALL_IN_CTOR, "Not level1 or level2_security.");
|
||||
OACR_WARNING_DISABLE(POTENTIAL_ARGUMENT_TYPE_MISMATCH, "Not level1 or level2_security.");
|
||||
|
||||
// determine WIN32 api calling convention
|
||||
// it seems this is normally stdcall?? but when compiling as /clr:pure or /clr:Safe
|
||||
// this is not supported, so in this case, we need to use the 'default' calling convention
|
||||
// TODO: can we reuse the #define of WINAPI??
|
||||
#ifdef _WIN32
|
||||
#ifdef _M_CEE_SAFE
|
||||
#define WINAPI_CC __clrcall
|
||||
#elif _M_CEE
|
||||
#define WINAPI_CC __clrcall
|
||||
#else
|
||||
#define WINAPI_CC __stdcall
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// fix some warnings in STL
|
||||
#if !defined(_DEBUG) || defined(_CHECKED) || defined(_MANAGED)
|
||||
#pragma warning(disable : 4702) // unreachable code
|
||||
#endif
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h> // include here because we redefine some names later
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cmath> // for HUGE_VAL
|
||||
#include <assert.h>
|
||||
#include <map>
|
||||
#ifdef __windows__
|
||||
#include <windows.h> // for CRITICAL_SECTION
|
||||
#include <strsafe.h> // for strbcpy() etc templates
|
||||
#endif
|
||||
#if __unix__
|
||||
#include <strings.h>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <dlfcn.h>
|
||||
typedef unsigned char byte;
|
||||
#endif
|
||||
|
||||
|
||||
#pragma push_macro("STRSAFE_NO_DEPRECATE")
|
||||
#define STRSAFE_NO_DEPRECATE // deprecation managed elsewhere, not by strsafe
|
||||
#pragma pop_macro("STRSAFE_NO_DEPRECATE")
|
||||
|
||||
// CRT error handling seems to not be included in wince headers
|
||||
// so we define our own imports
|
||||
#ifdef UNDER_CE
|
||||
|
||||
// TODO: is this true - is GetLastError == errno?? - also this adds a dependency on windows.h
|
||||
#define errno GetLastError()
|
||||
|
||||
// strerror(x) - x here is normally errno - TODO: make this return errno as a string
|
||||
#define strerror(x) "strerror error but can't report error number sorry!"
|
||||
#endif
|
||||
|
||||
#ifndef __in // dummies for sal annotations if compiler does not support it
|
||||
#define __in
|
||||
#define __inout_z
|
||||
#define __in_count(x)
|
||||
#define __inout_cap(x)
|
||||
#define __inout_cap_c(x)
|
||||
#endif
|
||||
#ifndef __out_z_cap // non-VS2005 annotations
|
||||
#define __out_cap(x)
|
||||
#define __out_z_cap(x)
|
||||
#define __out_cap_c(x)
|
||||
#endif
|
||||
|
||||
#ifndef __override // and some more non-std extensions required by Office
|
||||
#define __override virtual
|
||||
#endif
|
||||
|
||||
// disable warnings for which fixing would make code less readable
|
||||
#pragma warning(disable : 4290) // throw() declaration ignored
|
||||
#pragma warning(disable : 4244) // conversion from typeA to typeB, possible loss of data
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// basic macros
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#define SAFE_DELETE(p) { if(p) { delete (p); (p)=NULL; } }
|
||||
#define SAFE_RELEASE(p) { if(p) { (p)->Release(); (p)=NULL; } } // nasty! use CComPtr<>
|
||||
#ifndef ASSERT
|
||||
#ifdef _CHECKED // basetypes.h expects this function to be defined (it is in message.h)
|
||||
extern void _CHECKED_ASSERT_error(const char * file, int line, const char * exp);
|
||||
#define ASSERT(exp) ((exp)||(_CHECKED_ASSERT_error(__FILE__,__LINE__,#exp),0))
|
||||
#else
|
||||
#define ASSERT assert
|
||||
#endif
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
// ----------------------------------------------------------------------------
|
||||
// basic data types
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace basetypes {
|
||||
|
||||
// class ARRAY -- std::vector with array-bounds checking
|
||||
// VS 2008 and above do this, so there is no longer a need for this.
|
||||
|
||||
template<class _ElemType>
|
||||
class ARRAY : public std::vector<_ElemType>
|
||||
{
|
||||
#if defined (_DEBUG) || defined (_CHECKED) // debug version with range checking
|
||||
static void throwOutOfBounds()
|
||||
{ // (moved to separate function hoping to keep inlined code smaller
|
||||
OACR_WARNING_PUSH;
|
||||
OACR_WARNING_DISABLE(IGNOREDBYCOMMA, "Reviewd OK. Special trick below to show a message when assertion fails"
|
||||
"[rogeryu 2006/03/24]");
|
||||
OACR_WARNING_DISABLE(BOGUS_EXPRESSION_LIST, "This is intentional. [rogeryu 2006/03/24]");
|
||||
ASSERT (("ARRAY::operator[] out of bounds", false));
|
||||
OACR_WARNING_POP;
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
||||
ARRAY() : std::vector<_ElemType> () { }
|
||||
ARRAY (int size) : std::vector<_ElemType> (size) { }
|
||||
|
||||
#if defined (_DEBUG) || defined (_CHECKED) // debug version with range checking
|
||||
// ------------------------------------------------------------------------
|
||||
// operator[]: with array-bounds checking
|
||||
// ------------------------------------------------------------------------
|
||||
|
||||
inline _ElemType & operator[] (int index) // writing
|
||||
{
|
||||
if (index < 0 || index >= size()) throwOutOfBounds();
|
||||
return (*(std::vector<_ElemType>*) this)[index];
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
|
||||
inline const _ElemType & operator[] (int index) const // reading
|
||||
{
|
||||
if (index < 0 || index >= size()) throwOutOfBounds();
|
||||
return (*(std::vector<_ElemType>*) this)[index];
|
||||
}
|
||||
#endif
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
// size(): same as base class, but returning an 'int' instead of 'size_t'
|
||||
// to allow for better readable code
|
||||
// ------------------------------------------------------------------------
|
||||
|
||||
inline int size() const
|
||||
{
|
||||
size_t siz = ((std::vector<_ElemType>*) this)->size();
|
||||
return (int) siz;
|
||||
}
|
||||
};
|
||||
// overload swap(), otherwise we'd fallback to 3-way assignment & possibly throw
|
||||
template<class _T> inline void swap (ARRAY<_T> & L, ARRAY<_T> & R) throw()
|
||||
{ swap ((std::vector<_T> &) L, (std::vector<_T> &) R); }
|
||||
|
||||
// class fixed_vector - non-resizable vector
|
||||
|
||||
template<class _T> class fixed_vector
|
||||
{
|
||||
_T * p; // pointer array
|
||||
size_t n; // number of elements
|
||||
void check (int index) const { index; ASSERT (index >= 0 && (size_t) index < n); }
|
||||
void check (size_t index) const { index; ASSERT (index < n); }
|
||||
// ... TODO: when I make this public, LinearTransform.h acts totally up but I cannot see where it comes from.
|
||||
//fixed_vector (const fixed_vector & other) : n (0), p (NULL) { *this = other; }
|
||||
public:
|
||||
fixed_vector() : n (0), p (NULL) { }
|
||||
void resize (int size) { clear(); if (size > 0) { p = new _T[size]; n = size; } }
|
||||
void resize (size_t size) { clear(); if (size > 0) { p = new _T[size]; n = size; } }
|
||||
fixed_vector (int size) : n (size), p (size > 0 ? new _T[size] : NULL) { }
|
||||
fixed_vector (size_t size) : n ((int) size), p (size > 0 ? new _T[size] : NULL) { }
|
||||
~fixed_vector() { delete[] p; }
|
||||
inline int size() const { return (int) n; }
|
||||
inline int capacity() const { return (int) n; }
|
||||
inline bool empty() const { return n == 0; }
|
||||
void clear() { delete[] p; p = NULL; n = 0; }
|
||||
_T * begin() { return p; }
|
||||
const _T * begin() const { return p; }
|
||||
_T * end() { return p + n; } // note: n == 0 so result is NULL
|
||||
inline _T & operator[] (int index) { check (index); return p[index]; } // writing
|
||||
inline const _T & operator[] (int index) const { check (index); return p[index]; } // reading
|
||||
inline _T & operator[] (size_t index) { check (index); return p[index]; } // writing
|
||||
inline const _T & operator[] (size_t index) const { check (index); return p[index]; } // reading
|
||||
inline int indexof (const _T & elem) const { ASSERT (&elem >= p && &elem < p + n); return &elem - p; }
|
||||
inline void swap (fixed_vector & other) throw() { std::swap (other.p, p); std::swap (other.n, n); }
|
||||
template<class VECTOR> fixed_vector & operator= (const VECTOR & other)
|
||||
{
|
||||
int other_n = (int) other.size();
|
||||
fixed_vector tmp (other_n);
|
||||
for (int k = 0; k < other_n; k++) tmp[k] = other[k];
|
||||
swap (tmp);
|
||||
return *this;
|
||||
}
|
||||
fixed_vector & operator= (const fixed_vector & other)
|
||||
{
|
||||
int other_n = (int) other.size();
|
||||
fixed_vector tmp (other_n);
|
||||
for (int k = 0; k < other_n; k++) tmp[k] = other[k];
|
||||
swap (tmp);
|
||||
return *this;
|
||||
}
|
||||
template<class VECTOR> fixed_vector (const VECTOR & other) : n (0), p (NULL) { *this = other; }
|
||||
};
|
||||
template<class _T> inline void swap (fixed_vector<_T> & L, fixed_vector<_T> & R) throw() { L.swap (R); }
|
||||
|
||||
// class matrix - simple fixed-size 2-dimensional array, access elements as m(i,j)
|
||||
// stored as concatenation of rows
|
||||
|
||||
template<class T> class matrix : fixed_vector<T>
|
||||
{
|
||||
size_t numcols;
|
||||
size_t locate (size_t i, size_t j) const { ASSERT (i < rows() && j < cols()); return i * cols() + j; }
|
||||
public:
|
||||
typedef T elemtype;
|
||||
matrix() : numcols (0) {}
|
||||
matrix (size_t n, size_t m) { resize (n, m); }
|
||||
void resize (size_t n, size_t m) { numcols = m; fixed_vector<T>::resize (n * m); }
|
||||
size_t cols() const { return numcols; }
|
||||
size_t rows() const { return empty() ? 0 : size() / cols(); }
|
||||
size_t size() const { return fixed_vector<T>::size(); } // use this for reading and writing... not nice!
|
||||
bool empty() const { return fixed_vector<T>::empty(); }
|
||||
T & operator() (size_t i, size_t j) { return (*this)[locate(i,j)]; }
|
||||
const T & operator() (size_t i, size_t j) const { return (*this)[locate(i,j)]; }
|
||||
void swap (matrix & other) throw() { std::swap (numcols, other.numcols); fixed_vector<T>::swap (other); }
|
||||
};
|
||||
template<class _T> inline void swap (matrix<_T> & L, matrix<_T> & R) throw() { L.swap (R); }
|
||||
|
||||
// TODO: get rid of these
|
||||
typedef std::string STRING;
|
||||
typedef std::wstring WSTRING;
|
||||
#ifdef __unix__
|
||||
typedef wchar_t TCHAR;
|
||||
#endif
|
||||
typedef std::basic_string<TCHAR> TSTRING; // wide/narrow character string
|
||||
|
||||
// derive from this for noncopyable classes (will get you private unimplemented copy constructors)
|
||||
// ... TODO: change all of basetypes classes/structs to use this
|
||||
class noncopyable
|
||||
{
|
||||
noncopyable & operator= (const noncopyable &);
|
||||
noncopyable (const noncopyable &);
|
||||
public:
|
||||
noncopyable(){}
|
||||
};
|
||||
|
||||
struct throw_hr
|
||||
{
|
||||
const char * msg;
|
||||
inline throw_hr (const char * msg = NULL) : msg (msg) {}
|
||||
};
|
||||
|
||||
// back-mapping of exceptions to HRESULT codes
|
||||
// usage pattern: HRESULT COM_function (...) { try { exception-based function body } catch_hr_return; }
|
||||
#define catch_hr_return \
|
||||
catch (const bad_alloc &) { return E_OUTOFMEMORY; } \
|
||||
catch (const bad_hr & e) { return e.hr; } \
|
||||
catch (const invalid_argument &) { return E_INVALIDARG; } \
|
||||
catch (const runtime_error &) { return E_FAIL; } \
|
||||
catch (const logic_error &) { return E_UNEXPECTED; } \
|
||||
catch (const exception &) { return E_FAIL; } \
|
||||
return S_OK;
|
||||
|
||||
};}; // namespace
|
||||
|
||||
#ifndef BASETYPES_NO_UNSAFECRTOVERLOAD // if on, no unsafe CRT overload functions
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// overloads for "unsafe" CRT functions used in our code base
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// strlen/wcslen overloads for fixed-buffer size
|
||||
|
||||
// Note: Careful while fixing bug related to these templates.
|
||||
// In all attempted experiments, in seems all 6 definitions are required
|
||||
// below to get the correct behaviour. Be very very careful
|
||||
// not to delete something without testing that case 5&6 have "size" deduced.
|
||||
// 1. char *
|
||||
// 2. char * const
|
||||
// 3. const char *
|
||||
// 4. const char * const
|
||||
// 5. char (&) [size]
|
||||
// 6. const char (&) [size]
|
||||
// the following includes all headers that use strlen() and fail because of the mapping below
|
||||
// to find those, change #define strlen strlen_ to something invalid e.g. strlen::strlen_
|
||||
#if _MSC_VER >= 1600 // VS 2010 --TODO: fix this by correct include order instead
|
||||
#include <intrin.h> // defines strlen() as an intrinsic in VS 2010
|
||||
#include <typeinfo> // uses strlen()
|
||||
#include <xlocale> // uses strlen()
|
||||
#endif
|
||||
#define strlen strlen_
|
||||
template<typename _T>
|
||||
size_t strlen_(_T &s) { return strnlen_s(static_cast<const char *>(s), SIZE_MAX); } // never be called but needed to keep compiler happy
|
||||
template<typename _T> inline size_t strlen_(const _T &s) { return strnlen(static_cast<const char *>(s), SIZE_MAX); }
|
||||
template<> inline size_t strlen_(char * &s) { return strnlen(s, SIZE_MAX); }
|
||||
template<> inline size_t strlen_(const char * &s) { return strnlen(s, SIZE_MAX); }
|
||||
template<size_t n> inline size_t strlen_(const char (&s)[n]) { return (strnlen(s, n)); }
|
||||
template<size_t n> inline size_t strlen_(char (&s)[n]) { return (strnlen(s, n)); }
|
||||
#define wcslen wcslen_
|
||||
template<typename _T>
|
||||
size_t wcslen_(_T &s) { return wcsnlen_s(static_cast<const wchar_t *>(s), SIZE_MAX); } // never be called but needed to keep compiler happy
|
||||
template<> inline size_t wcslen_(wchar_t * &s) { return wcsnlen(s, SIZE_MAX); }
|
||||
template<> inline size_t wcslen_(const wchar_t * &s) { return wcsnlen(s, SIZE_MAX); }
|
||||
template<size_t n> inline size_t wcslen_(const wchar_t (&s)[n]) { return (wcsnlen(s, n)); }
|
||||
template<size_t n> inline size_t wcslen_(wchar_t (&s)[n]) { return (wcsnlen(s, n)); }
|
||||
|
||||
// xscanf wrappers -- one overload for each actual use case in our code base
|
||||
static inline int sscanf (const char * buf, const char * format, int * i1) { return sscanf (buf, format, i1); }
|
||||
static inline int sscanf (const char * buf, const char * format, int * i1, int * i2) { return sscanf (buf, format, i1, i2); }
|
||||
static inline int sscanf (const char * buf, const char * format, int * i1, int * i2, int * i3) { return sscanf (buf, format, i1, i2, i3); }
|
||||
static inline int sscanf (const char * buf, const char * format, double * f1) { return sscanf (buf, format, f1); }
|
||||
static inline int swscanf (const wchar_t * buf, const wchar_t * format, int * i1) { return swscanf (buf, format, i1); }
|
||||
static inline int fscanf (FILE * file, const char * format, float * f1) { return fscanf (file, format, f1); }
|
||||
|
||||
// cacpy -- fixed-size character array (same as original strncpy (dst, src, sizeof (dst)))
|
||||
// NOTE: THIS FUNCTION HAS NEVER BEEN TESTED. REMOVE THIS COMMENT ONCE IT HAS.
|
||||
template<class T, size_t n> static inline void cacpy (T (&dst)[n], const T * src)
|
||||
{ for (int i = 0; i < n; i++) { dst[i] = *src; if (*src) src++; } }
|
||||
// { return strncpy (dst, src, n); } // using original C std lib function
|
||||
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// frequently missing string functions
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace strfun {
|
||||
|
||||
#ifndef BASETYPES_NO_STRPRINTF
|
||||
template<typename C> struct basic_cstring : public std::basic_string<C>
|
||||
{
|
||||
template<typename S> basic_cstring (S p) : std::basic_string<C> (p) { }
|
||||
operator const C * () const { return this->c_str(); }
|
||||
};
|
||||
|
||||
typedef basic_cstring<char> cstring;
|
||||
typedef basic_cstring<wchar_t> wcstring;
|
||||
|
||||
// [w]strprintf() -- like sprintf() but resulting in a C++ string
|
||||
template<class _T> struct _strprintf : public std::basic_string<_T>
|
||||
{ // works for both wchar_t* and char*
|
||||
_strprintf (const _T * format, ...)
|
||||
{
|
||||
va_list args; va_start (args, format); // varargs stuff
|
||||
size_t n = _cprintf (format, args); // num chars excl. '\0'
|
||||
const int FIXBUF_SIZE = 128; // incl. '\0'
|
||||
if (n < FIXBUF_SIZE)
|
||||
{
|
||||
_T fixbuf[FIXBUF_SIZE];
|
||||
this->assign (_sprintf (&fixbuf[0], sizeof (fixbuf)/sizeof (*fixbuf), format, args), n);
|
||||
}
|
||||
else // too long: use dynamically allocated variable-size buffer
|
||||
{
|
||||
std::vector<_T> varbuf (n + 1); // incl. '\0'
|
||||
this->assign (_sprintf (&varbuf[0], varbuf.size(), format, args), n);
|
||||
}
|
||||
}
|
||||
private:
|
||||
// helpers
|
||||
inline size_t _cprintf (const wchar_t * format, va_list args) { return _vscwprintf (format, args); }
|
||||
inline size_t _cprintf (const char * format, va_list args) { return _vscprintf (format, args); }
|
||||
inline const wchar_t * _sprintf (wchar_t * buf, size_t bufsiz, const wchar_t * format, va_list args) { vswprintf_s (buf, bufsiz, format, args); return buf; }
|
||||
inline const char * _sprintf ( char * buf, size_t bufsiz, const char * format, va_list args) { vsprintf_s (buf, bufsiz, format, args); return buf; }
|
||||
};
|
||||
|
||||
typedef strfun::_strprintf<char> strprintf; // char version
|
||||
typedef strfun::_strprintf<wchar_t> wstrprintf; // wchar_t version
|
||||
|
||||
#endif
|
||||
|
||||
//http://www.nanobit.net/putty/doxy/PUTTY_8H-source.html
|
||||
#ifndef CP_UTF8
|
||||
#define CP_UTF8 65001
|
||||
#endif
|
||||
// string-encoding conversion functions
|
||||
#ifdef _WIN32
|
||||
struct utf8 : std::string { utf8 (const std::wstring & p) // utf-16 to -8
|
||||
{
|
||||
size_t len = p.length();
|
||||
if (len == 0) { return;} // empty string
|
||||
msra::basetypes::fixed_vector<char> buf (3 * len + 1); // max: 1 wchar => up to 3 mb chars
|
||||
// ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify
|
||||
std::fill (buf.begin (), buf.end (), 0);
|
||||
int rc = WideCharToMultiByte (CP_UTF8, 0, p.c_str(), (int) len,
|
||||
&buf[0], (int) buf.size(), NULL, NULL);
|
||||
if (rc == 0) throw std::runtime_error ("WideCharToMultiByte");
|
||||
(*(std::string*)this) = &buf[0];
|
||||
}};
|
||||
struct utf16 : std::wstring { utf16 (const std::string & p) // utf-8 to -16
|
||||
{
|
||||
size_t len = p.length();
|
||||
if (len == 0) { return;} // empty string
|
||||
msra::basetypes::fixed_vector<wchar_t> buf (len + 1);
|
||||
// ... TODO: this fill() should be unnecessary (a 0 is appended)--but verify
|
||||
std::fill (buf.begin (), buf.end (), (wchar_t) 0);
|
||||
int rc = MultiByteToWideChar (CP_UTF8, 0, p.c_str(), (int) len,
|
||||
&buf[0], (int) buf.size());
|
||||
if (rc == 0) throw std::runtime_error ("MultiByteToWideChar");
|
||||
ASSERT (rc < buf.size ());
|
||||
(*(std::wstring*)this) = &buf[0];
|
||||
}};
|
||||
#endif
|
||||
|
||||
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996) // Reviewed by Yusheng Li, March 14, 2006. depr. fn (wcstombs, mbstowcs)
|
||||
static inline std::string wcstombs (const std::wstring & p) // output: MBCS
|
||||
{
|
||||
size_t len = p.length();
|
||||
msra::basetypes::fixed_vector<char> buf (2 * len + 1); // max: 1 wchar => 2 mb chars
|
||||
std::fill (buf.begin (), buf.end (), 0);
|
||||
::wcstombs (&buf[0], p.c_str(), 2 * len + 1);
|
||||
return std::string (&buf[0]);
|
||||
}
|
||||
static inline std::wstring mbstowcs (const std::string & p) // input: MBCS
|
||||
{
|
||||
size_t len = p.length();
|
||||
msra::basetypes::fixed_vector<wchar_t> buf (len + 1); // max: >1 mb chars => 1 wchar
|
||||
std::fill (buf.begin (), buf.end (), (wchar_t) 0);
|
||||
OACR_WARNING_SUPPRESS(UNSAFE_STRING_FUNCTION, "Reviewed OK. size checked. [rogeryu 2006/03/21]");
|
||||
::mbstowcs (&buf[0], p.c_str(), len + 1);
|
||||
return std::wstring (&buf[0]);
|
||||
}
|
||||
#pragma warning(pop)
|
||||
static inline std::string utf8 (const std::wstring & p) { return msra::strfun::wcstombs (p.c_str()); } // output: UTF-8... not really
|
||||
static inline std::wstring utf16 (const std::string & p) { return msra::strfun::mbstowcs(p.c_str()); } // input: UTF-8... not really
|
||||
|
||||
|
||||
|
||||
// split and join -- tokenize a string like strtok() would, join() strings together
|
||||
template<class _T> static inline std::vector<std::basic_string<_T>> split (const std::basic_string<_T> & s, const _T * delim)
|
||||
{
|
||||
std::vector<std::basic_string<_T>> res;
|
||||
for (size_t st = s.find_first_not_of (delim); st != std::basic_string<_T>::npos; )
|
||||
{
|
||||
size_t en = s.find_first_of (delim, st +1);
|
||||
if (en == std::basic_string<_T>::npos) en = s.length();
|
||||
res.push_back (s.substr (st, en-st));
|
||||
st = s.find_first_not_of (delim, en +1); // may exceed
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class _T> static inline std::basic_string<_T> join (const std::vector<std::basic_string<_T>> & a, const _T * delim)
|
||||
{
|
||||
std::basic_string<_T> res;
|
||||
for (int i = 0; i < (int) a.size(); i++)
|
||||
{
|
||||
if (i > 0) res.append (delim);
|
||||
res.append (a[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
// parsing strings to numbers
|
||||
static inline int toint (const wchar_t * s)
|
||||
{
|
||||
return _wtoi (s); // ... TODO: check it
|
||||
}
|
||||
#endif
|
||||
static inline int toint (const char * s)
|
||||
{
|
||||
return atoi (s); // ... TODO: check it
|
||||
}
|
||||
static inline int toint (const std::wstring & s) { return toint (s.c_str()); }
|
||||
|
||||
static inline double todouble (const char * s)
|
||||
{
|
||||
char * ep; // will be set to point to first character that failed parsing
|
||||
double value = strtod (s, &ep);
|
||||
if (*s == 0 || *ep != 0)
|
||||
throw std::runtime_error ("todouble: invalid input string");
|
||||
return value;
|
||||
}
|
||||
|
||||
// TODO: merge this with todouble(const char*) above
|
||||
static inline double todouble (const std::string & s)
|
||||
{
|
||||
s.size(); // just used to remove the unreferenced warning
|
||||
|
||||
double value = 0.0;
|
||||
|
||||
// stod supposedly exists in VS2010, but some folks have compilation errors
|
||||
// If this causes errors again, change the #if into the respective one for VS 2010.
|
||||
#if _MSC_VER > 1400 // VS 2010+
|
||||
size_t * idx = 0;
|
||||
value = std::stod (s, idx);
|
||||
if (idx) throw std::runtime_error ("todouble: invalid input string");
|
||||
#else
|
||||
char *ep = 0; // will be updated by strtod to point to first character that failed parsing
|
||||
value = strtod (s.c_str(), &ep);
|
||||
|
||||
// strtod documentation says ep points to first unconverted character OR
|
||||
// return value will be +/- HUGE_VAL for overflow/underflow
|
||||
if (ep != s.c_str() + s.length() || value == HUGE_VAL || value == -HUGE_VAL)
|
||||
throw std::runtime_error ("todouble: invalid input string");
|
||||
#endif
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline double todouble (const std::wstring & s)
|
||||
{
|
||||
wchar_t * endptr;
|
||||
double value = wcstod (s.c_str(), &endptr);
|
||||
if (*endptr) throw std::runtime_error ("todouble: invalid input string");
|
||||
return value;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// tokenizer -- utility for white-space tokenizing strings in a character buffer
|
||||
// This simple class just breaks a string, but does not own the string buffer.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class tokenizer : public std::vector<char*>
|
||||
{
|
||||
const char * delim;
|
||||
public:
|
||||
tokenizer (const char * delim, size_t cap) : delim (delim) { reserve (cap); }
|
||||
// Usage: tokenizer tokens (delim, capacity); tokens = buf; tokens.size(), tokens[i]
|
||||
void operator= (char * buf)
|
||||
{
|
||||
resize (0);
|
||||
|
||||
// strtok_s not available on all platforms - so backoff to strtok on those
|
||||
#ifdef strtok_s
|
||||
char * context; // for strtok_s()
|
||||
for (char * p = strtok_s (buf, delim, &context); p; p = strtok_s (NULL, delim, &context))
|
||||
push_back (p);
|
||||
#else
|
||||
for (char * p = strtok (buf, delim); p; p = strtok (NULL, delim))
|
||||
push_back (p);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
};}; // namespace
|
||||
static inline msra::strfun::cstring charpath (const std::wstring & p)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
return std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>>().to_bytes(p);
|
||||
#else // old version, delete once we know it works
|
||||
size_t len = p.length();
|
||||
std::vector<char> buf(2 * len + 1, 0); // max: 1 wchar => 2 mb chars
|
||||
::wcstombs(buf.data(), p.c_str(), 2 * len + 1);
|
||||
return msra::strfun::cstring (&buf[0]);
|
||||
#endif
|
||||
}
|
||||
static inline FILE* _wfopen (const wchar_t * path, const wchar_t * mode) { return fopen(charpath(path), charpath(mode)); }
|
||||
static inline void Sleep (size_t ms) { std::this_thread::sleep_for (std::chrono::milliseconds (ms)); }
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// wrappers for some basic types (files, handles, timer)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace basetypes {
|
||||
|
||||
// FILE* with auto-close; use auto_file_ptr instead of FILE*.
|
||||
// Warning: do not pass an auto_file_ptr to a function that calls fclose(),
|
||||
// except for fclose() itself.
|
||||
class auto_file_ptr
|
||||
{
|
||||
FILE * f;
|
||||
FILE * operator= (auto_file_ptr &); // can't ref-count: no assignment
|
||||
auto_file_ptr (auto_file_ptr &);
|
||||
// implicit close (destructor, assignment): we ignore error
|
||||
void close() throw() { if (f) try { if (f != stdin && f != stdout && f != stderr) ::fclose (f); } catch (...) { } f = NULL; }
|
||||
void openfailed (const std::string & path) { throw std::runtime_error ("auto_file_ptr: error opening file '" + path + "': " + strerror (errno)); }
|
||||
protected:
|
||||
friend int fclose (auto_file_ptr&); // explicit close (note: may fail)
|
||||
int fclose() { int rc = ::fclose (f); if (rc == 0) f = NULL; return rc; }
|
||||
public:
|
||||
auto_file_ptr() : f (NULL) { }
|
||||
~auto_file_ptr() { close(); }
|
||||
auto_file_ptr (const char * path, const char * mode) { f = fopen (path, mode); if (f == NULL) openfailed (path); }
|
||||
auto_file_ptr (const wchar_t * wpath, const char * mode) { f = _wfopen (wpath, msra::strfun::utf16 (mode).c_str()); if (f == NULL) openfailed (msra::strfun::utf8 (wpath)); }
|
||||
FILE * operator= (FILE * other) { close(); f = other; return f; }
|
||||
auto_file_ptr (FILE * other) : f (other) { }
|
||||
operator FILE * () const { return f; }
|
||||
FILE * operator->() const { return f; }
|
||||
void swap (auto_file_ptr & other) throw() { std::swap (f, other.f); }
|
||||
};
|
||||
inline int fclose (auto_file_ptr & af) { return af.fclose(); }
|
||||
|
||||
|
||||
};};
|
||||
|
||||
namespace msra { namespace files {
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// textreader -- simple reader for text files --we need this all the time!
|
||||
// Currently reads 8-bit files, but can return as wstring, in which case
|
||||
// they are interpreted as UTF-8 (without BOM).
|
||||
// Note: Not suitable for pipes or typed input due to readahead (fixable if needed).
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class textreader
|
||||
{
|
||||
msra::basetypes::auto_file_ptr f;
|
||||
std::vector<char> buf; // read buffer (will only grow, never shrink)
|
||||
int ch; // next character (we need to read ahead by one...)
|
||||
char getch() { char prevch = (char) ch; ch = fgetc (f); return prevch; }
|
||||
public:
|
||||
textreader (const std::wstring & path) : f (path.c_str(), "rb") { buf.reserve (10000); ch = fgetc (f); }
|
||||
operator bool() const { return ch != EOF; } // true if still a line to read
|
||||
std::string getline() // get and consume the next line
|
||||
{
|
||||
if (ch == EOF) throw std::logic_error ("textreader: attempted to read beyond EOF");
|
||||
assert (buf.empty());
|
||||
// get all line's characters --we recognize UNIX (LF), DOS (CRLF), and Mac (CR) convention
|
||||
while (ch != EOF && ch != '\n' && ch != '\r') buf.push_back (getch());
|
||||
if (ch != EOF && getch() == '\r' && ch == '\n') getch(); // consume EOLN char
|
||||
std::string line (buf.begin(), buf.end());
|
||||
buf.clear();
|
||||
return line;
|
||||
}
|
||||
std::wstring wgetline() { return msra::strfun::utf16 (getline()); }
|
||||
};
|
||||
|
||||
};};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// functional-programming style helper macros (...do this with templates?)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#define foreach_index(_i,_dat) for (int _i = 0; _i < (int) (_dat).size(); _i++)
|
||||
#define map_array(_x,_expr,_y) { _y.resize (_x.size()); foreach_index(_i,_x) _y[_i]=_expr(_x[_i]); }
|
||||
#define reduce_array(_x,_expr,_y) { foreach_index(_i,_x) _y = (_i==0) ? _x[_i] : _expr(_y,_x[_i]); }
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// frequently missing utility functions
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace util {
|
||||
|
||||
// to (slightly) simplify processing of command-line arguments.
|
||||
// command_line args (argc, argv);
|
||||
// while (args.has (1) && args[0][0] == '-') { option = args.shift(); process (option); }
|
||||
// for (const wchar_t * arg = args.shift(); arg; arg = args.shift()) { process (arg); }
|
||||
class command_line
|
||||
{
|
||||
int num;
|
||||
const wchar_t * * args;
|
||||
public:
|
||||
command_line (int argc, wchar_t * argv[]) : num (argc), args ((const wchar_t **) argv) { shift(); }
|
||||
inline int size() const { return num; }
|
||||
inline bool has (int left) { return size() >= left; }
|
||||
const wchar_t * shift() { if (size() == 0) return NULL; num--; return *args++; }
|
||||
const wchar_t * operator[] (int i) const { return (i < 0 || i >= size()) ? NULL : args[i]; }
|
||||
};
|
||||
|
||||
// byte-reverse a variable --reverse all bytes (intended for integral types and float)
|
||||
template<typename T> static inline void bytereverse (T & v) throw()
|
||||
{ // note: this is more efficient than it looks because sizeof (v[0]) is a constant
|
||||
char * p = (char *) &v;
|
||||
const size_t elemsize = sizeof (v);
|
||||
for (int k = 0; k < elemsize / 2; k++) // swap individual bytes
|
||||
swap (p[k], p[elemsize-1 - k]);
|
||||
}
|
||||
|
||||
// byte-swap an entire array
|
||||
template<class V> static inline void byteswap (V & v) throw()
|
||||
{
|
||||
foreach_index (i, v)
|
||||
bytereverse (v[i]);
|
||||
}
|
||||
|
||||
// execute a block with retry
|
||||
// Block must be restartable.
|
||||
// Use this when writing small files to those unreliable Windows servers.
|
||||
// TODO: This will fail to compile under VS 2008--we need an #ifdef around this
|
||||
template<typename FUNCTION> static void attempt (int retries, const FUNCTION & body)
|
||||
{
|
||||
for (int attempt = 1; ; attempt++)
|
||||
{
|
||||
try
|
||||
{
|
||||
body();
|
||||
if (attempt > 1) fprintf (stderr, "attempt: success after %d retries\n", attempt);
|
||||
break;
|
||||
}
|
||||
catch (const std::exception & e)
|
||||
{
|
||||
if (attempt >= retries)
|
||||
throw; // failed N times --give up and rethrow the error
|
||||
fprintf (stderr, "attempt: %s, retrying %d-th time out of %d...\n", e.what(), attempt+1, retries);
|
||||
::Sleep (1000); // wait a little, then try again
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};}; // namespace
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
// ----------------------------------------------------------------------------
|
||||
// frequently missing Win32 functions
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// strerror() for Win32 error codes
|
||||
static inline std::wstring FormatWin32Error (DWORD error)
|
||||
{
|
||||
wchar_t buf[1024] = { 0 };
|
||||
::FormatMessageW (FORMAT_MESSAGE_FROM_SYSTEM, "", error, 0, buf, sizeof (buf)/sizeof (*buf) -1, NULL);
|
||||
std::wstring res (buf);
|
||||
// eliminate newlines (and spaces) from the end
|
||||
size_t last = res.find_last_not_of (L" \t\r\n");
|
||||
if (last != std::string::npos) res.erase (last +1, res.length());
|
||||
return res;
|
||||
}
|
||||
// we always wanted this!
|
||||
#pragma warning (push)
|
||||
#pragma warning (disable: 6320) // Exception-filter expression is the constant EXCEPTION_EXECUTE_HANDLER
|
||||
#pragma warning (disable: 6322) // Empty _except block
|
||||
static inline void SetCurrentThreadName (const char* threadName)
|
||||
{ // from http://msdn.microsoft.com/en-us/library/xcb2z8hs.aspx
|
||||
::Sleep(10);
|
||||
#pragma pack(push,8)
|
||||
struct { DWORD dwType; LPCSTR szName; DWORD dwThreadID; DWORD dwFlags; } info = { 0x1000, threadName, (DWORD) -1, 0 };
|
||||
#pragma pack(pop)
|
||||
__try { RaiseException (0x406D1388, 0, sizeof(info)/sizeof(ULONG_PTR), (ULONG_PTR*)&info); }
|
||||
__except(EXCEPTION_EXECUTE_HANDLER) { }
|
||||
}
|
||||
#pragma warning (pop)
|
||||
|
||||
// return a string as a CoTaskMemAlloc'ed memory object
|
||||
// Returns NULL if out of memory (we don't throw because we'd just catch it outside and convert to HRESULT anyway).
|
||||
static inline LPWSTR CoTaskMemString (const wchar_t * s)
|
||||
{
|
||||
size_t n = wcslen (s) + 1; // number of chars to allocate and copy
|
||||
LPWSTR p = (LPWSTR) ::CoTaskMemAlloc (sizeof (*p) * n);
|
||||
if (p) for (size_t i = 0; i < n; i++) p[i] = s[i];
|
||||
return p;
|
||||
}
|
||||
|
||||
template<class S> static inline void ZeroStruct (S & s) { memset (&s, 0, sizeof (s)); }
|
||||
|
||||
#endif
|
||||
// ----------------------------------------------------------------------------
|
||||
// machine dependent
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#define MACHINE_IS_BIG_ENDIAN (false)
|
||||
|
||||
using namespace msra::basetypes; // for compatibility
|
||||
|
||||
#pragma warning (pop)
|
||||
|
||||
// RuntimeError - throw a std::runtime_error with a formatted error string
|
||||
#ifdef _MSC_VER
|
||||
__declspec(noreturn)
|
||||
#endif
|
||||
static inline void RuntimeError(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
char buffer[1024];
|
||||
|
||||
va_start(args, format);
|
||||
vsprintf(buffer, format, args);
|
||||
throw std::runtime_error(buffer);
|
||||
};
|
||||
|
||||
// LogicError - throw a std::logic_error with a formatted error string
|
||||
#ifdef _MSC_VER
|
||||
__declspec(noreturn)
|
||||
#endif
|
||||
static inline void LogicError(const char * format, ...)
|
||||
{
|
||||
va_list args;
|
||||
char buffer[1024];
|
||||
|
||||
va_start(args, format);
|
||||
vsprintf(buffer, format, args);
|
||||
throw std::logic_error(buffer);
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// dynamic loading of modules
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#ifdef _WIN32
|
||||
class Plugin
|
||||
{
|
||||
HMODULE m_hModule; // module handle for the writer DLL
|
||||
std::wstring m_dllName; // name of the writer DLL
|
||||
public:
|
||||
Plugin() { m_hModule = NULL; }
|
||||
template<class STRING> // accepts char (UTF-8) and wide string
|
||||
FARPROC Load(const STRING & plugin, const std::string & proc)
|
||||
{
|
||||
m_dllName = msra::strfun::utf16(plugin);
|
||||
m_dllName += L".dll";
|
||||
m_hModule = LoadLibrary(m_dllName.c_str());
|
||||
if (m_hModule == NULL)
|
||||
RuntimeError("Plugin not found: %s", msra::strfun::utf8(m_dllName));
|
||||
|
||||
// create a variable of each type just to call the proper templated version
|
||||
return GetProcAddress(m_hModule, proc.c_str());
|
||||
}
|
||||
~Plugin() { if (m_hModule) FreeLibrary(m_hModule); }
|
||||
};
|
||||
#else
|
||||
class Plugin
|
||||
{
|
||||
public:
|
||||
template<class STRING> // accepts char (UTF-8) and wide string
|
||||
void * Load(const STRING & plugin, const std::string & proc)
|
||||
{
|
||||
RuntimeError("Plugins not implemented on Linux yet");
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif // _BASETYPES_
|
|
@ -1,122 +0,0 @@
|
|||
//
|
||||
// <copyright file="biggrowablevectors.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// biggrowablevectors.h -- big growable vector that uses two layers and optionally a disk backing store for paging
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// growablevectorbase -- helper for two-layer growable random-access array
|
||||
// This allows both a fully allocated vector (with push_back()), e.g. for uids,
|
||||
// as well as a partially allocated one (content managed by derived class), for features and lattice blocks.
|
||||
// TODO:
|
||||
// - test this (make copy of binary first before full compilation; or rebuild the previous version)
|
||||
// - fully move in-mem range here, test again
|
||||
// - then we can move towards paging from archive directly (biggrowablevectorarray gets tossed)
|
||||
// ---------------------------------------------------------------------------
|
||||
template<class BLOCKTYPE> class growablevectorbase
|
||||
{
|
||||
protected: // fix this later
|
||||
const size_t elementsperblock;
|
||||
size_t n; // number of elements
|
||||
std::vector<std::unique_ptr<BLOCKTYPE>> blocks; // the data blocks
|
||||
void operator= (const growablevectorbase &); // (non-assignable)
|
||||
void check (size_t t) const { if (t >= n) throw std::logic_error ("growablevectorbase: out of bounds"); } // bounds check helper
|
||||
|
||||
// resize intermediate level, but do not allocate blocks
|
||||
// (may deallocate if shrinking)
|
||||
void resize_without_commit (size_t T)
|
||||
{
|
||||
blocks.resize ((T + elementsperblock-1) / elementsperblock);
|
||||
n = T;
|
||||
// TODO: update allocated range
|
||||
}
|
||||
|
||||
// commit memory
|
||||
// begin/end must be block boundaries
|
||||
void commit (size_t begin, size_t end, BLOCKTYPE * blockdata)
|
||||
{
|
||||
auto blockptr = getblock (begin, end); // memory leak: if this fails (logic error; should never happen)
|
||||
blockptr.set (blockdata); // take ownership of the block
|
||||
// TODO: update allocated range --also enforce consecutiveness
|
||||
}
|
||||
|
||||
// flush a block
|
||||
// begin/end must be block boundaries
|
||||
void flush (size_t begin, size_t end)
|
||||
{
|
||||
auto blockptr = getblock (begin, end); // memory leak: if this fails (logic error; should never happen)
|
||||
blockptr.reset(); // release it
|
||||
// TODO: update allocated range --also enforce consecutiveness
|
||||
}
|
||||
|
||||
// helper to get a block pointer, with block referenced as its entire range
|
||||
std::unique_ptr<BLOCKTYPE> & getblockptr (size_t t) // const
|
||||
{
|
||||
check (t);
|
||||
return blocks[t / elementsperblock];
|
||||
}
|
||||
|
||||
// helper to get a block pointer, with block referenced as its entire range
|
||||
std::unique_ptr<BLOCKTYPE> & getblockptr (size_t begin, size_t end) const
|
||||
{
|
||||
// BUGBUG: last block may be shorter than elementsperblock
|
||||
if (end - begin != elementsperblock || getblockt (begin) != 0)
|
||||
throw std::logic_error ("growablevectorbase: non-block boundaries passed to block-level function");
|
||||
return getblockptr (begin);
|
||||
}
|
||||
public:
|
||||
growablevectorbase (size_t elementsperblock) : elementsperblock (elementsperblock), n (0) { blocks.reserve (1000); }
|
||||
size_t size() const { return n; } // number of frames
|
||||
bool empty() const { return size() == 0; }
|
||||
|
||||
// to access an element t -> getblock(t)[getblockt(t)]
|
||||
BLOCKTYPE & getblock (size_t t) const
|
||||
{
|
||||
check (t);
|
||||
const size_t blockid = t / elementsperblock;
|
||||
return *blocks[blockid].get();
|
||||
}
|
||||
|
||||
size_t getblockt (size_t t) const
|
||||
{
|
||||
check (t);
|
||||
return t % elementsperblock;
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// biggrowablevector -- big vector we can push_back to
|
||||
// ---------------------------------------------------------------------------
|
||||
template<typename ELEMTYPE> class biggrowablevector : public growablevectorbase<std::vector<ELEMTYPE>>
|
||||
{
|
||||
public:
|
||||
biggrowablevector() : growablevectorbase<std::vector<ELEMTYPE>>::growablevectorbase (65536) { }
|
||||
|
||||
template<typename VALTYPE> void push_back (VALTYPE e) // VALTYPE could be an rvalue reference
|
||||
{
|
||||
size_t i = this->size();
|
||||
this->resize_without_commit (i + 1);
|
||||
auto & block = this->getblockptr (i);
|
||||
if (block.get() == NULL)
|
||||
block.reset (new std::vector<ELEMTYPE> (this->elementsperblock));
|
||||
(*block)[this->getblockt (i)] = e;
|
||||
}
|
||||
|
||||
ELEMTYPE & operator[] (size_t t) { return this->getblock(t)[this->getblockt (t)]; } // get an element
|
||||
const ELEMTYPE & operator[] (size_t t) const { return this->getblock(t)[this->getblockt (t)]; } // get an element
|
||||
|
||||
void resize (const size_t n)
|
||||
{
|
||||
this->resize_without_commit (n);
|
||||
foreach_index (i, this->blocks)
|
||||
if (this->blocks[i].get() == NULL)
|
||||
this->blocks[i].reset (new std::vector<ELEMTYPE> (this->elementsperblock));
|
||||
}
|
||||
};
|
||||
|
||||
};};
|
|
@ -1,373 +0,0 @@
|
|||
//
|
||||
// <copyright file="chunkevalsource.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
#pragma once
|
||||
|
||||
|
||||
//#include <objbase.h>
|
||||
#include "basetypes.h" // for attempt()
|
||||
#include "htkfeatio.h" // for reading HTK features
|
||||
#include "minibatchsourcehelpers.h"
|
||||
|
||||
#ifndef __unix__
|
||||
#include "ssematrix.h"
|
||||
#endif
|
||||
|
||||
#ifdef LEAKDETECT
|
||||
#include <vld.h> // for memory leak detection
|
||||
#endif
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
class chunkevalsource // : public numamodelmanager
|
||||
{
|
||||
const size_t chunksize; // actual block size to perform computation on
|
||||
|
||||
// data FIFO
|
||||
msra::dbn::matrix feat;
|
||||
std::vector<std::vector<float>> frames; // [t] all feature frames concatenated into a big block
|
||||
std::vector<char> boundaryflags; // [t] -1 for first and +1 last frame, 0 else (for augmentneighbors())
|
||||
std::vector<size_t> numframes; // [k] number of frames for all appended files
|
||||
std::vector<std::wstring> outpaths; // [k] and their pathnames
|
||||
std::vector<unsigned int> sampperiods; // [k] and sample periods (they should really all be the same...)
|
||||
size_t vdim; // input dimension
|
||||
size_t udim; // output dimension
|
||||
bool minibatchready;
|
||||
void operator=(const chunkevalsource &);
|
||||
private:
|
||||
void clear() // empty the FIFO
|
||||
{
|
||||
frames.clear();
|
||||
boundaryflags.clear();
|
||||
numframes.clear();
|
||||
outpaths.clear();
|
||||
sampperiods.clear();
|
||||
minibatchready=false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void saveandflush(msra::dbn::matrix &pred)
|
||||
{
|
||||
const size_t framesinblock = frames.size();
|
||||
|
||||
// write out all files
|
||||
size_t firstframe = 0;
|
||||
foreach_index (k, numframes)
|
||||
{
|
||||
const wstring & outfile = outpaths[k];
|
||||
unsigned int sampperiod = sampperiods[k];
|
||||
size_t n = numframes[k];
|
||||
msra::files::make_intermediate_dirs (outfile);
|
||||
fprintf (stderr, "saveandflush: writing %zu frames to %S\n", n, outfile.c_str());
|
||||
msra::dbn::matrixstripe thispred (pred, firstframe, n);
|
||||
// some sanity check for the data we've written
|
||||
const size_t nansinf = thispred.countnaninf();
|
||||
if (nansinf > 0)
|
||||
fprintf (stderr, "chunkeval: %d NaNs or INF detected in '%S' (%d frames)\n", (int) nansinf, outfile.c_str(), (int) thispred.cols());
|
||||
// save it
|
||||
msra::util::attempt (5, [&]()
|
||||
{
|
||||
msra::asr::htkfeatwriter::write (outfile, "USER", sampperiod, thispred);
|
||||
});
|
||||
firstframe += n;
|
||||
}
|
||||
assert (firstframe == framesinblock); framesinblock;
|
||||
|
||||
// and we are done --forget the FIFO content & get ready for next chunk
|
||||
clear();
|
||||
|
||||
}
|
||||
|
||||
public:
|
||||
chunkevalsource (size_t numinput, size_t numoutput, size_t chunksize)
|
||||
:vdim(numinput),udim(numoutput),chunksize(chunksize)
|
||||
{
|
||||
frames.reserve (chunksize * 2);
|
||||
feat.resize(vdim,chunksize); // initialize to size chunksize
|
||||
}
|
||||
|
||||
// append data to chunk
|
||||
template<class MATRIX> void addfile (const MATRIX & feat, const string & featkind, unsigned int sampperiod, const std::wstring & outpath)
|
||||
{
|
||||
// append to frames; also expand neighbor frames
|
||||
if (feat.cols() < 2)
|
||||
throw std::runtime_error ("evaltofile: utterances < 2 frames not supported");
|
||||
foreach_column (t, feat)
|
||||
{
|
||||
std::vector<float> v (&feat(0,t), &feat(0,t) + feat.rows());
|
||||
frames.push_back (v);
|
||||
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
|
||||
}
|
||||
|
||||
numframes.push_back (feat.cols());
|
||||
outpaths.push_back (outpath);
|
||||
sampperiods.push_back (sampperiod);
|
||||
|
||||
}
|
||||
|
||||
void createevalminibatch()
|
||||
{
|
||||
const size_t framesinblock = frames.size();
|
||||
feat.resize(vdim, framesinblock); // input features for whole utt (col vectors)
|
||||
// augment the features
|
||||
msra::dbn::augmentneighbors (frames, boundaryflags, 0, framesinblock, feat);
|
||||
minibatchready=true;
|
||||
}
|
||||
|
||||
void writetofiles(msra::dbn::matrix &pred){ saveandflush(pred); }
|
||||
|
||||
msra::dbn::matrix chunkofframes() { assert(minibatchready); return feat; }
|
||||
|
||||
bool isminibatchready() { return minibatchready; }
|
||||
|
||||
size_t currentchunksize() { return frames.size(); }
|
||||
void flushinput(){createevalminibatch();}
|
||||
void reset() { clear(); }
|
||||
|
||||
};
|
||||
|
||||
|
||||
class chunkevalsourcemulti // : public numamodelmanager
|
||||
{
|
||||
const size_t chunksize; // actual block size to perform computation on
|
||||
|
||||
// data FIFO
|
||||
std::vector<msra::dbn::matrix> feat;
|
||||
std::vector<std::vector<std::vector<float>>> framesmulti; // [t] all feature frames concatenated into a big block
|
||||
std::vector<char> boundaryflags; // [t] -1 for first and +1 last frame, 0 else (for augmentneighbors())
|
||||
std::vector<size_t> numframes; // [k] number of frames for all appended files
|
||||
std::vector<std::vector<std::wstring>> outpaths; // [k] and their pathnames
|
||||
std::vector<std::vector<unsigned int>> sampperiods; // [k] and sample periods (they should really all be the same...)
|
||||
std::vector<size_t> vdims; // input dimension
|
||||
std::vector<size_t> udims; // output dimension
|
||||
bool minibatchready;
|
||||
|
||||
void operator=(const chunkevalsourcemulti &);
|
||||
private:
|
||||
void clear() // empty the FIFO
|
||||
{
|
||||
foreach_index(i, vdims)
|
||||
{
|
||||
framesmulti[i].clear();
|
||||
outpaths[i].clear();
|
||||
sampperiods[i].clear();
|
||||
}
|
||||
boundaryflags.clear();
|
||||
numframes.clear();
|
||||
minibatchready=false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void saveandflush(msra::dbn::matrix &pred, size_t index)
|
||||
{
|
||||
const size_t framesinblock = framesmulti[index].size();
|
||||
|
||||
// write out all files
|
||||
size_t firstframe = 0;
|
||||
foreach_index (k, numframes)
|
||||
{
|
||||
const wstring & outfile = outpaths[index][k];
|
||||
unsigned int sampperiod = sampperiods[index][k];
|
||||
size_t n = numframes[k];
|
||||
msra::files::make_intermediate_dirs (outfile);
|
||||
fprintf (stderr, "saveandflush: writing %zu frames to %S\n", n, outfile.c_str());
|
||||
msra::dbn::matrixstripe thispred (pred, firstframe, n);
|
||||
// some sanity check for the data we've written
|
||||
const size_t nansinf = thispred.countnaninf();
|
||||
if (nansinf > 0)
|
||||
fprintf (stderr, "chunkeval: %d NaNs or INF detected in '%S' (%d frames)\n", (int) nansinf, outfile.c_str(), (int) thispred.cols());
|
||||
// save it
|
||||
msra::util::attempt (5, [&]()
|
||||
{
|
||||
msra::asr::htkfeatwriter::write (outfile, "USER", sampperiod, thispred);
|
||||
});
|
||||
firstframe += n;
|
||||
}
|
||||
assert (firstframe == framesinblock); framesinblock;
|
||||
|
||||
// and we are done --forget the FIFO content & get ready for next chunk
|
||||
|
||||
}
|
||||
|
||||
public:
|
||||
chunkevalsourcemulti (std::vector<size_t> vdims, std::vector<size_t> udims, size_t chunksize)
|
||||
:vdims(vdims),udims(udims),chunksize(chunksize)
|
||||
{
|
||||
|
||||
foreach_index(i, vdims)
|
||||
{
|
||||
msra::dbn::matrix thisfeat;
|
||||
std::vector<std::vector<float>> frames; // [t] all feature frames concatenated into a big block
|
||||
|
||||
frames.reserve(chunksize * 2);
|
||||
framesmulti.push_back(frames);
|
||||
//framesmulti[i].reserve (chunksize * 2);
|
||||
|
||||
thisfeat.resize(vdims[i], chunksize);
|
||||
feat.push_back(thisfeat);
|
||||
|
||||
outpaths.push_back(std::vector<std::wstring>());
|
||||
sampperiods.push_back(std::vector<unsigned int>());
|
||||
//feat[i].resize(vdims[i],chunksize); // initialize to size chunksize
|
||||
}
|
||||
}
|
||||
|
||||
// append data to chunk
|
||||
template<class MATRIX> void addfile (const MATRIX & feat, const string & featkind, unsigned int sampperiod, const std::wstring & outpath, size_t index)
|
||||
{
|
||||
// append to frames; also expand neighbor frames
|
||||
if (feat.cols() < 2)
|
||||
throw std::runtime_error ("evaltofile: utterances < 2 frames not supported");
|
||||
foreach_column (t, feat)
|
||||
{
|
||||
std::vector<float> v (&feat(0,t), &feat(0,t) + feat.rows());
|
||||
framesmulti[index].push_back (v);
|
||||
if (index==0)
|
||||
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
|
||||
}
|
||||
if (index==0)
|
||||
numframes.push_back (feat.cols());
|
||||
|
||||
outpaths[index].push_back (outpath);
|
||||
sampperiods[index].push_back (sampperiod);
|
||||
|
||||
}
|
||||
|
||||
void createevalminibatch()
|
||||
{
|
||||
foreach_index(i, framesmulti)
|
||||
{
|
||||
const size_t framesinblock = framesmulti[i].size();
|
||||
feat[i].resize(vdims[i], framesinblock); // input features for whole utt (col vectors)
|
||||
// augment the features
|
||||
msra::dbn::augmentneighbors (framesmulti[i], boundaryflags, 0, framesinblock, feat[i]);
|
||||
}
|
||||
minibatchready=true;
|
||||
}
|
||||
|
||||
void writetofiles(msra::dbn::matrix &pred, size_t index){ saveandflush(pred, index); }
|
||||
|
||||
msra::dbn::matrix chunkofframes(size_t index) { assert(minibatchready); assert(index<=feat.size()); return feat[index]; }
|
||||
|
||||
bool isminibatchready() { return minibatchready; }
|
||||
|
||||
size_t currentchunksize() { return framesmulti[0].size(); }
|
||||
void flushinput(){createevalminibatch();}
|
||||
void reset() { clear(); }
|
||||
|
||||
};
|
||||
|
||||
class FileEvalSource // : public numamodelmanager
|
||||
{
|
||||
const size_t chunksize; // actual block size to perform computation on
|
||||
|
||||
// data FIFO
|
||||
std::vector<msra::dbn::matrix> feat;
|
||||
std::vector<std::vector<std::vector<float>>> framesMulti; // [t] all feature frames concatenated into a big block
|
||||
std::vector<char> boundaryFlags; // [t] -1 for first and +1 last frame, 0 else (for augmentneighbors())
|
||||
std::vector<size_t> numFrames; // [k] number of frames for all appended files
|
||||
std::vector<std::vector<unsigned int>> sampPeriods; // [k] and sample periods (they should really all be the same...)
|
||||
std::vector<size_t> vdims; // input dimension
|
||||
std::vector<size_t> leftcontext;
|
||||
std::vector<size_t> rightcontext;
|
||||
bool minibatchReady;
|
||||
size_t minibatchSize;
|
||||
size_t frameIndex;
|
||||
|
||||
void operator=(const FileEvalSource &);
|
||||
|
||||
private:
|
||||
void Clear() // empty the FIFO
|
||||
{
|
||||
foreach_index(i, vdims)
|
||||
{
|
||||
framesMulti[i].clear();
|
||||
sampPeriods[i].clear();
|
||||
}
|
||||
boundaryFlags.clear();
|
||||
numFrames.clear();
|
||||
minibatchReady=false;
|
||||
frameIndex=0;
|
||||
}
|
||||
|
||||
public:
|
||||
FileEvalSource(std::vector<size_t> vdims, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t chunksize) :vdims(vdims), leftcontext(leftcontext), rightcontext(rightcontext), chunksize(chunksize)
|
||||
{
|
||||
foreach_index(i, vdims)
|
||||
{
|
||||
msra::dbn::matrix thisfeat;
|
||||
std::vector<std::vector<float>> frames; // [t] all feature frames concatenated into a big block
|
||||
|
||||
frames.reserve(chunksize * 2);
|
||||
framesMulti.push_back(frames);
|
||||
//framesmulti[i].reserve (chunksize * 2);
|
||||
|
||||
thisfeat.resize(vdims[i], chunksize);
|
||||
feat.push_back(thisfeat);
|
||||
|
||||
sampPeriods.push_back(std::vector<unsigned int>());
|
||||
//feat[i].resize(vdims[i],chunksize); // initialize to size chunksize
|
||||
}
|
||||
}
|
||||
|
||||
// append data to chunk
|
||||
template<class MATRIX> void AddFile (const MATRIX & feat, const string & /*featkind*/, unsigned int sampPeriod, size_t index)
|
||||
{
|
||||
// append to frames; also expand neighbor frames
|
||||
if (feat.cols() < 2)
|
||||
throw std::runtime_error ("evaltofile: utterances < 2 frames not supported");
|
||||
foreach_column (t, feat)
|
||||
{
|
||||
std::vector<float> v (&feat(0,t), &feat(0,t) + feat.rows());
|
||||
framesMulti[index].push_back (v);
|
||||
if (index==0)
|
||||
boundaryFlags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
|
||||
}
|
||||
if (index==0)
|
||||
numFrames.push_back (feat.cols());
|
||||
|
||||
sampPeriods[index].push_back (sampPeriod);
|
||||
|
||||
}
|
||||
|
||||
void CreateEvalMinibatch()
|
||||
{
|
||||
foreach_index(i, framesMulti)
|
||||
{
|
||||
const size_t framesInBlock = framesMulti[i].size();
|
||||
feat[i].resize(vdims[i], framesInBlock); // input features for whole utt (col vectors)
|
||||
// augment the features
|
||||
size_t leftextent, rightextent;
|
||||
// page in the needed range of frames
|
||||
if (leftcontext[i] == 0 && rightcontext[i] == 0)
|
||||
{
|
||||
leftextent = rightextent = augmentationextent(framesMulti[i][0].size(), vdims[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
leftextent = leftcontext[i];
|
||||
rightextent = rightcontext[i];
|
||||
}
|
||||
|
||||
//msra::dbn::augmentneighbors(framesMulti[i], boundaryFlags, 0, leftcontext[i], rightcontext[i],)
|
||||
msra::dbn::augmentneighbors (framesMulti[i], boundaryFlags, leftextent, rightextent, 0, framesInBlock, feat[i]);
|
||||
}
|
||||
minibatchReady=true;
|
||||
}
|
||||
|
||||
void SetMinibatchSize(size_t mbSize){ minibatchSize=mbSize;}
|
||||
msra::dbn::matrix ChunkOfFrames(size_t index) { assert(minibatchReady); assert(index<=feat.size()); return feat[index]; }
|
||||
|
||||
bool IsMinibatchReady() { return minibatchReady; }
|
||||
|
||||
size_t CurrentFileSize() { return framesMulti[0].size(); }
|
||||
void FlushInput(){CreateEvalMinibatch();}
|
||||
void Reset() { Clear(); }
|
||||
};
|
||||
|
||||
|
||||
};};
|
|
@ -1,24 +0,0 @@
|
|||
//
|
||||
// <copyright file="dllmain.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// dllmain.cpp : Defines the entry point for the DLL application.
|
||||
#include "stdafx.h"
|
||||
|
||||
BOOL APIENTRY DllMain( HMODULE /*hModule*/,
|
||||
DWORD ul_reason_for_call,
|
||||
LPVOID /*lpReserved*/
|
||||
)
|
||||
{
|
||||
switch (ul_reason_for_call)
|
||||
{
|
||||
case DLL_PROCESS_ATTACH:
|
||||
case DLL_THREAD_ATTACH:
|
||||
case DLL_THREAD_DETACH:
|
||||
case DLL_PROCESS_DETACH:
|
||||
break;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,620 +0,0 @@
|
|||
//
|
||||
// fileutil.h - file I/O with error checking
|
||||
//
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
//
|
||||
#pragma once
|
||||
#ifndef _FILEUTIL_
|
||||
#define _FILEUTIL_
|
||||
|
||||
#include "Platform.h"
|
||||
#include <stdio.h>
|
||||
#ifdef __unix__
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
#include <algorithm> // for std::find
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include <cctype>
|
||||
#include <errno.h>
|
||||
#include <stdint.h>
|
||||
#include <assert.h>
|
||||
#include <string.h> // for strerror()
|
||||
|
||||
using namespace std;
|
||||
|
||||
#define SAFE_CLOSE(f) (((f) == NULL) || (fcloseOrDie ((f)), (f) = NULL))
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
|
||||
// A pathname of "-" returns stdout or stdin, depending on mode, and it will
|
||||
// change the binary mode if 'b' or 't' are given. If you use this, make sure
|
||||
// not to fclose() such a handle.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
FILE * fopenOrDie (const string & pathname, const char * mode);
|
||||
FILE * fopenOrDie (const wstring & pathname, const wchar_t * mode);
|
||||
|
||||
#ifndef __unix__
|
||||
// ----------------------------------------------------------------------------
|
||||
// fsetmode(): set mode to binary or text
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fsetmode (FILE * f, char type);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// freadOrDie(): like fread() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void freadOrDie (void * ptr, size_t size, size_t count, FILE * f);
|
||||
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, int num, FILE * f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, size_t num, FILE * f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fwriteOrDie(): like fwrite() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fwriteOrDie (const void * ptr, size_t size, size_t count, FILE * f);
|
||||
|
||||
template<class _T>
|
||||
void fwriteOrDie (const _T & data, FILE * f) // template for vector<>
|
||||
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fprintfOrDie(): like fprintf() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fprintfOrDie (FILE * f, const char *format, ...);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcloseOrDie(): like fclose() but terminate with err msg in case of error
|
||||
// not yet implemented, but we should
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#define fcloseOrDie fclose
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fflushOrDie(): like fflush() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fflushOrDie (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// filesize(): determine size of the file in bytes
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
size_t filesize (const wchar_t * pathname);
|
||||
size_t filesize (FILE * f);
|
||||
int64_t filesize64 (const wchar_t * pathname);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fseekOrDie(),ftellOrDie(), fget/setpos(): seek functions with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// 32-bit offsets only
|
||||
long fseekOrDie (FILE * f, long offset, int mode = SEEK_SET);
|
||||
#define ftellOrDie ftell
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fget/setpos(): seek functions with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
uint64_t fgetpos (FILE * f);
|
||||
void fsetpos (FILE * f, uint64_t pos);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// unlinkOrDie(): unlink() with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void unlinkOrDie (const std::string & pathname);
|
||||
void unlinkOrDie (const std::wstring & pathname);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// renameOrDie(): rename() with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void renameOrDie (const std::string & from, const std::string & to);
|
||||
void renameOrDie (const std::wstring & from, const std::wstring & to);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fexists(): test if a file exists
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool fexists (const char * pathname);
|
||||
bool fexists (const wchar_t * pathname);
|
||||
inline bool fexists (const std::string & pathname) { return fexists (pathname.c_str()); }
|
||||
inline bool fexists (const std::wstring & pathname) { return fexists (pathname.c_str()); }
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// funicode(): test if a file uses unicode
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool funicode (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fskipspace(): skip space characters
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool fskipspace (FILE * F);
|
||||
bool fskipwspace (FILE * F);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetline(): like fgets() but terminate with err msg in case of error;
|
||||
// removes the newline character at the end (like gets()), returned buffer is
|
||||
// always 0-terminated; has second version that returns an STL string instead
|
||||
// fgetstring(): read a 0-terminated string (terminate if error)
|
||||
// fgetword(): read a space-terminated token (terminate if error)
|
||||
// fskipNewLine(): skip all white space until end of line incl. the newline
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputstring(): write a 0-terminated string (terminate if error)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputstring (FILE * f, const char *);
|
||||
void fputstring (const HANDLE f, const char * str);
|
||||
void fputstring (FILE * f, const std::string &);
|
||||
void fputstring (FILE * f, const wchar_t *);
|
||||
void fputstring (FILE * f, const std::wstring &);
|
||||
|
||||
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
|
||||
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
|
||||
string fgetline (FILE * f);
|
||||
wstring fgetlinew (FILE * f);
|
||||
void fgetline (FILE * f, std::string & s, std::vector<char> & buf);
|
||||
void fgetline (FILE * f, std::wstring & s, std::vector<char> & buf);
|
||||
void fgetline (FILE * f, std::vector<char> & buf);
|
||||
void fgetline (FILE * f, std::vector<wchar_t> & buf);
|
||||
|
||||
const char * fgetstring (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
const char * fgetstring (const HANDLE f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
|
||||
const wchar_t * fgetstring (FILE * f, wchar_t * buf, int size);
|
||||
wstring fgetwstring (FILE * f);
|
||||
string fgetstring (FILE * f);
|
||||
|
||||
const char * fgettoken (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
|
||||
string fgettoken (FILE * f);
|
||||
const wchar_t * fgettoken (FILE * f, wchar_t * buf, int size);
|
||||
wstring fgetwtoken (FILE * f);
|
||||
|
||||
int fskipNewline (FILE * f, bool skip = true);
|
||||
int fskipwNewline (FILE * f, bool skip = true);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputstring(): write a 0-terminated string (terminate if error)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputstring (FILE * f, const char *);
|
||||
void fputstring (FILE * f, const std::string &);
|
||||
void fputstring (FILE * f, const wchar_t *);
|
||||
void fputstring (FILE * f, const std::wstring &);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetTag(): read a 4-byte tag & return as a string
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
string fgetTag (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcheckTag (FILE * f, const char * expectedTag);
|
||||
void fcheckTag_ascii (FILE * f, const string & expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcompareTag(): compare two tags; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcompareTag (const string & readTag, const string & expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputTag(): write a 4-byte tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputTag (FILE * f, const char * tag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fskipstring(): skip a 0-terminated string, such as a pad string
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fskipstring (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fpad(): write a 0-terminated string to pad file to a n-byte boundary
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fpad (FILE * f, int n);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetbyte(): read a byte value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
char fgetbyte (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetshort(): read a short value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
short fgetshort (FILE * f);
|
||||
short fgetshort_bigendian (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetint24(): read a 3-byte (24-bit) int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
int fgetint24 (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetint(): read an int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
int fgetint (FILE * f);
|
||||
int fgetint_bigendian (FILE * f);
|
||||
int fgetint_ascii (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetlong(): read an long value
|
||||
// ----------------------------------------------------------------------------
|
||||
long fgetlong (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfloat(): read a float value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
float fgetfloat (FILE * f);
|
||||
float fgetfloat_bigendian (FILE * f);
|
||||
float fgetfloat_ascii (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetdouble(): read a double value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
double fgetdouble (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputbyte(): write a byte value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputbyte (FILE * f, char val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputshort(): write a short value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputshort (FILE * f, short val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputint24(): write a 3-byte (24-bit) int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputint24 (FILE * f, int v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputint(): write an int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputint (FILE * f, int val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputlong(): write an long value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputlong (FILE * f, long val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfloat(): write a float value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputfloat (FILE * f, float val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputdouble(): write a double value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputdouble (FILE * f, double val);
|
||||
|
||||
|
||||
// template versions of put/get functions for binary files
|
||||
template <typename T>
|
||||
void fput(FILE * f, T v)
|
||||
{
|
||||
fwriteOrDie (&v, sizeof (v), 1, f);
|
||||
}
|
||||
|
||||
|
||||
// template versions of put/get functions for binary files
|
||||
template <typename T>
|
||||
void fget(FILE * f, T& v)
|
||||
{
|
||||
freadOrDie ((void *)&v, sizeof (v), 1, f);
|
||||
}
|
||||
|
||||
|
||||
// GetFormatString - get the format string for a particular type
|
||||
template <typename T>
|
||||
const wchar_t* GetFormatString(T /*t*/)
|
||||
{
|
||||
// if this _ASSERT goes off it means that you are using a type that doesn't have
|
||||
// a read and/or write routine.
|
||||
// If the type is a user defined class, you need to create some global functions that handles file in/out.
|
||||
// for example:
|
||||
//File& operator>>(File& stream, MyClass& test);
|
||||
//File& operator<<(File& stream, MyClass& test);
|
||||
//
|
||||
// in your class you will probably want to add these functions as friends so you can access any private members
|
||||
// friend File& operator>>(File& stream, MyClass& test);
|
||||
// friend File& operator<<(File& stream, MyClass& test);
|
||||
//
|
||||
// if you are using wchar_t* or char* types, these use other methods because they require buffers to be passed
|
||||
// either use std::string and std::wstring, or use the WriteString() and ReadString() methods
|
||||
assert(false); // need a specialization
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// GetFormatString - specalizations to get the format string for a particular type
|
||||
template <> const wchar_t* GetFormatString(char);
|
||||
template <> const wchar_t* GetFormatString(wchar_t);
|
||||
template <> const wchar_t* GetFormatString(short);
|
||||
template <> const wchar_t* GetFormatString(int);
|
||||
template <> const wchar_t* GetFormatString(long);
|
||||
template <> const wchar_t* GetFormatString(unsigned short);
|
||||
template <> const wchar_t* GetFormatString(unsigned int);
|
||||
template <> const wchar_t* GetFormatString(unsigned long);
|
||||
template <> const wchar_t* GetFormatString(float);
|
||||
template <> const wchar_t* GetFormatString(double);
|
||||
template <> const wchar_t* GetFormatString(size_t);
|
||||
template <> const wchar_t* GetFormatString(long long);
|
||||
template <> const wchar_t* GetFormatString(const char*);
|
||||
template <> const wchar_t* GetFormatString(const wchar_t*);
|
||||
|
||||
// GetScanFormatString - get the format string for a particular type
|
||||
template <typename T>
|
||||
const wchar_t* GetScanFormatString(T t)
|
||||
{
|
||||
assert(false); // need a specialization
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// GetScanFormatString - specalizations to get the format string for a particular type
|
||||
template <> const wchar_t* GetScanFormatString(char);
|
||||
template <> const wchar_t* GetScanFormatString(wchar_t);
|
||||
template <> const wchar_t* GetScanFormatString(short);
|
||||
template <> const wchar_t* GetScanFormatString(int);
|
||||
template <> const wchar_t* GetScanFormatString(long);
|
||||
template <> const wchar_t* GetScanFormatString(unsigned short);
|
||||
template <> const wchar_t* GetScanFormatString(unsigned int);
|
||||
template <> const wchar_t* GetScanFormatString(unsigned long);
|
||||
template <> const wchar_t* GetScanFormatString(float);
|
||||
template <> const wchar_t* GetScanFormatString(double);
|
||||
template <> const wchar_t* GetScanFormatString(size_t);
|
||||
template <> const wchar_t* GetScanFormatString(long long);
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetText(): get a value from a text file
|
||||
// ----------------------------------------------------------------------------
|
||||
template <typename T>
|
||||
void fgetText(FILE * f, T& v)
|
||||
{
|
||||
int rc = ftrygetText(f, v);
|
||||
if (rc == 0)
|
||||
throw std::runtime_error("error reading value from file (invalid format)");
|
||||
else if (rc == EOF)
|
||||
throw std::runtime_error(std::string("error reading from file: ") + strerror(errno));
|
||||
assert(rc == 1);
|
||||
}
|
||||
|
||||
// version to try and get a string, and not throw exceptions if contents don't match
|
||||
template <typename T>
|
||||
int ftrygetText(FILE * f, T& v)
|
||||
{
|
||||
const wchar_t* formatString = GetScanFormatString<T>(v);
|
||||
int rc = fwscanf (f, formatString, &v);
|
||||
assert(rc == 1 || rc == 0);
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <> int ftrygetText<bool>(FILE * f, bool& v);
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetText() specializations for fwscanf_s differences: get a value from a text file
|
||||
// ----------------------------------------------------------------------------
|
||||
void fgetText(FILE * f, char& v);
|
||||
void fgetText(FILE * f, wchar_t& v);
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputText(): write a value out as text
|
||||
// ----------------------------------------------------------------------------
|
||||
template <typename T>
|
||||
void fputText(FILE * f, T v)
|
||||
{
|
||||
const wchar_t* formatString = GetFormatString(v);
|
||||
int rc = fwprintf(f, formatString, v);
|
||||
if (rc == 0)
|
||||
throw std::runtime_error("error writing value to file, no values written");
|
||||
else if (rc < 0)
|
||||
throw std::runtime_error(std::string("error writing to file: ") + strerror(errno));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputText(): write a bool out as character
|
||||
// ----------------------------------------------------------------------------
|
||||
template <> void fputText<bool>(FILE * f, bool v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfile(): write a binary block or a string as a file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputfile (const wstring & pathname, const std::vector<char> & buffer);
|
||||
void fputfile (const wstring & pathname, const std::wstring & string);
|
||||
void fputfile (const wstring & pathname, const std::string & string);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfile(): load a file as a binary block
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fgetfile (const wstring & pathname, std::vector<char> & buffer);
|
||||
void fgetfile (FILE * f, std::vector<char> & buffer);
|
||||
namespace msra { namespace files {
|
||||
void fgetfilelines (const std::wstring & pathname, vector<char> & readbuffer, std::vector<std::string> & lines);
|
||||
static inline std::vector<std::string> fgetfilelines (const std::wstring & pathname) { vector<char> buffer; std::vector<std::string> lines; fgetfilelines (pathname, buffer, lines); return lines; }
|
||||
vector<char*> fgetfilelines (const wstring & pathname, vector<char> & readbuffer);
|
||||
};};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void expand_wildcards (const wstring & path, vector<wstring> & paths);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// make_intermediate_dirs() -- make all intermediate dirs on a path
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace files {
|
||||
void make_intermediate_dirs (const wstring & filepath);
|
||||
};};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fuptodate() -- test whether an output file is at least as new as an input file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace files {
|
||||
bool fuptodate (const wstring & target, const wstring & input, bool inputrequired = true);
|
||||
};};
|
||||
|
||||
#if 0
|
||||
// ----------------------------------------------------------------------------
|
||||
// simple support for WAV file I/O
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// define the header if we haven't seen it yet
|
||||
#ifndef _WAVEFORMATEX_
|
||||
#define _WAVEFORMATEX_
|
||||
|
||||
/*
|
||||
* extended waveform format structure used for all non-PCM formats. this
|
||||
* structure is common to all non-PCM formats.
|
||||
*/
|
||||
typedef unsigned short WORD; // in case not defined yet (i.e. linux)
|
||||
typedef struct tWAVEFORMATEX
|
||||
{
|
||||
WORD wFormatTag; /* format type */
|
||||
WORD nChannels; /* number of channels (i.e. mono, stereo...) */
|
||||
DWORD nSamplesPerSec; /* sample rate */
|
||||
DWORD nAvgBytesPerSec; /* for buffer estimation */
|
||||
WORD nBlockAlign; /* block size of data */
|
||||
WORD wBitsPerSample; /* number of bits per sample of mono data */
|
||||
WORD cbSize; /* the count in bytes of the size of */
|
||||
/* extra information (after cbSize) */
|
||||
} WAVEFORMATEX, *PWAVEFORMATEX;
|
||||
|
||||
#endif /* _WAVEFORMATEX_ */
|
||||
|
||||
typedef struct wavehder{
|
||||
char riffchar[4];
|
||||
unsigned int RiffLength;
|
||||
char wavechar[8];
|
||||
unsigned int FmtLength;
|
||||
signed short wFormatTag;
|
||||
signed short nChannels;
|
||||
unsigned int nSamplesPerSec;
|
||||
unsigned int nAvgBytesPerSec;
|
||||
signed short nBlockAlign;
|
||||
signed short wBitsPerSample;
|
||||
char datachar[4];
|
||||
unsigned int DataLength;
|
||||
private:
|
||||
void prepareRest (int SampleCount);
|
||||
public:
|
||||
void prepare (unsigned int Fs, int Bits, int Channels, int SampleCount);
|
||||
void prepare (const WAVEFORMATEX & wfx, int SampleCount);
|
||||
unsigned int read (FILE * f, signed short & wRealFormatTag, int & bytesPerSample);
|
||||
void write (FILE * f);
|
||||
static void update (FILE * f);
|
||||
} WAVEHEADER;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetwfx(), fputwfx(): I/O of wave file headers only
|
||||
// ----------------------------------------------------------------------------
|
||||
unsigned int fgetwfx (FILE *f, WAVEFORMATEX & wfx);
|
||||
void fputwfx (FILE *f, const WAVEFORMATEX & wfx, unsigned int numSamples);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetraw(): read data of .wav file, and separate data of multiple channels.
|
||||
// For example, data[i][j]: i is channel index, 0 means the first
|
||||
// channel. j is sample index.
|
||||
// ----------------------------------------------------------------------------
|
||||
void fgetraw (FILE *f,std::vector< std::vector<short> > & data,const WAVEHEADER & wavhd);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// temp functions -- clean these up
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// split a pathname into directory and filename
|
||||
static inline void splitpath (const wstring & path, wstring & dir, wstring & file)
|
||||
{
|
||||
size_t pos = path.find_last_of (L"\\:/"); // DOS drives, UNIX, Windows
|
||||
if (pos == path.npos) // no directory found
|
||||
{
|
||||
dir.clear();
|
||||
file = path;
|
||||
}
|
||||
else
|
||||
{
|
||||
dir = path.substr (0, pos);
|
||||
file = path.substr (pos +1);
|
||||
}
|
||||
}
|
||||
|
||||
// test if a pathname is a relative path
|
||||
// A relative path is one that can be appended to a directory.
|
||||
// Drive-relative paths, such as D:file, are considered non-relative.
|
||||
static inline bool relpath (const wchar_t * path)
|
||||
{ // this is a wild collection of pathname conventions in Windows
|
||||
if (path[0] == '/' || path[0] == '\\') // e.g. \WINDOWS
|
||||
return false;
|
||||
if (path[0] && path[1] == ':') // drive syntax
|
||||
return false;
|
||||
// ... TODO: handle long NT paths
|
||||
return true; // all others
|
||||
}
|
||||
template<class CHAR>
|
||||
static inline bool relpath (const std::basic_string<CHAR> & s) { return relpath (s.c_str()); }
|
||||
|
||||
// trim from start
|
||||
static inline std::string <rim(std::string &s) {
|
||||
s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun<int, int>(std::isspace))));
|
||||
return s;
|
||||
}
|
||||
|
||||
// trim from end
|
||||
static inline std::string &rtrim(std::string &s) {
|
||||
s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun<int, int>(std::isspace))).base(), s.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
// trim from both ends
|
||||
static inline std::string &trim(std::string &s) {
|
||||
return ltrim(rtrim(s));
|
||||
}
|
||||
|
||||
vector<string> sep_string(const string & str, const string & sep);
|
||||
|
||||
#endif // _FILEUTIL_
|
|
@ -1,448 +0,0 @@
|
|||
// TODO: this is a dup; use the one in Include/ instead
|
||||
|
||||
//
|
||||
// <copyright file="fileutil.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
|
||||
#pragma once
|
||||
#ifndef _FILEUTIL_
|
||||
#define _FILEUTIL_
|
||||
|
||||
#include "basetypes.h"
|
||||
#include <stdio.h>
|
||||
#ifdef __WINDOWS__
|
||||
#include <windows.h> // for mmreg.h and FILETIME
|
||||
#include <mmreg.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
using namespace std;
|
||||
|
||||
#define SAFE_CLOSE(f) (((f) == NULL) || (fcloseOrDie ((f)), (f) = NULL))
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fopenOrDie(): like fopen() but terminate with err msg in case of error.
|
||||
// A pathname of "-" returns stdout or stdin, depending on mode, and it will
|
||||
// change the binary mode if 'b' or 't' are given. If you use this, make sure
|
||||
// not to fclose() such a handle.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
FILE * fopenOrDie (const STRING & pathname, const char * mode);
|
||||
FILE * fopenOrDie (const WSTRING & pathname, const wchar_t * mode);
|
||||
|
||||
#ifndef __unix__ // don't need binary/text distinction on unix
|
||||
// ----------------------------------------------------------------------------
|
||||
// fsetmode(): set mode to binary or text
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fsetmode (FILE * f, char type);
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// freadOrDie(): like fread() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void freadOrDie (void * ptr, size_t size, size_t count, FILE * f);
|
||||
void freadOrDie (void * ptr, size_t size, size_t count, const HANDLE f);
|
||||
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, int num, FILE * f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, size_t num, FILE * f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, int num, const HANDLE f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
template<class _T>
|
||||
void freadOrDie (_T & data, size_t num, const HANDLE f) // template for vector<>
|
||||
{ data.resize (num); if (data.size() > 0) freadOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fwriteOrDie(): like fwrite() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fwriteOrDie (const void * ptr, size_t size, size_t count, FILE * f);
|
||||
void fwriteOrDie (const void * ptr, size_t size, size_t count, const HANDLE f);
|
||||
|
||||
template<class _T>
|
||||
void fwriteOrDie (const _T & data, FILE * f) // template for vector<>
|
||||
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
template<class _T>
|
||||
void fwriteOrDie (const _T & data, const HANDLE f) // template for vector<>
|
||||
{ if (data.size() > 0) fwriteOrDie (&data[0], sizeof (data[0]), data.size(), f); }
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fprintfOrDie(): like fprintf() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fprintfOrDie (FILE * f, const char *format, ...);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcloseOrDie(): like fclose() but terminate with err msg in case of error
|
||||
// not yet implemented, but we should
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#define fcloseOrDie fclose
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fflushOrDie(): like fflush() but terminate with err msg in case of error
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fflushOrDie (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// filesize(): determine size of the file in bytes
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
size_t filesize (const wchar_t * pathname);
|
||||
size_t filesize (FILE * f);
|
||||
int64_t filesize64 (const wchar_t * pathname);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fseekOrDie(),ftellOrDie(), fget/setpos(): seek functions with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// 32-bit offsets only
|
||||
long fseekOrDie (FILE * f, long offset, int mode = SEEK_SET);
|
||||
#define ftellOrDie ftell
|
||||
uint64_t fgetpos (FILE * f);
|
||||
void fsetpos (FILE * f, uint64_t pos);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// unlinkOrDie(): unlink() with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void unlinkOrDie (const std::string & pathname);
|
||||
void unlinkOrDie (const std::wstring & pathname);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// renameOrDie(): rename() with error handling
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void renameOrDie (const std::string & from, const std::string & to);
|
||||
void renameOrDie (const std::wstring & from, const std::wstring & to);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fexists(): test if a file exists
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool fexists (const char * pathname);
|
||||
bool fexists (const wchar_t * pathname);
|
||||
inline bool fexists (const std::string & pathname) { return fexists (pathname.c_str()); }
|
||||
inline bool fexists (const std::wstring & pathname) { return fexists (pathname.c_str()); }
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// funicode(): test if a file uses unicode
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool funicode (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fskipspace(): skip space characters
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fskipspace (FILE * F);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetline(): like fgets() but terminate with err msg in case of error;
|
||||
// removes the newline character at the end (like gets()), returned buffer is
|
||||
// always 0-terminated; has second version that returns an STL string instead
|
||||
// fgetstring(): read a 0-terminated string (terminate if error)
|
||||
// fgetword(): read a space-terminated token (terminate if error)
|
||||
// fskipNewLine(): skip all white space until end of line incl. the newline
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
template<class CHAR> CHAR * fgetline (FILE * f, CHAR * buf, int size);
|
||||
template<class CHAR, size_t n> CHAR * fgetline (FILE * f, CHAR (& buf)[n]) { return fgetline (f, buf, n); }
|
||||
STRING fgetline (FILE * f);
|
||||
WSTRING fgetlinew (FILE * f);
|
||||
void fgetline (FILE * f, std::string & s, ARRAY<char> & buf);
|
||||
void fgetline (FILE * f, std::wstring & s, ARRAY<char> & buf);
|
||||
void fgetline (FILE * f, ARRAY<char> & buf);
|
||||
void fgetline (FILE * f, ARRAY<wchar_t> & buf);
|
||||
|
||||
const char * fgetstring (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (FILE * f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
const char * fgetstring (const HANDLE f, char * buf, int size);
|
||||
template<size_t n> const char * fgetstring (const HANDLE f, char (& buf)[n]) { return fgetstring (f, buf, n); }
|
||||
wstring fgetwstring (FILE * f);
|
||||
|
||||
const char * fgettoken (FILE * f, char * buf, int size);
|
||||
template<size_t n> const char * fgettoken (FILE * f, char (& buf)[n]) { return fgettoken (f, buf, n); }
|
||||
STRING fgettoken (FILE * f);
|
||||
|
||||
void fskipNewline (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputstring(): write a 0-terminated string (terminate if error)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputstring (FILE * f, const char *);
|
||||
void fputstring (const HANDLE f, const char * str);
|
||||
void fputstring (FILE * f, const std::string &);
|
||||
void fputstring (FILE * f, const wchar_t *);
|
||||
void fputstring (FILE * f, const std::wstring &);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetTag(): read a 4-byte tag & return as a string
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
STRING fgetTag (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcheckTag(): read a 4-byte tag & verify it; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcheckTag (FILE * f, const char * expectedTag);
|
||||
void fcheckTag (const HANDLE f, const char * expectedTag);
|
||||
void fcheckTag_ascii (FILE * f, const STRING & expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fcompareTag(): compare two tags; terminate if wrong tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fcompareTag (const STRING & readTag, const STRING & expectedTag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputTag(): write a 4-byte tag
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputTag (FILE * f, const char * tag);
|
||||
void fputTag(const HANDLE f, const char * tag);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fskipstring(): skip a 0-terminated string, such as a pad string
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fskipstring (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fpad(): write a 0-terminated string to pad file to a n-byte boundary
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fpad (FILE * f, int n);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetbyte(): read a byte value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
char fgetbyte (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetshort(): read a short value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
short fgetshort (FILE * f);
|
||||
short fgetshort_bigendian (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetint24(): read a 3-byte (24-bit) int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
int fgetint24 (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetint(): read an int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
int fgetint (FILE * f);
|
||||
int fgetint (const HANDLE f);
|
||||
int fgetint_bigendian (FILE * f);
|
||||
int fgetint_ascii (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfloat(): read a float value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
float fgetfloat (FILE * f);
|
||||
float fgetfloat_bigendian (FILE * f);
|
||||
float fgetfloat_ascii (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetdouble(): read a double value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
double fgetdouble (FILE * f);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetwav(): read an entire .wav file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fgetwav (FILE * f, ARRAY<short> & wav, int & sampleRate);
|
||||
void fgetwav (const wstring & fn, ARRAY<short> & wav, int & sampleRate);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputwav(): save data into a .wav file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputwav (FILE * f, const vector<short> & wav, int sampleRate, int nChannels = 1);
|
||||
void fputwav (const wstring & fn, const vector<short> & wav, int sampleRate, int nChannels = 1);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputbyte(): write a byte value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputbyte (FILE * f, char val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputshort(): write a short value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputshort (FILE * f, short val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputint24(): write a 3-byte (24-bit) int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputint24 (FILE * f, int v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputint(): write an int value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputint (FILE * f, int val);
|
||||
void fputint (const HANDLE f, int v);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfloat(): write a float value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputfloat (FILE * f, float val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputdouble(): write a double value
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputdouble (FILE * f, double val);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fputfile(): write a binary block or a string as a file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fputfile (const WSTRING & pathname, const ARRAY<char> & buffer);
|
||||
void fputfile (const WSTRING & pathname, const std::wstring & string);
|
||||
void fputfile (const WSTRING & pathname, const std::string & string);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetfile(): load a file as a binary block
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void fgetfile (const WSTRING & pathname, ARRAY<char> & buffer);
|
||||
void fgetfile (FILE * f, ARRAY<char> & buffer);
|
||||
namespace msra { namespace files {
|
||||
void fgetfilelines (const std::wstring & pathname, vector<char> & readbuffer, std::vector<std::string> & lines);
|
||||
static inline std::vector<std::string> fgetfilelines (const std::wstring & pathname) { vector<char> buffer; std::vector<std::string> lines; fgetfilelines (pathname, buffer, lines); return lines; }
|
||||
vector<char*> fgetfilelines (const wstring & pathname, vector<char> & readbuffer);
|
||||
};};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// getfiletime(), setfiletime(): access modification time
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
bool getfiletime (const std::wstring & path, FILETIME & time);
|
||||
void setfiletime (const std::wstring & path, const FILETIME & time);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// expand_wildcards() -- expand a path with wildcards (also intermediate ones)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void expand_wildcards (const wstring & path, vector<wstring> & paths);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// make_intermediate_dirs() -- make all intermediate dirs on a path
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace files {
|
||||
void make_intermediate_dirs (const wstring & filepath);
|
||||
};};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fuptodate() -- test whether an output file is at least as new as an input file
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
namespace msra { namespace files {
|
||||
bool fuptodate (const wstring & target, const wstring & input, bool inputrequired = true);
|
||||
};};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// simple support for WAV file I/O
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
typedef struct wavehder{
|
||||
char riffchar[4];
|
||||
unsigned int RiffLength;
|
||||
char wavechar[8];
|
||||
unsigned int FmtLength;
|
||||
signed short wFormatTag;
|
||||
signed short nChannels;
|
||||
unsigned int nSamplesPerSec;
|
||||
unsigned int nAvgBytesPerSec;
|
||||
signed short nBlockAlign;
|
||||
signed short wBitsPerSample;
|
||||
char datachar[4];
|
||||
unsigned int DataLength;
|
||||
private:
|
||||
void prepareRest (int SampleCount);
|
||||
public:
|
||||
void prepare (unsigned int Fs, int Bits, int Channels, int SampleCount);
|
||||
void prepare (const WAVEFORMATEX & wfx, int SampleCount);
|
||||
unsigned int read (FILE * f, signed short & wRealFormatTag, int & bytesPerSample);
|
||||
void write (FILE * f);
|
||||
static void update (FILE * f);
|
||||
} WAVEHEADER;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetwfx(), fputwfx(): I/O of wave file headers only
|
||||
// ----------------------------------------------------------------------------
|
||||
unsigned int fgetwfx (FILE *f, WAVEFORMATEX & wfx);
|
||||
void fputwfx (FILE *f, const WAVEFORMATEX & wfx, unsigned int numSamples);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// fgetraw(): read data of .wav file, and separate data of multiple channels.
|
||||
// For example, data[i][j]: i is channel index, 0 means the first
|
||||
// channel. j is sample index.
|
||||
// ----------------------------------------------------------------------------
|
||||
void fgetraw (FILE *f,ARRAY< ARRAY<short> > & data,const WAVEHEADER & wavhd);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// temp functions -- clean these up
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// split a pathname into directory and filename
|
||||
static inline void splitpath (const wstring & path, wstring & dir, wstring & file)
|
||||
{
|
||||
size_t pos = path.find_last_of (L"\\:/"); // DOS drives, UNIX, Windows
|
||||
if (pos == path.npos) // no directory found
|
||||
{
|
||||
dir.clear();
|
||||
file = path;
|
||||
}
|
||||
else
|
||||
{
|
||||
dir = path.substr (0, pos);
|
||||
file = path.substr (pos +1);
|
||||
}
|
||||
}
|
||||
|
||||
// test if a pathname is a relative path
|
||||
// A relative path is one that can be appended to a directory.
|
||||
// Drive-relative paths, such as D:file, are considered non-relative.
|
||||
static inline bool relpath (const wchar_t * path)
|
||||
{ // this is a wild collection of pathname conventions in Windows
|
||||
if (path[0] == '/' || path[0] == '\\') // e.g. \WINDOWS
|
||||
return false;
|
||||
if (path[0] && path[1] == ':') // drive syntax
|
||||
return false;
|
||||
// ... TODO: handle long NT paths
|
||||
return true; // all others
|
||||
}
|
||||
template<class CHAR>
|
||||
static inline bool relpath (const std::basic_string<CHAR> & s) { return relpath (s.c_str()); }
|
||||
|
||||
#endif // _FILEUTIL_
|
|
@ -1,951 +0,0 @@
|
|||
//
|
||||
// <copyright file="htkfeatio.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// htkfeatio.h -- helper for I/O of HTK feature files
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "basetypes.h"
|
||||
#include "fileutil.h"
|
||||
#include "simple_checked_arrays.h"
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <hash_map>
|
||||
#include <stdint.h>
|
||||
#include <limits.h>
|
||||
#include <wchar.h>
|
||||
namespace msra { namespace asr {
|
||||
|
||||
// ===========================================================================
|
||||
// htkfeatio -- common base class for reading and writing HTK feature files
|
||||
// ===========================================================================
|
||||
|
||||
class htkfeatio
|
||||
{
|
||||
protected:
|
||||
auto_file_ptr f;
|
||||
wstring physicalpath; // path of this file
|
||||
bool needbyteswapping; // need to swap the bytes?
|
||||
|
||||
string featkind; // HTK feature-kind string
|
||||
size_t featdim; // feature dimension
|
||||
unsigned int featperiod; // sampling period
|
||||
|
||||
// note that by default we assume byte swapping (seems to be HTK default)
|
||||
htkfeatio() : needbyteswapping (true), featdim (0), featperiod (0) {}
|
||||
|
||||
// set the feature kind variables --if already set then validate that they are the same
|
||||
// Path is only for error message.
|
||||
void setkind (string kind, size_t dim, unsigned int period, const wstring & path)
|
||||
{
|
||||
if (featkind.empty()) // not set yet: just memorize them
|
||||
{
|
||||
assert (featdim == 0 && featperiod == 0);
|
||||
featkind = kind;
|
||||
featdim = dim;
|
||||
featperiod = period;
|
||||
}
|
||||
else // set already: check if consistent
|
||||
{
|
||||
if (featkind != kind || featdim != dim || featperiod != period)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("setkind: inconsistent feature kind for file '%S'", path.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
static short swapshort (short v) throw()
|
||||
{
|
||||
const unsigned char * b = (const unsigned char *) &v;
|
||||
return (short) ((b[0] << 8) + b[1]);
|
||||
}
|
||||
static int swapint (int v) throw()
|
||||
{
|
||||
const unsigned char * b = (const unsigned char *) &v;
|
||||
return (int) (((((b[0] << 8) + b[1]) << 8) + b[2]) << 8) + b[3];
|
||||
}
|
||||
|
||||
struct fileheader
|
||||
{
|
||||
int nsamples;
|
||||
int sampperiod;
|
||||
short sampsize;
|
||||
short sampkind;
|
||||
void read (FILE * f)
|
||||
{
|
||||
nsamples = fgetint (f);
|
||||
sampperiod = fgetint (f);
|
||||
sampsize = fgetshort (f);
|
||||
sampkind = fgetshort (f);
|
||||
}
|
||||
|
||||
// read header of idx feature cach
|
||||
void idxRead (FILE * f)
|
||||
{
|
||||
int magic = swapint(fgetint (f));
|
||||
if (magic != 2051)
|
||||
throw std::runtime_error ("reading idx feature cache header: invalid magic");
|
||||
nsamples = swapint(fgetint(f));
|
||||
sampperiod = 0;
|
||||
sampkind = (short)9; //user type
|
||||
int nRows = swapint(fgetint(f));
|
||||
int nCols = swapint(fgetint(f));
|
||||
sampsize = (short) (nRows * nCols); // features are stored as bytes;
|
||||
}
|
||||
|
||||
void write (FILE * f)
|
||||
{
|
||||
fputint (f, nsamples);
|
||||
fputint (f, sampperiod);
|
||||
fputshort (f, sampsize);
|
||||
fputshort (f, sampkind);
|
||||
}
|
||||
void byteswap()
|
||||
{
|
||||
nsamples = swapint (nsamples);
|
||||
sampperiod = swapint (sampperiod);
|
||||
sampsize = swapshort (sampsize);
|
||||
sampkind = swapshort (sampkind);
|
||||
}
|
||||
};
|
||||
|
||||
static const int BASEMASK = 077;
|
||||
static const int PLP = 11;
|
||||
static const int MFCC = 6;
|
||||
static const int FBANK = 7;
|
||||
static const int USER = 9;
|
||||
static const int FESTREAM = 12;
|
||||
static const int HASENERGY = 0100; // _E log energy included
|
||||
static const int HASNULLE = 0200; // _N absolute energy suppressed
|
||||
static const int HASDELTA = 0400; // _D delta coef appended
|
||||
static const int HASACCS = 01000; // _A acceleration coefs appended
|
||||
static const int HASCOMPX = 02000; // _C is compressed
|
||||
static const int HASZEROM = 04000; // _Z zero meaned
|
||||
static const int HASCRCC = 010000; // _K has CRC check
|
||||
static const int HASZEROC = 020000; // _0 0'th Cepstra included
|
||||
static const int HASVQ = 040000; // _V has VQ index attached
|
||||
static const int HASTHIRD = 0100000; // _T has Delta-Delta-Delta index attached
|
||||
};
|
||||
|
||||
// ===========================================================================
|
||||
// htkfeatwriter -- write HTK feature file
|
||||
// This is designed to write a single file only (no archive mode support).
|
||||
// ===========================================================================
|
||||
|
||||
class htkfeatwriter : protected htkfeatio
|
||||
{
|
||||
size_t curframe;
|
||||
vector<float> tmp;
|
||||
public:
|
||||
short parsekind (const string & str)
|
||||
{
|
||||
vector<string> params = msra::strfun::split (str, ";");
|
||||
if (params.empty())
|
||||
throw std::runtime_error ("parsekind: invalid param kind string");
|
||||
vector<string> parts = msra::strfun::split (params[0], "_");
|
||||
// map base kind
|
||||
short sampkind;
|
||||
string basekind = parts[0];
|
||||
if (basekind == "PLP") sampkind = PLP;
|
||||
else if (basekind == "MFCC") sampkind = MFCC;
|
||||
else if (basekind == "FBANK") sampkind = FBANK;
|
||||
else if (basekind == "USER") sampkind = USER;
|
||||
else throw std::runtime_error ("parsekind: unsupported param base kind");
|
||||
// map qualifiers
|
||||
for (size_t i = 1; i < parts.size(); i++)
|
||||
{
|
||||
string opt = parts[i];
|
||||
if (opt.length() != 1)
|
||||
throw std::runtime_error ("parsekind: invalid param kind string");
|
||||
switch (opt[0])
|
||||
{
|
||||
case 'E': sampkind |= HASENERGY; break;
|
||||
case 'D': sampkind |= HASDELTA; break;
|
||||
case 'N': sampkind |= HASNULLE; break;
|
||||
case 'A': sampkind |= HASACCS; break;
|
||||
case 'T': sampkind |= HASTHIRD; break;
|
||||
case 'Z': sampkind |= HASZEROM; break;
|
||||
case '0': sampkind |= HASZEROC; break;
|
||||
default: throw std::runtime_error ("parsekind: invalid qualifier in param kind string");
|
||||
}
|
||||
}
|
||||
return sampkind;
|
||||
}
|
||||
public:
|
||||
// open the file for writing
|
||||
htkfeatwriter (wstring path, string kind, size_t dim, unsigned int period)
|
||||
{
|
||||
setkind (kind, dim, period, path);
|
||||
// write header
|
||||
fileheader H;
|
||||
H.nsamples = 0; // unknown for now, updated in close()
|
||||
H.sampperiod = period;
|
||||
const int bytesPerValue = sizeof (float); // we do not support compression for now
|
||||
H.sampsize = (short) featdim * bytesPerValue;
|
||||
H.sampkind = parsekind (kind);
|
||||
if (needbyteswapping)
|
||||
H.byteswap();
|
||||
f = fopenOrDie (path, L"wbS");
|
||||
H.write (f);
|
||||
curframe = 0;
|
||||
}
|
||||
// write a frame
|
||||
void write (const vector<float> & v)
|
||||
{
|
||||
if (v.size() != featdim)
|
||||
throw std::logic_error ("htkfeatwriter: inconsistent feature dimension");
|
||||
if (needbyteswapping)
|
||||
{
|
||||
tmp.resize (v.size());
|
||||
foreach_index (k, v) tmp[k] = v[k];
|
||||
msra::util::byteswap (tmp);
|
||||
fwriteOrDie (tmp, f);
|
||||
}
|
||||
else
|
||||
fwriteOrDie (v, f);
|
||||
curframe++;
|
||||
}
|
||||
// finish
|
||||
// This updates the header.
|
||||
// BUGBUG: need to implement safe-save semantics! Otherwise won't work reliably with -make mode.
|
||||
// ... e.g. set DeleteOnClose temporarily, and clear at the end?
|
||||
void close (size_t numframes)
|
||||
{
|
||||
if (curframe != numframes)
|
||||
throw std::logic_error ("htkfeatwriter: inconsistent number of frames passed to close()");
|
||||
fflushOrDie (f);
|
||||
// now implant the length field; it's at offset 0
|
||||
int nSamplesFile = (int) numframes;
|
||||
if (needbyteswapping)
|
||||
nSamplesFile = swapint (nSamplesFile);
|
||||
fseekOrDie (f, 0);
|
||||
fputint (f, nSamplesFile);
|
||||
fflushOrDie (f);
|
||||
f = NULL; // this triggers an fclose() on auto_file_ptr
|
||||
}
|
||||
// read an entire utterance into a matrix
|
||||
// Matrix type needs to have operator(i,j) and resize(n,m).
|
||||
// We write to a tmp file first to ensure we don't leave broken files that would confuse make mode.
|
||||
template<class MATRIX> static void write (const wstring & path, const string & kindstr, unsigned int period, const MATRIX & feat)
|
||||
{
|
||||
wstring tmppath = path + L"$$"; // tmp path for make-mode compliant
|
||||
unlinkOrDie (path); // delete if old file is already there
|
||||
// write it out
|
||||
size_t featdim = feat.rows();
|
||||
size_t numframes = feat.cols();
|
||||
vector<float> v (featdim);
|
||||
htkfeatwriter W (tmppath, kindstr, feat.rows(), period);
|
||||
#ifdef SAMPLING_EXPERIMENT
|
||||
for (size_t i = 0; i < numframes; i++)
|
||||
{
|
||||
foreach_index (k, v)
|
||||
{
|
||||
float val = feat(k,i) - logf((float) SAMPLING_EXPERIMENT);
|
||||
if (i % SAMPLING_EXPERIMENT == 0)
|
||||
v[k] = val;
|
||||
else
|
||||
v[k] += (float) (log (1 + exp (val - v[k]))); // log add
|
||||
}
|
||||
if (i % SAMPLING_EXPERIMENT == SAMPLING_EXPERIMENT -1)
|
||||
W.write (v);
|
||||
}
|
||||
#else
|
||||
for (size_t i = 0; i < numframes; i++)
|
||||
{
|
||||
foreach_index (k, v)
|
||||
v[k] = feat(k,i);
|
||||
W.write (v);
|
||||
}
|
||||
#endif
|
||||
#ifdef SAMPLING_EXPERIMENT
|
||||
W.close (numframes / SAMPLING_EXPERIMENT);
|
||||
#else
|
||||
W.close (numframes);
|
||||
#endif
|
||||
// rename to final destination
|
||||
// (This would only fail in strange circumstances such as accidental multiple processes writing to the same file.)
|
||||
// renameOrDie (tmppath, path);
|
||||
}
|
||||
};
|
||||
|
||||
// ===========================================================================
|
||||
// htkfeatreader -- read HTK feature file, with archive support
|
||||
//
|
||||
// To support archives, one instance of this can (and is supposed to) be used
|
||||
// repeatedly. All feat files read on the same instance are validated to have
|
||||
// the same feature kind.
|
||||
//
|
||||
// For archives, this caches the last used file handle, in expectation that most reads
|
||||
// are sequential anyway. In conjunction with a big buffer, this makes a huge difference.
|
||||
// ===========================================================================
|
||||
|
||||
class htkfeatreader : protected htkfeatio
|
||||
{
|
||||
// information on current file
|
||||
// File handle and feature type information is stored in the underlying htkfeatio object.
|
||||
size_t physicalframes; // total number of frames in physical file
|
||||
//TODO make this nicer
|
||||
bool isidxformat; // support reading of features in idxformat as well (it's a hack, but different format's are not supported yet)
|
||||
uint64_t physicaldatastart; // byte offset of first data byte
|
||||
size_t vecbytesize; // size of one vector in bytes
|
||||
|
||||
bool addEnergy; // add in energy as data is read (will all have zero values)
|
||||
bool compressed; // is compressed to 16-bit values
|
||||
bool hascrcc; // need to skip crcc
|
||||
vector<float> a, b; // for decompression
|
||||
vector<short> tmp; // for decompression
|
||||
vector<unsigned char> tmpByteVector; // for decompression of idx files
|
||||
size_t curframe; // current # samples read so far
|
||||
size_t numframes; // number of samples for current logical file
|
||||
size_t energyElements; // how many energy elements to add if addEnergy is true
|
||||
|
||||
public:
|
||||
|
||||
// parser for complex a=b[s,e] syntax
|
||||
struct parsedpath
|
||||
{
|
||||
protected:
|
||||
friend class htkfeatreader;
|
||||
bool isarchive; // true if archive (range specified)
|
||||
bool isidxformat; // support reading of features in idxformat as well (it's a hack, but different format's are not supported yet)
|
||||
wstring xpath; // original full path specification as passed to constructor (for error messages)
|
||||
wstring logicalpath; // virtual path that this file should be understood to belong to
|
||||
wstring archivepath; // physical path of archive file
|
||||
size_t s, e; // first and last frame inside the archive file; (0, INT_MAX) if not given
|
||||
void malformed() const { throw std::runtime_error (msra::strfun::strprintf ("parsedpath: malformed path '%S'", xpath.c_str())); }
|
||||
|
||||
// consume and return up to 'delim'; remove from 'input' (we try to avoid C++0x here for VS 2008 compat)
|
||||
wstring consume (wstring & input, const wchar_t * delim)
|
||||
{
|
||||
vector<wstring> parts = msra::strfun::split (input, delim); // (not very efficient, but does not matter here)
|
||||
if (parts.size() == 1) input.clear(); // not found: consume to end
|
||||
else input = parts[1]; // found: break at delimiter
|
||||
return parts[0];
|
||||
}
|
||||
public:
|
||||
// constructor parses a=b[s,e] syntax and fills in the file
|
||||
// Can be used implicitly e.g. by passing a string to open().
|
||||
parsedpath (wstring xpath) : xpath (xpath)
|
||||
{
|
||||
// parse out logical path
|
||||
logicalpath = consume (xpath, L"=");
|
||||
isidxformat = false;
|
||||
if (xpath.empty()) // no '=' detected: pass entire file (it's not an archive)
|
||||
{
|
||||
archivepath = logicalpath;
|
||||
s = 0;
|
||||
e = INT_MAX;
|
||||
isarchive = false;
|
||||
// check for "-ubyte" suffix in path name => it is an idx file
|
||||
wstring ubyte(L"-ubyte");
|
||||
size_t pos = archivepath.size() >= ubyte.size() ? archivepath.size() - ubyte.size() : 0;
|
||||
wstring suffix = archivepath.substr(pos , ubyte.size());
|
||||
isidxformat = ubyte == suffix;
|
||||
}
|
||||
else // a=b[s,e] syntax detected
|
||||
{
|
||||
archivepath = consume (xpath, L"[");
|
||||
if (xpath.empty()) // actually it's only a=b
|
||||
{
|
||||
s = 0;
|
||||
e = INT_MAX;
|
||||
isarchive = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
s = msra::strfun::toint (consume (xpath, L","));
|
||||
if (xpath.empty()) malformed();
|
||||
e = msra::strfun::toint (consume (xpath, L"]"));
|
||||
if (!xpath.empty()) malformed();
|
||||
isarchive = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get the physical path for 'make' test
|
||||
const wstring & physicallocation() const { return archivepath; }
|
||||
|
||||
// casting to wstring yields the logical path
|
||||
operator const wstring & () const { return logicalpath; }
|
||||
|
||||
// get duration in frames
|
||||
size_t numframes() const
|
||||
{
|
||||
if (!isarchive)
|
||||
throw runtime_error ("parsedpath: this mode requires an input script with start and end frames given");
|
||||
return e - s + 1;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
// open the physical HTK file
|
||||
// This is different from the logical (virtual) path name in the case of an archive.
|
||||
void openphysical (const parsedpath & ppath)
|
||||
{
|
||||
wstring physpath = ppath.physicallocation();
|
||||
//auto_file_ptr f = fopenOrDie (physpath, L"rbS");
|
||||
auto_file_ptr f (fopenOrDie (physpath, L"rb")); // removed 'S' for now, as we mostly run local anyway, and this will speed up debugging
|
||||
|
||||
// read the header (12 bytes for htk feature files)
|
||||
fileheader H;
|
||||
isidxformat = ppath.isidxformat;
|
||||
if (!isidxformat)
|
||||
H.read (f);
|
||||
else // read header of idxfile
|
||||
H.idxRead (f);
|
||||
|
||||
// take a guess as to whether we need byte swapping or not
|
||||
bool needbyteswapping = ((unsigned int) swapint (H.sampperiod) < (unsigned int) H.sampperiod);
|
||||
if (needbyteswapping)
|
||||
H.byteswap();
|
||||
|
||||
// interpret sampkind
|
||||
int basekind = H.sampkind & BASEMASK;
|
||||
string kind;
|
||||
switch (basekind)
|
||||
{
|
||||
case PLP: kind = "PLP"; break;
|
||||
case MFCC: kind = "MFCC"; break;
|
||||
case FBANK: kind = "FBANK"; break;
|
||||
case USER: kind = "USER"; break;
|
||||
case FESTREAM: kind = "USER"; break; // we return this as USER type (with guid)
|
||||
default: throw std::runtime_error ("htkfeatreader:unsupported feature kind");
|
||||
}
|
||||
// add qualifiers
|
||||
if (H.sampkind & HASENERGY) kind += "_E";
|
||||
if (H.sampkind & HASDELTA) kind += "_D";
|
||||
if (H.sampkind & HASNULLE) kind += "_N";
|
||||
if (H.sampkind & HASACCS) kind += "_A";
|
||||
if (H.sampkind & HASTHIRD) kind += "_T";
|
||||
bool compressed = (H.sampkind & HASCOMPX) != 0;
|
||||
bool hascrcc = (H.sampkind & HASCRCC) != 0;
|
||||
if (H.sampkind & HASZEROM) kind += "_Z";
|
||||
if (H.sampkind & HASZEROC) kind += "_0";
|
||||
if (H.sampkind & HASVQ) throw std::runtime_error ("htkfeatreader:we do not support VQ");
|
||||
// skip additional GUID in FESTREAM features
|
||||
if (H.sampkind == FESTREAM)
|
||||
{ // ... note: untested
|
||||
unsigned char guid[16];
|
||||
freadOrDie (&guid, sizeof (guid), 1, f);
|
||||
kind += ";guid=";
|
||||
for (int i = 0; i < sizeof (guid)/sizeof (*guid); i++)
|
||||
kind += msra::strfun::strprintf ("%02x", guid[i]);
|
||||
}
|
||||
|
||||
// other checks
|
||||
size_t bytesPerValue = isidxformat ? 1 : (compressed ? sizeof (short) : sizeof (float));
|
||||
|
||||
if (H.sampsize % bytesPerValue != 0) throw std::runtime_error ("htkfeatreader:sample size not multiple of dimension");
|
||||
size_t dim = H.sampsize / bytesPerValue;
|
||||
|
||||
// read the values for decompressing
|
||||
vector<float> a, b;
|
||||
if (compressed)
|
||||
{
|
||||
freadOrDie (a, dim, f);
|
||||
freadOrDie (b, dim, f);
|
||||
H.nsamples -= 4; // these are counted as 4 frames--that's the space they use
|
||||
if (needbyteswapping) { msra::util::byteswap (a); msra::util::byteswap (b); }
|
||||
}
|
||||
|
||||
// done: swap it in
|
||||
int64_t bytepos = fgetpos (f);
|
||||
setkind (kind, dim, H.sampperiod, ppath); // this checks consistency
|
||||
this->physicalpath.swap (physpath);
|
||||
this->physicaldatastart = bytepos;
|
||||
this->physicalframes = H.nsamples;
|
||||
this->f.swap (f); // note: this will get the previous f auto-closed at the end of this function
|
||||
this->needbyteswapping = needbyteswapping;
|
||||
this->compressed = compressed;
|
||||
this->a.swap (a);
|
||||
this->b.swap (b);
|
||||
this->vecbytesize = H.sampsize;
|
||||
this->hascrcc = hascrcc;
|
||||
}
|
||||
void close() // force close the open file --use this in case of read failure
|
||||
{
|
||||
f = NULL; // assigning a new FILE* to f will close the old FILE* if any
|
||||
physicalpath.clear();
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
htkfeatreader() {addEnergy = false; energyElements = 0;}
|
||||
|
||||
// helper to create a parsed-path object
|
||||
// const auto path = parse (xpath)
|
||||
parsedpath parse (const wstring & xpath) { return parsedpath (xpath); }
|
||||
|
||||
// read a feature file
|
||||
// Returns number of frames in that file.
|
||||
// This understands the more complex syntax a=b[s,e] and optimizes a little
|
||||
size_t open (const parsedpath & ppath)
|
||||
{
|
||||
// do not reopen the file if it is the same; use fsetpos() instead
|
||||
if (f == NULL || ppath.physicallocation() != physicalpath)
|
||||
openphysical (ppath);
|
||||
|
||||
if (ppath.isarchive) // reading a sub-range from an archive
|
||||
{
|
||||
if (ppath.s > ppath.e)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("open: start frame > end frame in '%S'", ppath.e, physicalframes, ppath.xpath.c_str()));
|
||||
if (ppath.e >= physicalframes)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("open: end frame exceeds archive's total number of frames %d in '%S'", physicalframes, ppath.xpath.c_str()));
|
||||
|
||||
int64_t dataoffset = physicaldatastart + ppath.s * vecbytesize;
|
||||
fsetpos (f, dataoffset); // we assume fsetpos(), which is our own, is smart to not flush the read buffer
|
||||
curframe = 0;
|
||||
numframes = ppath.e + 1 - ppath.s;
|
||||
}
|
||||
else // reading a full file
|
||||
{
|
||||
curframe = 0;
|
||||
numframes = physicalframes;
|
||||
assert (fgetpos (f) == physicaldatastart);
|
||||
}
|
||||
return numframes;
|
||||
}
|
||||
// get dimension and type information for a feature file
|
||||
// This will alter the state of this object in that it opens the file. It is efficient to read it right afterwards
|
||||
void getinfo (const parsedpath & ppath, string & featkind, size_t & featdim, unsigned int & featperiod)
|
||||
{
|
||||
open (ppath);
|
||||
featkind = this->featkind;
|
||||
featdim = this->featdim;
|
||||
featperiod = this->featperiod;
|
||||
}
|
||||
|
||||
// called to add energy as we read
|
||||
void AddEnergy(size_t energyElements)
|
||||
{
|
||||
this->energyElements = energyElements;
|
||||
this->addEnergy = energyElements != 0;
|
||||
}
|
||||
const string & getfeattype() const { return featkind; }
|
||||
operator bool() const { return curframe < numframes; }
|
||||
// read a vector from the open file
|
||||
void read (std::vector<float> & v)
|
||||
{
|
||||
if (curframe >= numframes) throw std::runtime_error ("htkfeatreader:attempted to read beyond end");
|
||||
if (!compressed && !isidxformat) // not compressed--the easy one
|
||||
{
|
||||
freadOrDie (v, featdim, f);
|
||||
if (needbyteswapping) msra::util::byteswap (v);
|
||||
}
|
||||
else if (isidxformat)
|
||||
{
|
||||
// read into temp vector
|
||||
freadOrDie (tmpByteVector, featdim, f);
|
||||
v.resize (featdim);
|
||||
foreach_index (k, v)
|
||||
v[k] = (float) tmpByteVector[k];
|
||||
}
|
||||
else // need to decompress
|
||||
{
|
||||
// read into temp vector
|
||||
freadOrDie (tmp, featdim, f);
|
||||
if (needbyteswapping) msra::util::byteswap (tmp);
|
||||
// 'decompress' it
|
||||
v.resize (tmp.size());
|
||||
foreach_index (k, v)
|
||||
v[k] = (tmp[k] + b[k]) / a[k];
|
||||
}
|
||||
curframe++;
|
||||
}
|
||||
// read a sequence of vectors from the open file into a range of frames [ts,te)
|
||||
template<class MATRIX> void read (MATRIX & feat, size_t ts, size_t te)
|
||||
{
|
||||
// read vectors from file and push to our target structure
|
||||
vector<float> v(featdim+energyElements);
|
||||
for (size_t t = ts; t < te; t++)
|
||||
{
|
||||
read (v);
|
||||
// add the energy elements (all zero) if needed
|
||||
if (addEnergy)
|
||||
{
|
||||
// we add the energy elements at the end of each section of features, (features, delta, delta-delta)
|
||||
size_t posIncrement = featdim/energyElements;
|
||||
size_t pos = posIncrement;
|
||||
for (size_t i=0;i < energyElements;i++,pos+=posIncrement)
|
||||
{
|
||||
auto iter = v.begin() + pos + i;
|
||||
v.insert(iter,0.0f);
|
||||
}
|
||||
}
|
||||
foreach_index (k, v)
|
||||
feat(k,t) = v[k];
|
||||
}
|
||||
}
|
||||
// read an entire utterance into an already allocated matrix
|
||||
// Matrix type needs to have operator(i,j)
|
||||
template<class MATRIX> void read (const parsedpath & ppath, const string & kindstr, const unsigned int period, MATRIX & feat)
|
||||
{
|
||||
// open the file and check dimensions
|
||||
size_t numframes = open (ppath);
|
||||
if (feat.cols() != numframes || feat.rows() != featdim)
|
||||
throw std::logic_error ("read: stripe read called with wrong dimensions");
|
||||
if (kindstr != featkind || period != featperiod)
|
||||
throw std::logic_error ("read: attempting to mixing different feature kinds");
|
||||
|
||||
// read vectors from file and push to our target structure
|
||||
try { read (feat, 0, numframes); } catch (...) { close(); throw; }
|
||||
}
|
||||
// read an entire utterance into a virgen, allocatable matrix
|
||||
// Matrix type needs to have operator(i,j) and resize(n,m)
|
||||
template<class MATRIX> void read (const parsedpath & ppath, string & kindstr, unsigned int & period, MATRIX & feat)
|
||||
{
|
||||
// get the file
|
||||
size_t numframes = open (ppath);
|
||||
feat.resize (featdim+energyElements, numframes); // result matrix--columns are features
|
||||
|
||||
// read vectors from file and push to our target structure
|
||||
try { read (feat, 0, numframes); } catch (...) { close(); throw; }
|
||||
|
||||
// return file info
|
||||
kindstr = featkind;
|
||||
period = featperiod;
|
||||
}
|
||||
};
|
||||
|
||||
struct htkmlfentry
|
||||
{
|
||||
unsigned int firstframe; // range [firstframe,firstframe+numframes)
|
||||
unsigned int numframes;
|
||||
//unsigned short classid; // numeric state id
|
||||
unsigned int classid; // numeric state id - mseltzer changed from ushort to uint for untied cd phones > 2^16
|
||||
|
||||
private:
|
||||
// verify and save data
|
||||
void setdata (size_t ts, size_t te, size_t uid)
|
||||
{
|
||||
if (te < ts) throw std::runtime_error ("htkmlfentry: end time below start time??");
|
||||
// save
|
||||
firstframe = (unsigned int) ts;
|
||||
numframes = (unsigned int) (te - ts);
|
||||
classid = (unsigned int) uid;
|
||||
// check for numeric overflow
|
||||
if (firstframe != ts || firstframe + numframes != te || classid != uid)
|
||||
throw std::runtime_error ("htkmlfentry: not enough bits for one of the values");
|
||||
}
|
||||
|
||||
// parse the time range
|
||||
// There are two formats:
|
||||
// - original HTK
|
||||
// - Dong's hacked format: ts te senonename senoneid
|
||||
// We distinguish
|
||||
static void parseframerange (const vector<char*> & toks, size_t & ts, size_t & te, const double htkTimeToFrame)
|
||||
{
|
||||
const double maxFrameNumber = htkTimeToFrame / 2.0; // if frame number is greater than this we assume it is time instead of frame
|
||||
double rts = msra::strfun::todouble (toks[0]);
|
||||
double rte = msra::strfun::todouble (toks[1]);
|
||||
if (rte > maxFrameNumber) // convert time to frame
|
||||
{
|
||||
ts = (size_t) (rts/htkTimeToFrame + 0.5); // get start frame
|
||||
te = (size_t) (rte/htkTimeToFrame + 0.5); // get end frame
|
||||
}
|
||||
else
|
||||
{
|
||||
ts = (size_t)(rts);
|
||||
te = (size_t)(rte);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
// parse format with original HTK state align MLF format and state list
|
||||
void parsewithstatelist (const vector<char*> & toks, const hash_map<std::string, size_t> & statelisthash, const double htkTimeToFrame)
|
||||
{
|
||||
size_t ts, te;
|
||||
parseframerange (toks, ts, te, htkTimeToFrame);
|
||||
auto iter = statelisthash.find (toks[2]);
|
||||
if (iter == statelisthash.end())
|
||||
throw std::runtime_error (msra::strfun::strprintf ("htkmlfentry: state %s not found in statelist", toks[2]));
|
||||
const size_t uid = iter->second; // get state index
|
||||
setdata (ts, te, uid);
|
||||
}
|
||||
|
||||
// ... note: this will be too simplistic for parsing more complex MLF formats. Fix when needed.
|
||||
// add support so that it can handle conditions where time instead of frame numer is used.
|
||||
void parse (const vector<char*> & toks, const double htkTimeToFrame)
|
||||
{
|
||||
if (toks.size() != 4) throw std::runtime_error ("htkmlfentry: currently we only support 4-column format");
|
||||
size_t ts, te;
|
||||
parseframerange (toks, ts, te, htkTimeToFrame);
|
||||
size_t uid = msra::strfun::toint (toks[3]);
|
||||
setdata(ts, te, uid);
|
||||
}
|
||||
};
|
||||
|
||||
template<class ENTRY, class WORDSEQUENCE>
|
||||
class htkmlfreader : public map<wstring,vector<ENTRY>> // [key][i] the data
|
||||
{
|
||||
wstring curpath; // for error messages
|
||||
hash_map<std::string, size_t> statelistmap; // for state <=> index
|
||||
map<wstring,WORDSEQUENCE> wordsequences; // [key] word sequences (if we are building word entries as well, for MMI)
|
||||
|
||||
void strtok (char * s, const char * delim, vector<char*> & toks)
|
||||
{
|
||||
toks.resize (0);
|
||||
char * context = nullptr;
|
||||
for (char * p = strtok_s (s, delim, &context); p; p = strtok_s (NULL, delim, &context))
|
||||
toks.push_back (p);
|
||||
}
|
||||
void malformed (string what)
|
||||
{
|
||||
throw std::runtime_error (msra::strfun::strprintf ("htkmlfreader: %s in '%S'", what.c_str(), curpath.c_str()));
|
||||
}
|
||||
|
||||
vector<char*> readlines (const wstring & path, vector<char> & buffer)
|
||||
{
|
||||
// load it into RAM in one huge chunk
|
||||
auto_file_ptr f (fopenOrDie (path, L"rb"));
|
||||
size_t len = filesize (f);
|
||||
buffer.reserve (len +1);
|
||||
freadOrDie (buffer, len, f);
|
||||
buffer.push_back (0); // this makes it a proper C string
|
||||
|
||||
// parse into lines
|
||||
vector<char *> lines;
|
||||
lines.reserve (len / 20);
|
||||
strtok (&buffer[0], "\r\n", lines);
|
||||
return lines;
|
||||
}
|
||||
|
||||
// determine mlf entry lines range
|
||||
// lines range: [s,e)
|
||||
size_t getnextmlfstart (vector<char*> & lines, size_t s)
|
||||
{
|
||||
// determine lines range
|
||||
size_t e;
|
||||
for (e = s ; ; e++)
|
||||
{
|
||||
if (e >= lines.size()) malformed ("unexpected end in mid-utterance");
|
||||
char * ll = lines[e];
|
||||
if (ll[0] == '.' && ll[1] == 0) // end delimiter: a single dot on a line
|
||||
break;
|
||||
}
|
||||
return (e + 1);
|
||||
// lines range: [s,e)
|
||||
}
|
||||
|
||||
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
void parseentry (vector<char*> & lines, size_t & line, const set<wstring> & restricttokeys,
|
||||
const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, vector<typename WORDSEQUENCE::word> & wordseqbuffer, vector<typename WORDSEQUENCE::aligninfo> & alignseqbuffer,
|
||||
const double htkTimeToFrame)
|
||||
{
|
||||
assert (line < lines.size());
|
||||
string filename = lines[line++];
|
||||
while (filename == "#!MLF!#") // skip embedded duplicate MLF headers (so user can 'cat' MLFs)
|
||||
filename = lines[line++];
|
||||
|
||||
// some mlf file have write errors, so skip malformed entry
|
||||
if (filename.length() < 3 || filename[0] != '"' || filename[filename.length()-1] != '"')
|
||||
{
|
||||
fprintf (stderr, "warning: filename entry (%s)\n", filename.c_str());
|
||||
size_t s = line;
|
||||
line = getnextmlfstart (lines, s);
|
||||
fprintf (stderr, "skip current mlf entry form line (%lu) until line (%lu).\n", s, line);
|
||||
return;
|
||||
}
|
||||
//fprintf (stderr,"start parse %s\n", filename.c_str());
|
||||
|
||||
filename = filename.substr (1, filename.length() -2); // strip quotes
|
||||
if (filename.find ("*/") == 0) filename = filename.substr (2);
|
||||
#ifdef _WIN32
|
||||
wstring key = msra::strfun::utf16 (regex_replace (filename, regex ("\\.[^\\.\\\\/:]*$", std::regex_constants::extended), string())); // delete extension (or not if none)
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
wstring key = msra::strfun::utf16(removeExtension(basename(filename))); // note that c++ 4.8 is incomplete for supporting regex
|
||||
#endif
|
||||
//fwprintf (stderr,L"after parse %S\n",key.c_str());
|
||||
|
||||
// determine lines range
|
||||
size_t s = line;
|
||||
line = getnextmlfstart (lines, line);
|
||||
size_t e = line - 1;
|
||||
// lines range: [s,e)
|
||||
|
||||
// don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs)
|
||||
if (!restricttokeys.empty() && restricttokeys.find (key) == restricttokeys.end())
|
||||
return;
|
||||
|
||||
vector<ENTRY> & entries = (*this)[key]; // this creates a new entry
|
||||
if (!entries.empty()) malformed (msra::strfun::strprintf ("duplicate entry '%S'", key.c_str()));
|
||||
entries.resize (e-s);
|
||||
wordseqbuffer.resize (0);
|
||||
alignseqbuffer.resize (0);
|
||||
vector<char*> toks;
|
||||
for (size_t i = s; i < e; i++)
|
||||
{
|
||||
strtok (lines[i], " \t", toks);
|
||||
if (statelistmap.size() == 0)
|
||||
entries[i-s].parse (toks, htkTimeToFrame);
|
||||
else
|
||||
entries[i-s].parsewithstatelist (toks, statelistmap, htkTimeToFrame);
|
||||
// if we also read word entries, do it here
|
||||
if (wordmap)
|
||||
{
|
||||
if (toks.size() > 6/*word entry are in this column*/)
|
||||
{
|
||||
const char * w = toks[6]; // the word name
|
||||
int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM)
|
||||
size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t) wid;
|
||||
wordseqbuffer.push_back (typename WORDSEQUENCE::word (wordindex, entries[i-s].firstframe, alignseqbuffer.size()));
|
||||
}
|
||||
if (unitmap)
|
||||
{
|
||||
if (toks.size() > 4)
|
||||
{
|
||||
const char * u = toks[4]; // the triphone name
|
||||
auto iter = unitmap->find (u); // map to unit id
|
||||
if (iter == unitmap->end())
|
||||
throw std::runtime_error (string ("parseentry: unknown unit ") + u + " in utterance " + strfun::utf8 (key));
|
||||
const size_t uid = iter->second;
|
||||
alignseqbuffer.push_back (typename WORDSEQUENCE::aligninfo (uid, 0/*#frames--we accumulate*/));
|
||||
}
|
||||
if (alignseqbuffer.empty())
|
||||
throw std::runtime_error ("parseentry: lonely senone entry at start without phone/word entry found, for utterance " + strfun::utf8 (key));
|
||||
alignseqbuffer.back().frames += entries[i-s].numframes; // (we do not have an overflow check here, but should...)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (wordmap) // if reading word sequences as well (for MMI), then record it (in a separate map)
|
||||
{
|
||||
if (!entries.empty() && wordseqbuffer.empty())
|
||||
throw std::runtime_error ("parseentry: got state alignment but no word-level info, although being requested, for utterance " + strfun::utf8 (key));
|
||||
// post-process silence
|
||||
// - first !silence -> !sent_start
|
||||
// - last !silence -> !sent_end
|
||||
int silence = (*wordmap)["!silence"];
|
||||
if (silence >= 0)
|
||||
{
|
||||
int sentstart = (*wordmap)["!sent_start"]; // these must have been created
|
||||
int sentend = (*wordmap)["!sent_end"];
|
||||
// map first and last !silence to !sent_start and !sent_end, respectively
|
||||
if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t) silence)
|
||||
wordseqbuffer.front().wordindex = sentstart;
|
||||
if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t) silence)
|
||||
wordseqbuffer.back().wordindex = sentend;
|
||||
}
|
||||
//if (sentstart < 0 || sentend < 0 || silence < 0)
|
||||
// throw std::logic_error ("parseentry: word map must contain !silence, !sent_start, and !sent_end");
|
||||
// implant
|
||||
auto & wordsequence = wordsequences[key]; // this creates the map entry
|
||||
wordsequence.words = wordseqbuffer; // makes a copy
|
||||
wordsequence.align = alignseqbuffer;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
// return if input statename is sil state (hard code to compared first 3 chars with "sil")
|
||||
bool issilstate (const string & statename) const // (later use some configuration table)
|
||||
{
|
||||
return (statename.size() > 3 && statename.at(0) == 's' && statename.at(1) == 'i' && statename.at(2) == 'l');
|
||||
}
|
||||
|
||||
vector<bool> issilstatetable; // [state index] => true if is sil state (cached)
|
||||
|
||||
// return if input stateid represent sil state (by table lookup)
|
||||
bool issilstate (const size_t id) const
|
||||
{
|
||||
assert (id < issilstatetable.size());
|
||||
return issilstatetable[id];
|
||||
}
|
||||
|
||||
struct nullmap { int operator[] (const char * s) const { throw std::logic_error ("nullmap: should never be used"); } }; // to satisfy a template, never used... :(
|
||||
|
||||
// constructor reads multiple MLF files
|
||||
htkmlfreader (const vector<wstring> & paths, const set<wstring> & restricttokeys, const wstring & stateListPath = L"", const double htkTimeToFrame = 100000.0)
|
||||
{
|
||||
// read state list
|
||||
if (stateListPath != L"")
|
||||
readstatelist (stateListPath);
|
||||
|
||||
// read MLF(s) --note: there can be multiple, so this is a loop
|
||||
foreach_index (i, paths)
|
||||
read (paths[i], restricttokeys, (nullmap* /*to satisfy C++ template resolution*/) NULL, (map<string,size_t>*) NULL, htkTimeToFrame);
|
||||
}
|
||||
|
||||
// alternate constructor that optionally also reads word alignments (for MMI training); triggered by providing a 'wordmap'
|
||||
// (We cannot use an optional arg in the constructor aboe because it interferes with teh template resolution.)
|
||||
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
htkmlfreader (const vector<wstring> & paths, const set<wstring> & restricttokeys, const wstring & stateListPath, const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, const double htkTimeToFrame)
|
||||
{
|
||||
// read state list
|
||||
if (stateListPath != L"")
|
||||
readstatelist (stateListPath);
|
||||
|
||||
// read MLF(s) --note: there can be multiple, so this is a loop
|
||||
foreach_index (i, paths)
|
||||
read (paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame);
|
||||
}
|
||||
|
||||
// note: this function is not designed to be pretty but to be fast
|
||||
template<typename WORDSYMBOLTABLE, typename UNITSYMBOLTABLE>
|
||||
void read (const wstring & path, const set<wstring> & restricttokeys, const WORDSYMBOLTABLE * wordmap, const UNITSYMBOLTABLE * unitmap, const double htkTimeToFrame)
|
||||
{
|
||||
if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files)
|
||||
return;
|
||||
|
||||
fprintf (stderr, "htkmlfreader: reading MLF file %S ...", path.c_str());
|
||||
curpath = path; // for error messages only
|
||||
|
||||
vector<char> buffer; // buffer owns the characters--don't release until done
|
||||
vector<char*> lines = readlines (path, buffer);
|
||||
vector<typename WORDSEQUENCE::word> wordsequencebuffer;
|
||||
vector<typename WORDSEQUENCE::aligninfo> alignsequencebuffer;
|
||||
|
||||
if (lines.empty() || strcmp (lines[0], "#!MLF!#")) malformed ("header missing");
|
||||
|
||||
// parse entries
|
||||
fprintf (stderr, "parse the line %zu\n", lines.size());
|
||||
size_t line = 1;
|
||||
while (line < lines.size() && (restricttokeys.empty() || this->size() < restricttokeys.size()))
|
||||
parseentry (lines, line, restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame);
|
||||
|
||||
curpath.clear();
|
||||
fprintf (stderr, " total %lu entries\n", this->size());
|
||||
}
|
||||
|
||||
// read state list, index is from 0
|
||||
void readstatelist (const wstring & stateListPath = L"")
|
||||
{
|
||||
if (stateListPath != L"")
|
||||
{
|
||||
vector<char> buffer; // buffer owns the characters--don't release until done
|
||||
vector<char*> lines = readlines (stateListPath, buffer);
|
||||
size_t index;
|
||||
issilstatetable.reserve (lines.size());
|
||||
for (index = 0; index < lines.size(); index++)
|
||||
{
|
||||
statelistmap[lines[index]] = index;
|
||||
issilstatetable.push_back (issilstate (lines[index]));
|
||||
}
|
||||
if (index != statelistmap.size())
|
||||
throw std::runtime_error (msra::strfun::strprintf ("readstatelist: lines (%d) not equal to statelistmap size (%d)", index, statelistmap.size()));
|
||||
if (statelistmap.size() != issilstatetable.size())
|
||||
throw std::runtime_error (msra::strfun::strprintf ("readstatelist: size of statelookuparray (%d) not equal to statelistmap size (%d)", issilstatetable.size(), statelistmap.size()));
|
||||
fprintf (stderr, "total %lu state names in state list %S\n", statelistmap.size(), stateListPath.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// return state num: varify the fintune layer dim
|
||||
size_t getstatenum () const
|
||||
{
|
||||
return statelistmap.size();
|
||||
}
|
||||
|
||||
size_t getstateid (string statename) // added by Hang Su adaptation
|
||||
{
|
||||
return statelistmap[statename];
|
||||
}
|
||||
|
||||
// access to word sequences
|
||||
const map<wstring,WORDSEQUENCE> & allwordtranscripts() const { return wordsequences; }
|
||||
};
|
||||
|
||||
};}; // namespaces
|
|
@ -1,743 +0,0 @@
|
|||
//
|
||||
// <copyright file="latticearchive.cpp" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "basetypes.h"
|
||||
#include "fileutil.h"
|
||||
#include "htkfeatio.h" // for MLF reading for numer lattices
|
||||
#include "latticearchive.h"
|
||||
#include "msra_mgram.h" // for MLF reading for numer lattices
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <hash_map>
|
||||
#include <regex>
|
||||
|
||||
#pragma warning(disable : 4996)
|
||||
namespace msra { namespace lattices {
|
||||
|
||||
// helper to write a symbol hash (string -> int) to a file
|
||||
// File has two sections:
|
||||
// - physicalunitname // line number is mapping, starting with 0
|
||||
// - logunitname physicalunitname // establishes a mapping; logunitname will get the same numeric index as physicalunitname
|
||||
template<class UNITMAP>
|
||||
static void writeunitmap (const wstring & symlistpath, const UNITMAP & unitmap)
|
||||
{
|
||||
std::vector<std::string> units;
|
||||
units.reserve (unitmap.size());
|
||||
std::vector<std::string> mappings;
|
||||
mappings.reserve (unitmap.size());
|
||||
for (auto iter = unitmap.cbegin(); iter != unitmap.cend(); iter++) // why would 'for (auto iter : unitmap)' not work?
|
||||
{
|
||||
const std::string label = iter->first;
|
||||
const size_t unitid = iter->second;
|
||||
if (units.size() <= unitid)
|
||||
units.resize (unitid + 1); // we grow it on demand; the result must be compact (all entries filled), we check that later
|
||||
if (!units[unitid].empty()) // many-to-one mapping: remember the unit; look it up while writing
|
||||
mappings.push_back (label);
|
||||
else
|
||||
units[unitid] = label;
|
||||
}
|
||||
|
||||
auto_file_ptr flist = fopenOrDie (symlistpath, L"wb");
|
||||
// write (physical) units
|
||||
foreach_index (k, units)
|
||||
{
|
||||
if (units[k].empty())
|
||||
throw std::logic_error ("build: unitmap has gaps");
|
||||
fprintfOrDie (flist, "%s\n", units[k].c_str());
|
||||
}
|
||||
// write log-phys mappings
|
||||
foreach_index (k, mappings)
|
||||
{
|
||||
const std::string unit = mappings[k]; // logical name
|
||||
const size_t unitid = unitmap.find (unit)->second; // get its unit id; this indexes the units array
|
||||
const std::string tounit = units[unitid]; // and get the name from tehre
|
||||
fprintfOrDie (flist, "%s %s\n", unit.c_str(), tounit.c_str());
|
||||
}
|
||||
fflushOrDie (flist);
|
||||
}
|
||||
|
||||
// (little helper to do a map::find() with default value)
|
||||
template<typename MAPTYPE, typename KEYTYPE, typename VALTYPE>
|
||||
static size_t tryfind (const MAPTYPE & map, const KEYTYPE & key, VALTYPE deflt)
|
||||
{
|
||||
auto iter = map.find (key);
|
||||
if (iter == map.end())
|
||||
return deflt;
|
||||
else
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
// archive format:
|
||||
// - output files of build():
|
||||
// - OUTPATH --the resulting archive (a huge file), simple concatenation of binary blocks
|
||||
// - OUTPATH.toc --contains keys and offsets; this is how content in archive is found
|
||||
// KEY=ARCHIVE[BYTEOFFSET] // where ARCHIVE can be empty, meaning same as previous
|
||||
// - OUTPATH.symlist --list of all unit names encountered, in order of numeric index used in archive (first = index 0)
|
||||
// This file is suitable as an input to HHEd's AU command.
|
||||
// - in actual use,
|
||||
// - .toc files can be concatenated
|
||||
// - .symlist files must remain paired with the archive file
|
||||
// - for actual training, user also needs to provide, typically from an HHEd AU run:
|
||||
// - OUTPATH.tying --map from triphone units to senone sequence by name; get full phone set from .symlist above
|
||||
// UNITNAME SENONE[2] SENONE[3] SENONE[4]
|
||||
/*static*/ void archive::build (const std::vector<std::wstring> & infiles, const std::wstring & outpath,
|
||||
const std::unordered_map<std::string,size_t> & modelsymmap,
|
||||
const msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence> & labels, // non-empty: build numer lattices
|
||||
const msra::lm::CMGramLM & unigram, const msra::lm::CSymbolSet & unigramsymbols) // for numer lattices
|
||||
{
|
||||
#if 0 // little unit test helper for testing the read function
|
||||
bool test = true;
|
||||
if (test)
|
||||
{
|
||||
archive a;
|
||||
a.open (outpath + L".toc");
|
||||
lattice L;
|
||||
std::hash_map<string,size_t> symmap;
|
||||
a.getlattice (L"sw2001_A_1263622500_1374610000", L, symmap);
|
||||
a.getlattice (L"sw2001_A_1391162500_1409287500", L, symmap);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const bool numermode = !labels.empty(); // if labels are passed then we shall convert the MLFs to lattices, and 'infiles' are regular keys
|
||||
|
||||
const std::wstring tocpath = outpath + L".toc";
|
||||
const std::wstring symlistpath = outpath + L".symlist";
|
||||
|
||||
// process all files
|
||||
std::set<std::wstring> seenkeys; // (keep track of seen keys; throw error for duplicate keys)
|
||||
msra::files::make_intermediate_dirs (outpath);
|
||||
|
||||
auto_file_ptr f = fopenOrDie (outpath, L"wb");
|
||||
auto_file_ptr ftoc = fopenOrDie (tocpath, L"wb");
|
||||
size_t brokeninputfiles = 0;
|
||||
foreach_index (i, infiles)
|
||||
{
|
||||
const std::wstring & inlatpath = infiles[i];
|
||||
fprintf (stderr, "build: processing lattice '%S'\n", inlatpath.c_str());
|
||||
|
||||
// get key
|
||||
std::wstring key = regex_replace (inlatpath, wregex (L"=.*"), wstring()); // delete mapping
|
||||
key = regex_replace (key, wregex (L".*[\\\\/]"), wstring()); // delete path
|
||||
key = regex_replace (key, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
|
||||
if (!seenkeys.insert (key).second)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("build: duplicate key for lattice '%S'", inlatpath.c_str()));
|
||||
|
||||
// we fail all the time due to totally broken HDecode/copy process, OK if not too many files are missing
|
||||
bool latticeread = false;
|
||||
try
|
||||
{
|
||||
// fetch lattice
|
||||
lattice L;
|
||||
if (!numermode)
|
||||
L.fromhtklattice (inlatpath, modelsymmap); // read HTK lattice
|
||||
else
|
||||
L.frommlf (key, modelsymmap, labels, unigram, unigramsymbols); // read MLF into a numerator lattice
|
||||
latticeread = true;
|
||||
|
||||
// write to archive
|
||||
uint64_t offset = fgetpos (f);
|
||||
L.fwrite (f);
|
||||
fflushOrDie (f);
|
||||
|
||||
// write reference to TOC file --note: TOC file is a headerless UTF8 file; so don't use fprintf %S format (default code page)
|
||||
fprintfOrDie (ftoc, "%s=%s[%llu]\n", msra::strfun::utf8 (key).c_str(), ((i - brokeninputfiles) == 0) ? msra::strfun::utf8 (outpath).c_str() : "", offset);
|
||||
fflushOrDie (ftoc);
|
||||
|
||||
fprintf (stderr, "written lattice to offset %llu as '%S'\n", offset, key.c_str());
|
||||
}
|
||||
catch (const exception & e)
|
||||
{
|
||||
if (latticeread) throw; // write failure
|
||||
// we ignore read failures
|
||||
fprintf (stderr, "ERROR: skipping unreadable lattice '%S': %s\n", inlatpath.c_str(), e.what());
|
||||
brokeninputfiles++;
|
||||
}
|
||||
}
|
||||
|
||||
// write out the unit map
|
||||
// TODO: This is sort of redundant now--it gets the symmap from the HMM, i.e. always the same for all archives.
|
||||
writeunitmap (symlistpath, modelsymmap);
|
||||
|
||||
fprintf (stderr, "completed %lu out of %lu lattices (%lu read failures, %.1f%%)\n", infiles.size(), infiles.size()-brokeninputfiles, brokeninputfiles, 100.0f * brokeninputfiles / infiles.size());
|
||||
}
|
||||
|
||||
// helper to set a context value (left, right) with checking of uniqueness
|
||||
void lattice::nodecontext::setcontext (int & lr, int val)
|
||||
{
|
||||
if (lr == unknown)
|
||||
lr = val;
|
||||
else if (lr != val)
|
||||
lr = (signed short) ambiguous;
|
||||
}
|
||||
|
||||
// helper for merge() to determine the unique node contexts
|
||||
vector<lattice::nodecontext> lattice::determinenodecontexts (const msra::asr::simplesenonehmm & hset) const
|
||||
{
|
||||
const size_t spunit = tryfind (hset.getsymmap(), "sp", SIZE_MAX);
|
||||
const size_t silunit = tryfind (hset.getsymmap(), "sil", SIZE_MAX);
|
||||
vector<lattice::nodecontext> nodecontexts (nodes.size());
|
||||
nodecontexts.front().left = nodecontext::start;
|
||||
nodecontexts.front().right = nodecontext::ambiguous; // (should not happen, but won't harm either)
|
||||
nodecontexts.back().right = nodecontext::end;
|
||||
nodecontexts.back().left = nodecontext::ambiguous; // (should not happen--we require !sent_end; but who knows)
|
||||
size_t multispseen = 0; // bad entries with multi-sp
|
||||
foreach_index (j, edges)
|
||||
{
|
||||
const auto & e = edges[j];
|
||||
const size_t S = e.S;
|
||||
const size_t E = e.E;
|
||||
auto a = getaligninfo (j);
|
||||
if (a.size() == 0) // !NULL edge
|
||||
throw std::logic_error ("determinenodecontexts: !NULL edges not allowed in merging, should be removed before");
|
||||
size_t A = a[0].unit;
|
||||
size_t Z = a[a.size()-1].unit;
|
||||
if (Z == spunit)
|
||||
{
|
||||
if (a.size() < 2)
|
||||
throw std::runtime_error ("determinenodecontexts: context-free unit (/sp/) found as a single-phone word");
|
||||
else
|
||||
{
|
||||
Z = a[a.size()-2].unit;
|
||||
if (Z == spunit) // a bugg lattice --I got this from HVite, to be tracked down
|
||||
{
|
||||
// search from end once again, to print a warning
|
||||
int n;
|
||||
for (n = (int) a.size() -1; n >= 0; n--)
|
||||
if (a[n].unit != spunit)
|
||||
break;
|
||||
// ends with n = position of furthest non-sp
|
||||
if (n < 0) // only sp?
|
||||
throw std::runtime_error ("determinenodecontexts: word consists only of /sp/");
|
||||
fprintf (stderr, "determinenodecontexts: word with %lu /sp/ at the end found, edge %d\n", a.size() -1 - n, j);
|
||||
multispseen++;
|
||||
Z = a[n].unit;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (A == spunit || Z == spunit)
|
||||
{
|
||||
#if 0
|
||||
fprintf (stderr, "A=%d Z=%d fa=%d j=%d/N=%d L=%d n=%d totalalign=%d ts/te=%d/%d\n", (int) A, (int) Z, (int) e.firstalign,(int) j, (int) edges.size(), (int) nodes.size(), (int) a.size(), (int) align.size(),
|
||||
nodes[S].t, nodes[E].t);
|
||||
foreach_index (kk, a)
|
||||
fprintf (stderr, "a[%d] = %d\n", kk, a[kk].unit);
|
||||
dump (stderr, [&] (size_t i) { return hset.gethmm (i).getname(); });
|
||||
#endif
|
||||
throw std::runtime_error ("determinenodecontexts: context-free unit (/sp/) found as a start phone or second last phone");
|
||||
}
|
||||
const auto & Ahmm = hset.gethmm (A);
|
||||
const auto & Zhmm = hset.gethmm (Z);
|
||||
int Aid = (int) Ahmm.gettransPindex();
|
||||
int Zid = (int) Zhmm.gettransPindex();
|
||||
nodecontexts[S].setright (Aid);
|
||||
nodecontexts[E].setleft (Zid);
|
||||
}
|
||||
if (multispseen > 0)
|
||||
fprintf (stderr, "determinenodecontexts: %lu broken edges in %lu with multiple /sp/ at the end seen\n", multispseen, edges.size());
|
||||
// check CI conditions and put in 't'
|
||||
// We make the hard assumption that there is only one CI phone, /sil/.
|
||||
const auto & silhmm = hset.gethmm (silunit);
|
||||
int silid = silhmm.gettransPindex();
|
||||
foreach_index (i, nodecontexts)
|
||||
{
|
||||
auto & nc = nodecontexts[i];
|
||||
if ((nc.left == nodecontext::unknown) ^ (nc.right == nodecontext::unknown))
|
||||
throw std::runtime_error ("determinenodecontexts: invalid dead-end node in lattice");
|
||||
if (nc.left == nodecontext::ambiguous && nc.right != silid && nc.right != nodecontext::end)
|
||||
throw std::runtime_error ("determinenodecontexts: invalid ambiguous left context (right context is not CI)");
|
||||
if (nc.right == nodecontext::ambiguous && nc.left != silid && nc.left != nodecontext::start)
|
||||
throw std::runtime_error ("determinenodecontexts: invalid ambiguous right context (left context is not CI)");
|
||||
nc.t = nodes[i].t;
|
||||
}
|
||||
return nodecontexts; // (will this use a move constructor??)
|
||||
}
|
||||
|
||||
// compar function for sorting and merging
|
||||
bool lattice::nodecontext::operator< (const nodecontext & other) const
|
||||
{
|
||||
// sort by t, left, right, i --sort by i to make i appear before iother, as assumed in merge function
|
||||
int diff = (int) t - (int) other.t;
|
||||
if (diff == 0)
|
||||
{
|
||||
diff = left - other.left;
|
||||
if (diff == 0)
|
||||
{
|
||||
diff = right - other.right;
|
||||
if (diff == 0)
|
||||
return i < other.i; // (cannot use 'diff=' pattern since unsigned but may be SIZE_MAX)
|
||||
}
|
||||
}
|
||||
return diff < 0;
|
||||
}
|
||||
|
||||
// remove that final !NULL edge
|
||||
// We have that in HAPI lattices, but there can be only one at the end.
|
||||
void lattice::removefinalnull()
|
||||
{
|
||||
const auto & lastedge = edges.back();
|
||||
// last edge can be !NULL, recognized as having 0 alignment records
|
||||
if (lastedge.firstalign < align.size()) // has alignment records --not !NULL
|
||||
return;
|
||||
if (lastedge.S != nodes.size() -2 || lastedge.E != nodes.size() -1)
|
||||
throw std::runtime_error ("removefinalnull: malformed final !NULL edge");
|
||||
edges.resize (edges.size() -1); // remove it
|
||||
nodes.resize (nodes.size() -1); // its start node is now the new end node
|
||||
foreach_index (j, edges)
|
||||
if (edges[j].E >= nodes.size())
|
||||
throw std::runtime_error ("removefinalnull: cannot have final !NULL edge and other edges connecting to end node at the same time");
|
||||
}
|
||||
|
||||
// merge a secondary lattice into the first
|
||||
// With lots of caveats:
|
||||
// - this optimizes lattices to true unigram lattices where the only unique node condition is acoustic context
|
||||
// - no !NULL edge at the end, call removefinalnull() before
|
||||
// - this function returns an unsorted edges[] array, i.e. invalid. We sort in uniq'ed representation, which is easier.
|
||||
// This function is not elegant at all, just hard labor!
|
||||
void lattice::merge (const lattice & other, const msra::asr::simplesenonehmm & hset)
|
||||
{
|
||||
if (!edges2.empty() || !other.edges2.empty())
|
||||
throw std::logic_error ("merge: lattice(s) must be in non-uniq'ed format (V1)");
|
||||
if (!info.numframes || !other.info.numframes)
|
||||
throw std::logic_error ("merge: lattice(s) must have identical number of frames");
|
||||
|
||||
// establish node contexts
|
||||
auto contexts = determinenodecontexts (hset);
|
||||
auto othercontexts = other.determinenodecontexts (hset);
|
||||
|
||||
// create joint node space and node mapping
|
||||
// This also collapses non-unique nodes.
|
||||
// Note the edge case sil-sil in one lattice which may be sil-ambiguous or ambiguous-sil on the other.
|
||||
// We ignore this, keeping such nodes unmerged. That's OK since middle /sil/ words have zero LM, and thus it's OK to keep them non-connected.
|
||||
foreach_index (i, contexts) contexts[i].i = i;
|
||||
foreach_index (i, othercontexts) othercontexts[i].iother = i;
|
||||
contexts.insert (contexts.end(), othercontexts.begin(), othercontexts.end()); // append othercontext
|
||||
sort (contexts.begin(), contexts.end());
|
||||
vector<size_t> nodemap (nodes.size(), SIZE_MAX);
|
||||
vector<size_t> othernodemap (other.nodes.size(), SIZE_MAX);
|
||||
int j = 0;
|
||||
foreach_index (i, contexts) // merge identical nodes --this is the critical step
|
||||
{
|
||||
if (j == 0 || contexts[j-1].t != contexts[i].t || contexts[j-1].left != contexts[i].left || contexts[j-1].right != contexts[i].right)
|
||||
contexts[j++] = contexts[i]; // entered a new one
|
||||
// node map
|
||||
if (contexts[i].i != SIZE_MAX)
|
||||
nodemap[contexts[i].i] = j-1;
|
||||
if (contexts[i].iother != SIZE_MAX)
|
||||
othernodemap[contexts[i].iother] = j-1;
|
||||
}
|
||||
fprintf (stderr, "merge: joint node space uniq'ed to %d from %d\n", j, contexts.size());
|
||||
contexts.resize (j);
|
||||
|
||||
// create a new node array (just copy the contexts[].t fields)
|
||||
nodes.resize (contexts.size());
|
||||
foreach_index (inew, nodes)
|
||||
nodes[inew].t = (unsigned short) contexts[inew].t;
|
||||
info.numnodes = nodes.size();
|
||||
|
||||
// incorporate the alignment records
|
||||
const size_t alignoffset = align.size();
|
||||
align.insert (align.end(), other.align.begin(), other.align.end());
|
||||
|
||||
// map existing edges' S and E fields, and also 'firstalign'
|
||||
foreach_index (j, edges)
|
||||
{
|
||||
edges[j].S = nodemap[edges[j].S];
|
||||
edges[j].E = nodemap[edges[j].E];
|
||||
}
|
||||
auto otheredges = other.edges;
|
||||
foreach_index (j, otheredges)
|
||||
{
|
||||
otheredges[j].S = othernodemap[otheredges[j].S];
|
||||
otheredges[j].E = othernodemap[otheredges[j].E];
|
||||
otheredges[j].firstalign += alignoffset; // that's where they are now
|
||||
}
|
||||
|
||||
// at this point, a new 'nodes' array exists, and the edges already are w.r.t. the new node space and align space
|
||||
|
||||
// now we are read to merge 'other' edges into this, simply by concatenation
|
||||
edges.insert (edges.end(), otheredges.begin(), otheredges.end());
|
||||
|
||||
// remove acoustic scores --they are likely not identical if they come from different decoders
|
||||
// If we don't do that, this will break the sorting in builduniquealignments()
|
||||
info.hasacscores = 0;
|
||||
foreach_index (j, edges)
|
||||
edges[j].a = 0.0f;
|
||||
|
||||
// Note: we have NOT sorted or de-duplicated yet. That is best done after conversion to the uniq'ed format.
|
||||
}
|
||||
|
||||
// remove duplicates
|
||||
// This must be called in uniq'ed format.
|
||||
void lattice::dedup()
|
||||
{
|
||||
if (edges2.empty())
|
||||
throw std::logic_error ("dedup: lattice must be in uniq'ed format (V2)");
|
||||
|
||||
size_t k = 0;
|
||||
foreach_index (j, edges2)
|
||||
{
|
||||
if (k > 0 && edges2[k-1].S == edges2[j].S && edges2[k-1].E == edges2[j].E && edges2[k-1].firstalign == edges2[j].firstalign)
|
||||
{
|
||||
if (edges2[k-1].implysp != edges2[j].implysp)
|
||||
throw std::logic_error ("dedup: inconsistent 'implysp' flag for otherwise identical edges");
|
||||
continue;
|
||||
}
|
||||
edges2[k++] = edges2[j];
|
||||
}
|
||||
fprintf (stderr, "dedup: edges reduced to %d from %d\n", k, edges2.size());
|
||||
edges2.resize (k);
|
||||
info.numedges = edges2.size();
|
||||
edges.clear(); // (should already be, but isn't; make sure we no longer use it)
|
||||
}
|
||||
|
||||
// load all lattices from a TOC file and write them to a new archive
|
||||
// Use this to
|
||||
// - upgrade the file format to latest in case of format changes
|
||||
// - check consistency (read only; don't write out)
|
||||
// - dump to stdout
|
||||
// - merge two lattices (for merging numer into denom lattices)
|
||||
// Input path is an actual TOC path, output is the stem (.TOC will be added). --yes, not nice, maybe fix it later
|
||||
// Example command:
|
||||
// convertlatticearchive --latticetocs dummy c:\smbrdebug\sw20_small.den.lats.toc.10 -w c:\smbrdebug\sw20_small.den.lats.converted --cdphonetying c:\smbrdebug\combined.tying --statelist c:\smbrdebug\swb300h.9304.aligned.statelist --transprobs c:\smbrdebug\MMF.9304.transprobs
|
||||
// How to regenerate from my test lattices:
|
||||
// buildlatticearchive c:\smbrdebug\sw20_small.den.lats.regenerated c:\smbrdebug\hvitelat\*lat
|
||||
// We support two special output path syntaxs:
|
||||
// - empty ("") -> don't output, just check the format
|
||||
// - dash ("-") -> dump lattice to stdout instead
|
||||
/*static*/ void archive::convert (const std::wstring & intocpath, const std::wstring & intocpath2, const std::wstring & outpath,
|
||||
const msra::asr::simplesenonehmm & hset)
|
||||
{
|
||||
const auto & modelsymmap = hset.getsymmap();
|
||||
|
||||
const std::wstring tocpath = outpath + L".toc";
|
||||
const std::wstring symlistpath = outpath + L".symlist";
|
||||
|
||||
// open input archive
|
||||
// TODO: I find that HVite emits redundant physical triphones, and even HHEd seems so (in .tying file).
|
||||
// Thus, we should uniq the units before sorting. We can do that here if we have the .tying file.
|
||||
// And then use the modelsymmap to map them down.
|
||||
// Do this directly in the hset module (it will be transparent).
|
||||
std::vector<std::wstring> intocpaths (1, intocpath); // set of paths consisting of 1
|
||||
msra::lattices::archive archive (intocpaths, modelsymmap);
|
||||
|
||||
// secondary archive for optional merging operation
|
||||
const bool mergemode = !intocpath2.empty(); // true if merging two lattices
|
||||
std::vector<std::wstring> intocpaths2;
|
||||
if (mergemode)
|
||||
intocpaths2.push_back (intocpath2);
|
||||
msra::lattices::archive archive2 (intocpaths2, modelsymmap); // (if no merging then this archive2 is empty)
|
||||
|
||||
// read the intocpath file once again to get the keys in original order
|
||||
std::vector<char> textbuffer;
|
||||
auto toclines = msra::files::fgetfilelines (intocpath, textbuffer);
|
||||
|
||||
auto_file_ptr f = NULL;
|
||||
auto_file_ptr ftoc = NULL;
|
||||
|
||||
// process all files
|
||||
if (outpath != L"" && outpath != L"-") // test for special syntaxes that bypass to actually create an output archive
|
||||
{
|
||||
msra::files::make_intermediate_dirs (outpath);
|
||||
f = fopenOrDie (outpath, L"wb");
|
||||
ftoc = fopenOrDie (tocpath, L"wb");
|
||||
}
|
||||
vector<const char *> invmodelsymmap; // only used for dump() mode
|
||||
|
||||
// we must parse the toc file once again to get the keys in original order
|
||||
size_t skippedmerges = 0;
|
||||
foreach_index (i, toclines)
|
||||
{
|
||||
const char * line = toclines[i];
|
||||
const char * p = strchr (line, '=');
|
||||
if (p == NULL)
|
||||
throw std::runtime_error ("open: invalid TOC line (no = sign): " + std::string (line));
|
||||
const std::wstring key = msra::strfun::utf16 (std::string (line, p - line));
|
||||
|
||||
fprintf (stderr, "convert: processing lattice '%S'\n", key.c_str());
|
||||
|
||||
// fetch lattice --this performs any necessary format conversions already
|
||||
lattice L;
|
||||
archive.getlattice (key, L);
|
||||
|
||||
lattice L2;
|
||||
if (mergemode)
|
||||
{
|
||||
if (!archive2.haslattice (key))
|
||||
{
|
||||
fprintf (stderr, "convert: cannot merge because lattice '%S' missing in secondary archive; skipping\n", key.c_str());
|
||||
skippedmerges++;
|
||||
continue;
|
||||
}
|
||||
archive2.getlattice (key, L2);
|
||||
|
||||
// merge it in
|
||||
// This will connect each node with matching 1-phone context conditions; aimed at merging numer lattices.
|
||||
L.removefinalnull(); // get rid of that final !NULL headache
|
||||
L2.removefinalnull();
|
||||
L.merge (L2, hset);
|
||||
// note: we are left with dups due to true unigram merging (HTK lattices cannot represent true unigram lattices since id is on the nodes)
|
||||
}
|
||||
//L.removefinalnull();
|
||||
//L.determinenodecontexts (hset);
|
||||
|
||||
// convert it --TODO: once we permanently use the new format, do this in fread() for V1
|
||||
// Note: Merging may have left this in unsorted format; we need to be robust against that.
|
||||
const size_t spunit = tryfind (modelsymmap, "sp", SIZE_MAX);
|
||||
L.builduniquealignments (spunit);
|
||||
|
||||
if (mergemode)
|
||||
L.dedup();
|
||||
|
||||
if (f && ftoc)
|
||||
{
|
||||
// write to archive
|
||||
uint64_t offset = fgetpos (f);
|
||||
L.fwrite (f);
|
||||
fflushOrDie (f);
|
||||
|
||||
// write reference to TOC file --note: TOC file is a headerless UTF8 file; so don't use fprintf %S format (default code page)
|
||||
fprintfOrDie (ftoc, "%s=%s[%llu]\n", msra::strfun::utf8 (key).c_str(), (i == 0) ? msra::strfun::utf8 (outpath).c_str() : "", offset);
|
||||
fflushOrDie (ftoc);
|
||||
|
||||
fprintf (stderr, "written converted lattice to offset %llu as '%S'\n", offset, key.c_str());
|
||||
}
|
||||
else if (outpath == L"-")
|
||||
{
|
||||
if (invmodelsymmap.empty()) // build this lazily
|
||||
{
|
||||
invmodelsymmap.resize (modelsymmap.size());
|
||||
for (auto iter = modelsymmap.begin(); iter != modelsymmap.end(); iter++)
|
||||
invmodelsymmap[iter->second] = iter->first.c_str();
|
||||
}
|
||||
L.rebuildedges (false);
|
||||
L.dump (stdout, [&] (size_t i) { return invmodelsymmap[i]; } );
|
||||
}
|
||||
} // end for (toclines)
|
||||
if (skippedmerges > 0)
|
||||
fprintf (stderr, "convert: %d out of %d merge operations skipped due to secondary lattice missing\n", skippedmerges, toclines.size());
|
||||
|
||||
// write out the updated unit map
|
||||
if (f && ftoc)
|
||||
writeunitmap (symlistpath, modelsymmap);
|
||||
|
||||
fprintf (stderr, "converted %d lattices\n", toclines.size());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// reading lattices from external formats (HTK lat, MLF)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// read an HTK lattice
|
||||
// The lattice is expected to be freshly constructed (I did not bother to check).
|
||||
void lattice::fromhtklattice (const wstring & path, const std::unordered_map<std::string,size_t> & unitmap)
|
||||
{
|
||||
vector<char> textbuffer;
|
||||
auto lines = msra::files::fgetfilelines (path, textbuffer);
|
||||
if (lines.empty())
|
||||
throw std::runtime_error ("lattice: mal-formed lattice--empty input file (or all-zeroes)");
|
||||
auto iter = lines.begin();
|
||||
// parse out LMF and WP
|
||||
char dummychar = 0; // dummy for sscanf() end checking
|
||||
for ( ; iter != lines.end() && strncmp (*iter, "N=", 2); iter++)
|
||||
{
|
||||
if (strncmp (*iter, "lmscale=", 8) == 0) // note: HTK sometimes generates extra garbage space at the end of this line
|
||||
if (sscanf_s (*iter, "lmscale=%f wdpenalty=%f%c", &info.lmf, &info.wp, &dummychar, sizeof (dummychar)) != 2 && dummychar != ' ')
|
||||
throw std::runtime_error ("lattice: mal-formed lmscale/wdpenalty line in lattice: " + string (*iter));
|
||||
}
|
||||
|
||||
// parse N and L
|
||||
if (iter != lines.end())
|
||||
{
|
||||
unsigned long N, L;
|
||||
if (sscanf_s (*iter, "N=%lu L=%lu %c", &N, &L, &dummychar, sizeof (dummychar)) != 2)
|
||||
throw std::runtime_error ("lattice: mal-formed N=/L= line in lattice: " + string (*iter));
|
||||
info.numnodes = N;
|
||||
info.numedges = L;
|
||||
iter++;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error ("lattice: mal-formed before parse N=/L= line in lattice.");
|
||||
|
||||
ASSERT(info.numnodes > 0);
|
||||
nodes.reserve (info.numnodes);
|
||||
// parse the nodes
|
||||
for (size_t i = 0; i < info.numnodes; i++, iter++)
|
||||
{
|
||||
if (iter == lines.end())
|
||||
throw std::runtime_error ("lattice: not enough I lines in lattice");
|
||||
unsigned long itest;
|
||||
float t;
|
||||
if (sscanf_s (*iter, "I=%lu t=%f%c", &itest, &t, &dummychar, sizeof (dummychar)) < 2)
|
||||
throw std::runtime_error ("lattice: mal-formed node line in lattice: " + string (*iter));
|
||||
if (i != (size_t) itest)
|
||||
throw std::runtime_error ("lattice: out-of-sequence node line in lattice: " + string (*iter));
|
||||
nodes.push_back (nodeinfo ((unsigned int) (t / info.frameduration + 0.5)));
|
||||
info.numframes = max (info.numframes, (size_t) nodes.back().t);
|
||||
}
|
||||
// parse the edges
|
||||
ASSERT(info.numedges > 0);
|
||||
edges.reserve (info.numedges);
|
||||
align.reserve (info.numedges * 10); // 10 phones per word on av. should be enough
|
||||
std::string label;
|
||||
for (size_t j = 0; j < info.numedges; j++, iter++)
|
||||
{
|
||||
if (iter == lines.end())
|
||||
throw std::runtime_error ("lattice: not enough J lines in lattice");
|
||||
unsigned long jtest;
|
||||
unsigned long S, E;
|
||||
float a, l;
|
||||
char d[1024];
|
||||
// example:
|
||||
// J=12 S=1 E=13 a=-326.81 l=-5.090 d=:sil-t:s+k:e,0.03:dh:m-ax:m+sil,0.03:sil,0.02:
|
||||
int nvals = sscanf_s (*iter, "J=%lu S=%lu E=%lu a=%f l=%f d=%s", &jtest, &S, &E, &a, &l, &d, sizeof (d));
|
||||
if (nvals == 5 && j == info.numedges - 1) // special case: last edge is a !NULL and thus may have the d= record missing
|
||||
strcpy (d, ":");
|
||||
else if (nvals != 6)
|
||||
throw std::runtime_error ("lattice: mal-formed edge line in lattice: " + string (*iter));
|
||||
if (j != (size_t) jtest)
|
||||
throw std::runtime_error ("lattice: out-of-sequence edge line in lattice: " + string (*iter));
|
||||
edges.push_back (edgeinfowithscores (S, E, a, l, align.size()));
|
||||
// build align array
|
||||
size_t edgeframes = 0; // (for checking whether the alignment sums up right)
|
||||
const char * p = d;
|
||||
if (p[0] != ':' || (p[1] == 0 && j < info.numedges-1)) // last edge may be empty
|
||||
throw std::runtime_error ("lattice: alignment info must start with a colon and must have at least one entry: " + string (*iter));
|
||||
p++;
|
||||
while (*p)
|
||||
{
|
||||
// p points to an entry of the form TRIPHONE,DURATION
|
||||
const char * q = strchr (p, ',');
|
||||
if (q == NULL)
|
||||
throw std::runtime_error ("lattice: alignment entry lacking a comma: " + string (*iter));
|
||||
if (q == p)
|
||||
throw std::runtime_error ("lattice: alignment entry label empty: " + string (*iter));
|
||||
label.assign (p, q-p); // the triphone label
|
||||
q++;
|
||||
char * ep;
|
||||
double duration = strtod (q, &ep); // (weird--returns a non-const ptr in ep to a const object)
|
||||
p = ep;
|
||||
if (*p != ':')
|
||||
throw std::runtime_error ("lattice: alignment entry not ending with a colon: " + string (*iter));
|
||||
p++;
|
||||
// create the alignment entry
|
||||
const size_t frames = (unsigned int) (duration / info.frameduration + 0.5);
|
||||
auto it = unitmap.find (label);
|
||||
if (it == unitmap.end())
|
||||
throw std::runtime_error ("lattice: unit in alignment that is not in model: " + label);
|
||||
const size_t unitid = it->second;
|
||||
//const size_t unitid = unitmap.insert (make_pair (label, unitmap.size())).first->second; // may create a new entry with index = #entries
|
||||
align.push_back (aligninfo (unitid, frames));
|
||||
edgeframes += frames;
|
||||
}
|
||||
if (edgeframes != nodes[E].t - (size_t) nodes[S].t)
|
||||
{
|
||||
char msg[128];
|
||||
sprintf (msg, "\n-- where edgeframes=%d != (nodes[E].t - nodes[S].t=%d), the gap is %d.", edgeframes, nodes[E].t - (size_t) nodes[S].t, edgeframes + nodes[S].t - nodes[E].t);
|
||||
throw std::runtime_error ("lattice: alignment info duration mismatches edge duration: " + string (*iter) + msg);
|
||||
}
|
||||
}
|
||||
if (iter != lines.end())
|
||||
throw std::runtime_error ("lattice: unexpected garbage at end of lattice: " + string (*iter));
|
||||
checklattice();
|
||||
|
||||
// create more efficient storage for alignments
|
||||
const size_t spunit = tryfind (unitmap, "sp", SIZE_MAX);
|
||||
builduniquealignments (spunit);
|
||||
|
||||
showstats();
|
||||
}
|
||||
|
||||
// construct a numerator lattice from an MLF entry
|
||||
// The lattice is expected to be freshly constructed (I did not bother to check).
|
||||
void lattice::frommlf (const wstring & key, const std::unordered_map<std::string,size_t> & unitmap,
|
||||
const msra::asr::htkmlfreader<msra::asr::htkmlfentry,lattice::htkmlfwordsequence> & labels,
|
||||
const msra::lm::CMGramLM & unigram, const msra::lm::CSymbolSet & unigramsymbols)
|
||||
{
|
||||
const auto & transcripts = labels.allwordtranscripts(); // (TODO: we could just pass the transcripts map--does not really matter)
|
||||
|
||||
// get the labels (state and word)
|
||||
auto iter = transcripts.find (key);
|
||||
if (iter == transcripts.end())
|
||||
throw std::runtime_error ("frommlf: no reference word sequence in MLF for lattice with key " + strfun::utf8 (key));
|
||||
const auto & transcript = iter->second;
|
||||
if (transcript.words.size() == 0)
|
||||
throw std::runtime_error ("frommlf: empty reference word sequence for lattice with key " + strfun::utf8 (key));
|
||||
|
||||
// determine unigram scores for all words
|
||||
vector<float> lmscores (transcript.words.size());
|
||||
size_t silence = unigramsymbols["!silence"];
|
||||
size_t lmend = unigramsymbols["</s>"];
|
||||
size_t sentstart = unigramsymbols["!sent_start"];
|
||||
size_t sentend = unigramsymbols["!sent_end"];
|
||||
|
||||
// create the lattice
|
||||
nodes.resize (transcript.words.size() +1);
|
||||
edges.resize (transcript.words.size());
|
||||
align.reserve (transcript.align.size());
|
||||
size_t numframes = 0;
|
||||
foreach_index (j, transcript.words)
|
||||
{
|
||||
const auto & w = transcript.words[j];
|
||||
nodes[j].t = w.firstframe;
|
||||
auto & e = edges[j];
|
||||
e.unused = 0;
|
||||
e.S = j;
|
||||
e.E = j+1;
|
||||
if (e.E != j+1)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("frommlf: too many tokens to be represented as edgeinfo::E in label set: %S", key.c_str()));
|
||||
e.a = 0.0f; // no ac score
|
||||
|
||||
// LM score
|
||||
// !sent_start and !silence are patched to LM score 0
|
||||
size_t wid = w.wordindex;
|
||||
if (wid == sentstart)
|
||||
{
|
||||
if (j != 0)
|
||||
throw std::logic_error ("frommlf: found an !sent_start token not at the first position");
|
||||
}
|
||||
else if (wid == sentend)
|
||||
{
|
||||
if (j != (int) transcript.words.size()-1)
|
||||
throw std::logic_error ("frommlf: found an !sent_end token not at the end position");
|
||||
wid = lmend; // use </s> for score lookup
|
||||
}
|
||||
const int iwid = (int) wid;
|
||||
e.l = (wid != sentstart && wid != silence) ? (float) unigram.score (&iwid, 1) : 0.0f;
|
||||
|
||||
// alignment
|
||||
e.implysp = 0;
|
||||
e.firstalign = align.size();
|
||||
auto a = transcript.getaligninfo (j);
|
||||
align.insert (align.end(), a.begin(), a.end());
|
||||
foreach_index (k, a)
|
||||
numframes += a[k].frames;
|
||||
}
|
||||
nodes[transcript.words.size()].t = (unsigned short) numframes;
|
||||
if (nodes[transcript.words.size()].t != numframes)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("frommlf: too many frames to be represented as nodeinfo::t in label set: %S", key.c_str()));
|
||||
info.lmf = -1.0f; // indicates not set
|
||||
info.wp = 0.0f; // not set indicated by lmf < 0
|
||||
info.numedges = edges.size();
|
||||
info.numnodes = nodes.size();
|
||||
info.numframes = numframes;
|
||||
checklattice();
|
||||
|
||||
// create more efficient storage for alignments
|
||||
const size_t spunit = tryfind (unitmap, "sp", SIZE_MAX);
|
||||
builduniquealignments (spunit);
|
||||
|
||||
showstats();
|
||||
}
|
||||
|
||||
};};
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,119 +0,0 @@
|
|||
//
|
||||
// <copyright file="latticestorage.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// latticestorage.h -- basic data structures for storing lattices
|
||||
|
||||
|
||||
#if 0 // [v-hansu] separate code with history
|
||||
#endif
|
||||
|
||||
#pragma once
|
||||
#include <string> // for the error message in checkoverflow() only
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
|
||||
#undef INITIAL_STRANGE // [v-hansu] intialize structs to strange values
|
||||
#define PARALLEL_SIL // [v-hansu] process sil on CUDA, used in other files, please search this
|
||||
#define LOGZERO -1e30f
|
||||
|
||||
namespace msra { namespace lattices {
|
||||
|
||||
static void checkoverflow (size_t fieldval, size_t targetval, const char * fieldname)
|
||||
{
|
||||
if (fieldval != targetval)
|
||||
{
|
||||
char buf[1000];
|
||||
sprintf_s (buf, "lattice: bit field %s too small for value 0x%zu (cut from 0x%zu)", fieldname, targetval, fieldval);
|
||||
throw std::runtime_error (buf);
|
||||
}
|
||||
}
|
||||
|
||||
struct nodeinfo
|
||||
{
|
||||
//uint64_t firstinedge : 24; // index of first incoming edge
|
||||
//uint64_t firstoutedge : 24; // index of first outgoing edge
|
||||
//uint64_t t : 16; // time associated with this
|
||||
unsigned short t; // time associated with this
|
||||
nodeinfo (size_t pt) : t ((unsigned short) pt) //, firstinedge (NOEDGE), firstoutedge (NOEDGE)
|
||||
{
|
||||
checkoverflow (t, pt, "nodeinfo::t");
|
||||
//checkoverflow (firstinedge, NOEDGE, "nodeinfo::firstinedge");
|
||||
//checkoverflow (firstoutedge, NOEDGE, "nodeinfo::firstoutedge");
|
||||
}
|
||||
nodeinfo() // [v-hansu] initialize to impossible values
|
||||
{
|
||||
#ifdef INITIAL_STRANGE
|
||||
t = unsigned short (-1);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
// V2 format: a and l are stored in separate vectors
|
||||
struct edgeinfo
|
||||
{
|
||||
uint64_t S : 19; // start node
|
||||
uint64_t unused : 1; // (for future use)
|
||||
uint64_t E : 19; // end node
|
||||
uint64_t implysp : 1; // 1--alignment ends with a /sp/ that is not stored
|
||||
uint64_t firstalign : 24; // index into align for first entry; end is firstalign of next edge
|
||||
edgeinfo (size_t pS, size_t pE, size_t pfirstalign) : S (pS), E (pE), firstalign (pfirstalign), unused (0), implysp (0)
|
||||
{
|
||||
checkoverflow (S, pS, "edgeinfowithscores::S");
|
||||
checkoverflow (E, pE, "edgeinfowithscores::E");
|
||||
checkoverflow (firstalign, pfirstalign, "edgeinfowithscores::firstalign");
|
||||
}
|
||||
edgeinfo() // [v-hansu] initialize to impossible values
|
||||
{
|
||||
#ifdef INITIAL_STRANGE
|
||||
S = uint64_t (-1);
|
||||
unused = uint64_t (-1);
|
||||
E = uint64_t (-1);
|
||||
implysp = uint64_t (-1);
|
||||
firstalign = uint64_t (-1);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
// V1 format: a and l are included in the edge itself
|
||||
struct edgeinfowithscores : edgeinfo
|
||||
{
|
||||
float a;
|
||||
float l;
|
||||
edgeinfowithscores (size_t pS, size_t pE, float a, float l, size_t pfirstalign) : edgeinfo (pS, pE, pfirstalign), a(a), l(l) {}
|
||||
edgeinfowithscores() // [v-hansu] initialize to impossible values
|
||||
{
|
||||
#ifdef INITIAL_STRANGE
|
||||
a = LOGZERO;
|
||||
l = LOGZERO;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
struct aligninfo // phonetic alignment
|
||||
{
|
||||
unsigned int unit : 19; // triphone index
|
||||
unsigned int frames : 11; // duration in frames
|
||||
// note: V1 did not have the following, which were instead the two 2 bits of 'frames'
|
||||
unsigned int unused : 1; // (for future use)
|
||||
unsigned int last : 1; // set for last entry
|
||||
aligninfo (size_t punit, size_t pframes) : unit ((unsigned int) punit), frames ((unsigned int) pframes), unused (0), last (0)
|
||||
{
|
||||
checkoverflow (unit, punit, "aligninfo::unit");
|
||||
checkoverflow (frames, pframes, "aligninfo::frames");
|
||||
}
|
||||
aligninfo() // [v-hansu] initialize to impossible values
|
||||
{
|
||||
#ifdef INITIAL_STRANGE
|
||||
unit = unsigned int (-1);
|
||||
frames = unsigned int (-1);
|
||||
unused = unsigned int (-1);
|
||||
last = unsigned int (-1);
|
||||
#endif
|
||||
}
|
||||
template<class IDMAP> void updateunit (const IDMAP & idmap/*[unit] -> new unit*/) // update 'unit' w.r.t. a different mapping, with bit-field overflow check
|
||||
{
|
||||
const size_t mappedunit = idmap[unit];
|
||||
unit = (unsigned int) mappedunit;
|
||||
checkoverflow (unit, mappedunit, "aligninfo::unit");
|
||||
}
|
||||
};
|
||||
};};
|
|
@ -1,299 +0,0 @@
|
|||
//
|
||||
// <copyright file="minibatchiterator.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// minibatchiterator.h -- iterator for minibatches
|
||||
|
||||
|
||||
#pragma once
|
||||
#define NONUMLATTICEMMI // [v-hansu] move from main.cpp, no numerator lattice for mmi training
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "ssematrix.h"
|
||||
#include "latticearchive.h" // for reading HTK phoneme lattices (MMI training)
|
||||
#include "simple_checked_arrays.h" // for const_array_ref
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// latticesource -- manages loading of lattices for MMI (in pairs for numer and denom)
|
||||
// ---------------------------------------------------------------------------
|
||||
class latticesource
|
||||
{
|
||||
const msra::lattices::archive numlattices, denlattices;
|
||||
public:
|
||||
latticesource (std::pair<std::vector<wstring>,std::vector<wstring>> latticetocs, const std::unordered_map<std::string,size_t> & modelsymmap)
|
||||
: numlattices (latticetocs.first, modelsymmap), denlattices (latticetocs.second, modelsymmap) {}
|
||||
|
||||
bool empty() const
|
||||
{
|
||||
#ifndef NONUMLATTICEMMI // TODO:set NUM lattice to null so as to save memory
|
||||
if (numlattices.empty() ^ denlattices.empty())
|
||||
throw std::runtime_error("latticesource: numerator and denominator lattices must be either both empty or both not empty");
|
||||
#endif
|
||||
return denlattices.empty();
|
||||
}
|
||||
|
||||
bool haslattice (wstring key) const
|
||||
{
|
||||
#ifdef NONUMLATTICEMMI
|
||||
return denlattices.haslattice (key);
|
||||
#else
|
||||
return numlattices.haslattice (key) && denlattices.haslattice (key);
|
||||
#endif
|
||||
}
|
||||
|
||||
class latticepair : public pair<msra::lattices::lattice,msra::lattices::lattice>
|
||||
{
|
||||
public:
|
||||
// NOTE: we don't check numerator lattice now
|
||||
size_t getnumframes () const { return second.getnumframes(); }
|
||||
size_t getnumnodes () const { return second.getnumnodes(); }
|
||||
size_t getnumedges () const { return second.getnumedges(); }
|
||||
wstring getkey () const { return second.getkey(); }
|
||||
};
|
||||
|
||||
void getlattices (const std::wstring & key, shared_ptr<const latticesource::latticepair> & L, size_t expectedframes) const
|
||||
{
|
||||
shared_ptr<latticepair> LP (new latticepair);
|
||||
denlattices.getlattice (key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
|
||||
L = LP;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// minibatchsource -- abstracted interface into frame sources
|
||||
// There are three implementations:
|
||||
// - the old minibatchframesource to randomize across frames and page to disk
|
||||
// - minibatchutterancesource that randomizes in chunks and pages from input files directly
|
||||
// - a wrapper that uses a thread to read ahead in parallel to CPU/GPU processing
|
||||
// ---------------------------------------------------------------------------
|
||||
class minibatchsource
|
||||
{
|
||||
public:
|
||||
// read a minibatch
|
||||
// This function returns all values in a "caller can keep them" fashion:
|
||||
// - uids are stored in a huge 'const' array, and will never go away
|
||||
// - transcripts are copied by value
|
||||
// - lattices are returned as a shared_ptr
|
||||
// Thus, getbatch() can be called in a thread-safe fashion, allowing for a 'minibatchsource' implementation that wraps another with a read-ahead thread.
|
||||
// Return value is 'true' if it did read anything from disk, and 'false' if data came only from RAM cache. This is used for controlling the read-ahead thread.
|
||||
virtual bool getbatch (const size_t globalts,
|
||||
const size_t framesrequested, msra::dbn::matrix & feat, std::vector<size_t> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & lattices) = 0;
|
||||
// alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
|
||||
virtual bool getbatch (const size_t globalts,
|
||||
const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & lattices) = 0;
|
||||
|
||||
// getbatch() overload to support subsetting of mini-batches for parallel training
|
||||
// Default implementation does not support subsetting and throws an exception on
|
||||
// calling this overload with a numsubsets value other than 1.
|
||||
virtual bool getbatch(const size_t globalts,
|
||||
const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t & framesadvanced,
|
||||
std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & lattices)
|
||||
{
|
||||
assert((subsetnum == 0) && (numsubsets == 1) && !supportsbatchsubsetting()); subsetnum; numsubsets;
|
||||
bool retVal = getbatch(globalts, framesrequested, feat, uids, transcripts, lattices);
|
||||
framesadvanced = feat[0].cols();
|
||||
|
||||
return retVal;
|
||||
}
|
||||
|
||||
virtual bool supportsbatchsubsetting() const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual size_t totalframes() const = 0;
|
||||
|
||||
virtual double gettimegetbatch () = 0; // used to report runtime
|
||||
virtual size_t firstvalidglobalts (const size_t globalts) = 0; // get first valid epoch start from intended 'globalts'
|
||||
virtual const std::vector<size_t> & unitcounts() const = 0; // report number of senones
|
||||
virtual void setverbosity(int newverbosity) = 0;
|
||||
virtual ~minibatchsource() { }
|
||||
};
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// minibatchiterator -- class to iterate over one epoch, minibatch by minibatch
|
||||
// This iterator supports both random frames and random utterances through the minibatchsource interface whichis common to both.
|
||||
// This supports multiple data passes with identical randomization; which is intended to be used for utterance-based training.
|
||||
// ---------------------------------------------------------------------------
|
||||
class minibatchiterator
|
||||
{
|
||||
void operator= (const minibatchiterator &); // (non-copyable)
|
||||
|
||||
const size_t epochstartframe;
|
||||
const size_t epochendframe;
|
||||
size_t firstvalidepochstartframe; // epoch start frame rounded up to first utterance boundary after epoch boundary
|
||||
const size_t requestedmbframes; // requested mb size; actual minibatches can be smaller (or even larger for lattices)
|
||||
const size_t datapasses; // we return the data this many times; caller must sub-sample with 'datapass'
|
||||
|
||||
msra::dbn::minibatchsource & source; // feature source to read from
|
||||
|
||||
// subset to read during distributed data-parallel training (no subsetting: (0,1))
|
||||
size_t subsetnum;
|
||||
size_t numsubsets;
|
||||
|
||||
std::vector<msra::dbn::matrix> featbuf; // buffer for holding curernt minibatch's frames
|
||||
std::vector<std::vector<size_t>> uids; // buffer for storing current minibatch's frame-level label sequence
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts; // buffer for storing current minibatch's word-level label sequences (if available and used; empty otherwise)
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> lattices; // lattices of the utterances in current minibatch (empty in frame mode)
|
||||
|
||||
size_t mbstartframe; // current start frame into generalized time line (used for frame-wise mode and for diagnostic messages)
|
||||
size_t actualmbframes; // actual number of frames in current minibatch
|
||||
size_t mbframesadvanced; // logical number of frames the current MB represents (to advance time; > featbuf.cols() possible, intended for the case of distributed data-parallel training)
|
||||
size_t datapass; // current datapass = pass through the data
|
||||
double timegetbatch; // [v-hansu] for time measurement
|
||||
double timechecklattice;
|
||||
private:
|
||||
// fetch the next mb
|
||||
// This updates featbuf, uids[], mbstartframe, and actualmbframes.
|
||||
void fillorclear()
|
||||
{
|
||||
if (!hasdata()) // we hit the end of the epoch: just cleanly clear out everything (not really needed, can't be requested ever)
|
||||
{
|
||||
foreach_index(i, featbuf)
|
||||
featbuf[i].resize (0, 0);
|
||||
|
||||
foreach_index(i,uids)
|
||||
uids[i].clear();
|
||||
|
||||
transcripts.clear();
|
||||
actualmbframes = 0;
|
||||
return;
|
||||
}
|
||||
// process one mini-batch (accumulation and update)
|
||||
assert (requestedmbframes > 0);
|
||||
const size_t requestedframes = min (requestedmbframes, epochendframe - mbstartframe); // (< mbsize at end)
|
||||
assert (requestedframes > 0);
|
||||
source.getbatch (mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices);
|
||||
timegetbatch = source.gettimegetbatch();
|
||||
actualmbframes = featbuf[0].cols(); // for single i/o, there featbuf is length 1
|
||||
// note:
|
||||
// - in frame mode, actualmbframes may still return less if at end of sweep
|
||||
// - in utterance mode, it likely returns less than requested, and
|
||||
// it may also be > epochendframe (!) for the last utterance, which, most likely, crosses the epoch boundary
|
||||
// - in case of data parallelism, featbuf.cols() < mbframesadvanced
|
||||
auto_timer timerchecklattice;
|
||||
if (!lattices.empty())
|
||||
{
|
||||
size_t totalframes = 0;
|
||||
foreach_index (i, lattices)
|
||||
totalframes += lattices[i]->getnumframes();
|
||||
if (totalframes != actualmbframes)
|
||||
throw std::logic_error ("fillorclear: frames in lattices do not match minibatch size");
|
||||
}
|
||||
timechecklattice = timerchecklattice;
|
||||
}
|
||||
bool hasdata() const { return mbstartframe < epochendframe; } // true if we can access and/or advance
|
||||
void checkhasdata() const { if (!hasdata()) throw std::logic_error ("minibatchiterator: access beyond end of epoch"); }
|
||||
public:
|
||||
// interface: for (minibatchiterator i (...), i, i++) { ... }
|
||||
minibatchiterator (msra::dbn::minibatchsource & source, size_t epoch, size_t epochframes, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
|
||||
: source (source),
|
||||
epochstartframe (epoch * epochframes),
|
||||
epochendframe (epochstartframe + epochframes),
|
||||
requestedmbframes (requestedmbframes),
|
||||
subsetnum(subsetnum), numsubsets(numsubsets),
|
||||
datapasses (datapasses),
|
||||
timegetbatch (0), timechecklattice (0)
|
||||
{
|
||||
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
|
||||
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
|
||||
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
|
||||
mbstartframe = firstvalidepochstartframe;
|
||||
datapass = 0;
|
||||
fillorclear(); // get the first batch
|
||||
}
|
||||
|
||||
// TODO not nice, but don't know how to access these frames otherwise
|
||||
// mbiterator constructor, set epochstart and -endframe explicitly
|
||||
minibatchiterator(msra::dbn::minibatchsource & source, size_t epoch, size_t epochstart, size_t epochend, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
|
||||
: source (source),
|
||||
epochstartframe (epochstart),
|
||||
epochendframe (epochend),
|
||||
requestedmbframes (requestedmbframes),
|
||||
subsetnum(subsetnum), numsubsets(numsubsets),
|
||||
datapasses (datapasses),
|
||||
timegetbatch (0), timechecklattice (0)
|
||||
{
|
||||
firstvalidepochstartframe = source.firstvalidglobalts (epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
|
||||
fprintf (stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
|
||||
epoch, epochstartframe, epochendframe, firstvalidepochstartframe, subsetnum, numsubsets, datapasses);
|
||||
mbstartframe = firstvalidepochstartframe;
|
||||
datapass = 0;
|
||||
fillorclear(); // get the first batch
|
||||
}
|
||||
|
||||
// need virtual destructor to ensure proper destruction
|
||||
virtual ~minibatchiterator()
|
||||
{}
|
||||
|
||||
// returns true if we still have data
|
||||
operator bool() const { return hasdata(); }
|
||||
|
||||
// advance to the next minimb
|
||||
void operator++(int/*denotes postfix version*/)
|
||||
{
|
||||
checkhasdata();
|
||||
mbstartframe += mbframesadvanced;
|
||||
// if we hit the end, we will get mbstartframe >= epochendframe <=> !hasdata()
|
||||
// (most likely actually mbstartframe > epochendframe since the last utterance likely crosses the epoch boundary)
|
||||
// in case of multiple datapasses, reset to start when hitting the end
|
||||
if (!hasdata() && datapass + 1 < datapasses)
|
||||
{
|
||||
mbstartframe = firstvalidepochstartframe;
|
||||
datapass++;
|
||||
fprintf (stderr, "\nminibatchiterator: entering %zu-th repeat pass through the data\n", datapass+1);
|
||||
}
|
||||
fillorclear();
|
||||
}
|
||||
|
||||
// accessors to current minibatch
|
||||
size_t currentmbstartframe() const { return mbstartframe; }
|
||||
size_t currentmbframes() const { return actualmbframes; }
|
||||
size_t currentmbframesadvanced() const { return mbframesadvanced; }
|
||||
size_t currentmblattices() const { return lattices.size(); }
|
||||
size_t currentdatapass() const { return datapass; } // 0..datapasses-1; use this for sub-sampling
|
||||
size_t requestedframes() const {return requestedmbframes; }
|
||||
double gettimegetbatch () {return timegetbatch;}
|
||||
double gettimechecklattice () {return timechecklattice;}
|
||||
bool isfirst() const { return mbstartframe == firstvalidepochstartframe && datapass == 0; }
|
||||
float progress() const // (note: 100%+eps possible for last utterance)
|
||||
{
|
||||
const float epochframes = (float) (epochendframe - epochstartframe);
|
||||
return (mbstartframe + mbframesadvanced - epochstartframe + datapass * epochframes) / (datapasses * epochframes);
|
||||
}
|
||||
std::pair<size_t,size_t> range() const { return make_pair (epochstartframe, epochendframe); }
|
||||
|
||||
// return the current minibatch frames as a matrix ref into the feature buffer
|
||||
// Number of frames is frames().cols() == currentmbframes().
|
||||
// For frame-based randomization, this is 'requestedmbframes' most of the times, while for utterance randomization,
|
||||
// this depends highly on the utterance lengths.
|
||||
// User is allowed to manipulate the frames... for now--TODO: move silence filtering here as well
|
||||
|
||||
msra::dbn::matrixstripe frames(size_t i) { checkhasdata(); assert(featbuf.size()>=i+1); return msra::dbn::matrixstripe (featbuf[i], 0, actualmbframes); }
|
||||
|
||||
msra::dbn::matrixstripe frames() { checkhasdata(); assert(featbuf.size()==1); return msra::dbn::matrixstripe (featbuf[0], 0, actualmbframes); }
|
||||
|
||||
// return the reference transcript labels (state alignment) for current minibatch
|
||||
/*const*/ std::vector<size_t> & labels() { checkhasdata(); assert(uids.size()==1);return uids[0]; }
|
||||
/*const*/ std::vector<size_t> & labels(size_t i) { checkhasdata(); assert(uids.size()>=i+1); return uids[i]; }
|
||||
|
||||
// return a lattice for an utterance (caller should first get total through currentmblattices())
|
||||
shared_ptr<const msra::dbn::latticesource::latticepair> lattice (size_t uttindex) const { return lattices[uttindex]; } // lattices making up the current
|
||||
|
||||
// return the reference transcript labels (words with alignments) for current minibatch (or empty if no transcripts requested)
|
||||
const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word> transcript (size_t uttindex) { return transcripts.empty() ? const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>() : transcripts[uttindex]; }
|
||||
};
|
||||
|
||||
};};
|
|
@ -1,279 +0,0 @@
|
|||
//
|
||||
// <copyright file="minibatchsourcehelpers.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// minibatchsourcehelpers.h -- helper classes for minibatch sources
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "basetypes.h"
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef __unix__
|
||||
#include "ssematrix.h" // for matrix type
|
||||
#endif
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// augmentneighbors() -- augmenting features with their neighbor frames
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// implant a sub-vector into a vector, for use in augmentneighbors
|
||||
template<class INV, class OUTV> static void copytosubvector (const INV & inv, size_t subvecindex, OUTV & outv)
|
||||
{
|
||||
size_t subdim = inv.size();
|
||||
assert (outv.size() % subdim == 0);
|
||||
size_t k0 = subvecindex * subdim;
|
||||
foreach_index (k, inv)
|
||||
outv[k + k0] = inv[k];
|
||||
}
|
||||
|
||||
// compute the augmentation extent (how many frames added on each side)
|
||||
static size_t augmentationextent (size_t featdim/*augment from*/, size_t modeldim/*to*/)
|
||||
{
|
||||
const size_t windowframes = modeldim / featdim; // total number of frames to generate
|
||||
const size_t extent = windowframes / 2; // extend each side by this
|
||||
|
||||
if (modeldim % featdim != 0)
|
||||
throw runtime_error ("augmentationextent: model vector size not multiple of input features");
|
||||
if (windowframes % 2 == 0)
|
||||
throw runtime_error (msra::strfun::strprintf ("augmentationextent: neighbor expansion of input features to %d not symmetrical", windowframes));
|
||||
|
||||
return extent;
|
||||
}
|
||||
|
||||
// augment neighbor frames for a frame (correctly not expanding across utterance boundaries)
|
||||
// The boundaryflags[] array, if not empty, flags first (-1) and last (+1) frame, i.e. frames to not expand across.
|
||||
// The output 'v' must have te-ts columns.
|
||||
template<class MATRIX, class VECTOR> static void augmentneighbors (const MATRIX & frames, const std::vector<char> & boundaryflags, size_t t,
|
||||
VECTOR & v)
|
||||
{
|
||||
// how many frames are we adding on each side
|
||||
const size_t extent = augmentationextent (frames[t].size(), v.size());
|
||||
|
||||
// copy the frame and its neighbors
|
||||
// Once we hit a boundaryflag in either direction, do not move index beyond.
|
||||
copytosubvector (frames[t], extent, v); // frame[t] sits right in the middle
|
||||
size_t t1 = t; // index for frames on to the left
|
||||
size_t t2 = t; // and right
|
||||
for (size_t n = 1; n <= extent; n++)
|
||||
{
|
||||
#ifdef SAMPLING_EXPERIMENT
|
||||
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
|
||||
{
|
||||
if (t1 >= SAMPLING_EXPERIMENT) t1 -= SAMPLING_EXPERIMENT; // index does not move beyond boundary
|
||||
if (t2 + SAMPLING_EXPERIMENT < frames.size()) t2 += SAMPLING_EXPERIMENT;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (boundaryflags[t1] != -1) t1 -= SAMPLING_EXPERIMENT; // index does not move beyond a set boundaryflag,
|
||||
if (boundaryflags[t2] != 1) t2 += SAMPLING_EXPERIMENT; // because that's the start/end of the utterance
|
||||
}
|
||||
#else
|
||||
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
|
||||
{
|
||||
if (t1 > 0) t1--; // index does not move beyond boundary
|
||||
if (t2 + 1 < frames.size()) t2++;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (boundaryflags[t1] != -1) t1--; // index does not move beyond a set boundaryflag,
|
||||
if (boundaryflags[t2] != 1) t2++; // because that's the start/end of the utterance
|
||||
}
|
||||
#endif
|
||||
copytosubvector (frames[t1], extent - n, v);
|
||||
copytosubvector (frames[t2], extent + n, v);
|
||||
}
|
||||
}
|
||||
|
||||
// augment neighbor frames for a frame (correctly not expanding across utterance boundaries)
|
||||
// The boundaryflags[] array, if not empty, flags first (-1) and last (+1) frame, i.e. frames to not expand across.
|
||||
// The output 'v' must have te-ts columns.
|
||||
template<class MATRIX, class VECTOR> static void augmentneighbors(const MATRIX & frames, const std::vector<char> & boundaryflags, size_t t, const size_t leftextent, const size_t rightextent,
|
||||
VECTOR & v)
|
||||
{
|
||||
|
||||
// copy the frame and its neighbors
|
||||
// Once we hit a boundaryflag in either direction, do not move index beyond.
|
||||
copytosubvector(frames[t], leftextent, v); // frame[t] sits right in the middle
|
||||
size_t t1 = t; // index for frames on to the left
|
||||
size_t t2 = t; // and right
|
||||
|
||||
for (size_t n = 1; n <= leftextent; n++)
|
||||
{
|
||||
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
|
||||
{
|
||||
if (t1 > 0) t1--; // index does not move beyond boundary
|
||||
}
|
||||
else
|
||||
{
|
||||
if (boundaryflags[t1] != -1) t1--; // index does not move beyond a set boundaryflag,
|
||||
}
|
||||
copytosubvector(frames[t1], leftextent - n, v);
|
||||
}
|
||||
for (size_t n = 1; n <= rightextent; n++)
|
||||
{
|
||||
if (boundaryflags.empty()) // boundary flags not given: 'frames' is full utterance
|
||||
{
|
||||
if (t2 + 1 < frames.size()) t2++;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (boundaryflags[t2] != 1) t2++; // because that's the start/end of the utterance
|
||||
}
|
||||
copytosubvector(frames[t2], rightextent + n, v);
|
||||
}
|
||||
}
|
||||
|
||||
// augment neighbor frames for one frame t in frames[] according to boundaryflags[]; result returned in column j of v
|
||||
template<class INMATRIX, class OUTMATRIX> static void augmentneighbors (const INMATRIX & frames, const std::vector<char> & boundaryflags, size_t t,
|
||||
OUTMATRIX & v, size_t j)
|
||||
{
|
||||
auto v_j = v.col(j); // the vector to fill in
|
||||
augmentneighbors (frames, boundaryflags, t, v_j);
|
||||
}
|
||||
|
||||
// augment neighbor frames for one frame t in frames[] according to boundaryflags[]; result returned in column j of v
|
||||
template<class INMATRIX, class OUTMATRIX> static void augmentneighbors(const INMATRIX & frames, const std::vector<char> & boundaryflags, size_t t, size_t leftextent, size_t rightextent,
|
||||
OUTMATRIX & v, size_t j)
|
||||
{
|
||||
auto v_j = v.col(j); // the vector to fill in
|
||||
augmentneighbors(frames, boundaryflags, t, leftextent, rightextent, v_j);
|
||||
}
|
||||
|
||||
// augment neighbor frames for a sequence of frames (part of an utterance, possibly spanning across boundaries)
|
||||
template<class MATRIX> static void augmentneighbors (const std::vector<std::vector<float>> & frames, const std::vector<char> & boundaryflags,
|
||||
size_t ts, size_t te, // range [ts,te)
|
||||
MATRIX & v)
|
||||
{
|
||||
for (size_t t = ts; t < te; t++)
|
||||
{
|
||||
auto v_t = v.col(t-ts); // the vector to fill in
|
||||
augmentneighbors (frames, boundaryflags, t, v_t);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// augment neighbor frames for a sequence of frames (part of an utterance, possibly spanning across boundaries)
|
||||
template<class MATRIX> static void augmentneighbors(const std::vector<std::vector<float>> & frames, const std::vector<char> & boundaryflags, size_t leftextent, size_t rightextent,
|
||||
size_t ts, size_t te, // range [ts,te)
|
||||
MATRIX & v)
|
||||
{
|
||||
for (size_t t = ts; t < te; t++)
|
||||
{
|
||||
auto v_t = v.col(t - ts); // the vector to fill in
|
||||
augmentneighbors(frames, boundaryflags, t, leftextent, rightextent, v_t);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// randomordering -- class to help manage randomization of input data
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static inline size_t rand (const size_t begin, const size_t end)
|
||||
{
|
||||
const size_t randno = ::rand() * RAND_MAX + ::rand(); // BUGBUG: still only covers 32-bit range
|
||||
return begin + randno % (end - begin);
|
||||
}
|
||||
|
||||
class randomordering // note: NOT thread-safe at all
|
||||
{
|
||||
// constants for randomization
|
||||
const static size_t randomizeAuto=0;
|
||||
const static size_t randomizeDisable=(size_t)-1;
|
||||
|
||||
typedef unsigned int INDEXTYPE; // don't use size_t, as this saves HUGE amounts of RAM
|
||||
std::vector<INDEXTYPE> map; // [t] -> t' indices in randomized order
|
||||
size_t currentseed; // seed for current sequence
|
||||
size_t randomizationrange; // t - randomizationrange/2 <= t' < t + randomizationrange/2 (we support this to enable swapping)
|
||||
// special values (randomizeAuto, randomizeDisable)
|
||||
void invalidate() { currentseed = (size_t) -1; }
|
||||
public:
|
||||
randomordering() { invalidate(); }
|
||||
|
||||
void resize (size_t len, size_t p_randomizationrange) { randomizationrange = p_randomizationrange>0?p_randomizationrange:len; map.resize (len); invalidate(); }
|
||||
|
||||
// return the randomized feature bounds for a time range
|
||||
std::pair<size_t,size_t> bounds (size_t ts, size_t te) const
|
||||
{
|
||||
size_t tbegin = max (ts, randomizationrange/2) - randomizationrange/2;
|
||||
size_t tend = min (te + randomizationrange/2, map.size());
|
||||
return std::make_pair<size_t,size_t> (move(tbegin), move(tend));
|
||||
}
|
||||
|
||||
// this returns the map directly (read-only) and will lazily initialize it for a given seed
|
||||
const std::vector<INDEXTYPE> & operator() (size_t seed) //throw()
|
||||
{
|
||||
// if wrong seed then lazily recache the sequence
|
||||
if (seed != currentseed)
|
||||
{
|
||||
// test for numeric overflow
|
||||
if (map.size()-1 != (INDEXTYPE) (map.size()-1))
|
||||
throw std::runtime_error ("randomordering: INDEXTYPE has too few bits for this corpus");
|
||||
// 0, 1, 2...
|
||||
foreach_index (t, map) map[t] = (INDEXTYPE) t;
|
||||
// now randomize them
|
||||
if (randomizationrange != randomizeDisable)
|
||||
{
|
||||
#if 1 // change to 0 to disable randomizing
|
||||
if (map.size() > RAND_MAX * (size_t) RAND_MAX)
|
||||
throw std::runtime_error ("randomordering: too large training set: need to change to different random generator!");
|
||||
srand ((unsigned int) seed);
|
||||
size_t retries = 0;
|
||||
foreach_index (t, map)
|
||||
{
|
||||
for (int tries = 0; tries < 5; tries++)
|
||||
{
|
||||
// swap current pos with a random position
|
||||
// Random positions are limited to t+randomizationrange.
|
||||
// This ensures some locality suitable for paging with a sliding window.
|
||||
const size_t tbegin = max ((size_t) t, randomizationrange/2) - randomizationrange/2; // range of window --TODO: use bounds() function above
|
||||
const size_t tend = min (t + randomizationrange/2, map.size());
|
||||
assert (tend >= tbegin); // (guard against potential numeric-wraparound bug)
|
||||
const size_t trand = rand (tbegin, tend); // random number within windows
|
||||
assert ((size_t) t <= trand + randomizationrange/2 && trand < (size_t) t + randomizationrange/2);
|
||||
// if range condition is fulfilled then swap
|
||||
if (trand <= map[t] + randomizationrange/2 && map[t] < trand + randomizationrange/2
|
||||
&& (size_t) t <= map[trand] + randomizationrange/2 && map[trand] < (size_t) t + randomizationrange/2)
|
||||
{
|
||||
::swap (map[t], map[trand]);
|
||||
break;
|
||||
}
|
||||
// but don't multi-swap stuff out of its range (for swapping positions that have been swapped before)
|
||||
// instead, try again with a different random number
|
||||
retries++;
|
||||
}
|
||||
}
|
||||
fprintf (stderr, "randomordering: %zu retries for %zu elements (%.1f%%) to ensure window condition\n", retries, map.size(), 100.0 * retries / map.size());
|
||||
// ensure the window condition
|
||||
foreach_index (t, map) assert ((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2);
|
||||
#if 1 // and a live check since I don't trust myself here yet
|
||||
foreach_index (t, map) if (!((size_t) t <= map[t] + randomizationrange/2 && map[t] < (size_t) t + randomizationrange/2))
|
||||
{
|
||||
fprintf (stderr, "randomordering: windowing condition violated %d -> %d\n", t, map[t]);
|
||||
throw std::logic_error ("randomordering: windowing condition violated");
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#if 1 // test whether it is indeed a unique complete sequence
|
||||
auto map2 = map;
|
||||
::sort (map2.begin(), map2.end());
|
||||
foreach_index (t, map2) assert (map2[t] == (size_t) t);
|
||||
#endif
|
||||
fprintf (stderr, "randomordering: recached sequence for seed %d: %d, %d, ...\n", (int) seed, (int) map[0], (int) map[1]);
|
||||
}
|
||||
currentseed = seed;
|
||||
}
|
||||
return map; // caller can now access it through operator[]
|
||||
}
|
||||
};
|
||||
|
||||
//typedef unsigned short CLASSIDTYPE; // type to store state ids; don't use size_t --saves HUGE amounts of RAM
|
||||
typedef unsigned int CLASSIDTYPE; //mseltzer - change to unsigned int for untied context-dependent phones
|
||||
|
||||
};};
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,254 +0,0 @@
|
|||
//
|
||||
// <copyright file="numahelpers.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// numahelpers.h -- some helpers with NUMA
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __unix__
|
||||
#include <Windows.h>
|
||||
#include "pplhelpers.h"
|
||||
|
||||
#endif
|
||||
#include <stdexcept>
|
||||
#include "simple_checked_arrays.h"
|
||||
#include "basetypes.h" // for FormatWin32Error
|
||||
|
||||
namespace msra { namespace numa {
|
||||
|
||||
// ... TODO: this can be a 'static', as it should only be set during foreach_node but not outside
|
||||
extern int node_override; // -1 = normal operation; >= 0: force a specific NUMA node
|
||||
|
||||
// force a specific NUMA node (only do this during single-threading!)
|
||||
static inline void overridenode (int n = -1)
|
||||
{
|
||||
node_override = n;
|
||||
}
|
||||
|
||||
// get the number of NUMA nodes we would like to distinguish
|
||||
static inline size_t getnumnodes()
|
||||
{
|
||||
ULONG n;
|
||||
if (!GetNumaHighestNodeNumber (&n)) return 1;
|
||||
return n +1;
|
||||
}
|
||||
|
||||
// execute body (node, i, n), i in [0,n) on all NUMA nodes in small chunks
|
||||
template <typename FUNCTION> void parallel_for_on_each_numa_node (bool multistep, const FUNCTION & body)
|
||||
{
|
||||
// get our configuration
|
||||
const size_t cores = ppl_cores;
|
||||
assert (cores > 0);
|
||||
const size_t nodes = getnumnodes();
|
||||
const size_t corespernode = (cores -1) / nodes + 1;
|
||||
// break into 8 steps per thread
|
||||
const size_t stepspernode = multistep ? 16 : 1;
|
||||
const size_t steps = corespernode * stepspernode;
|
||||
// now run on many threads, hoping to hit all NUMA nodes, until we are done
|
||||
hardcoded_array<LONG/*unsigned int*/,256> nextstepcounters; // next block to run for a NUMA node
|
||||
if (nodes > nextstepcounters.size())
|
||||
throw std::logic_error ("parallel_for_on_each_numa_node: nextstepcounters buffer too small, need to increase hard-coded size");
|
||||
for (size_t k = 0; k < nodes; k++) nextstepcounters[k] = 0;
|
||||
overridenode();
|
||||
//unsigned int totalloops = 0; // for debugging only, can be removed later
|
||||
msra::parallel::parallel_for (0, nodes * steps /*execute each step on each NUMA node*/, 1, [&](size_t /*dummy*/)
|
||||
{
|
||||
const size_t numanodeid = getcurrentnode();
|
||||
// find a node that still has work left, preferring our own node
|
||||
// Towards the end we will run on wrong nodes, but what can we do.
|
||||
for (size_t node1 = numanodeid; node1 < numanodeid + nodes; node1++)
|
||||
{
|
||||
const size_t node = node1 % nodes;
|
||||
const unsigned int step = InterlockedIncrement (&nextstepcounters[node]) -1; // grab this step
|
||||
if (step >= steps) // if done then counter has exceeded the required number of steps
|
||||
continue; // so try next NUMA node
|
||||
// found one: execute and terminate loop
|
||||
body (node, step, steps);
|
||||
//InterlockedIncrement (&totalloops);
|
||||
return; // done
|
||||
}
|
||||
// oops??
|
||||
throw std::logic_error ("parallel_for_on_each_numa_node: no left-over block found--should not get here!!");
|
||||
});
|
||||
//assert (totalloops == nodes * steps);
|
||||
}
|
||||
|
||||
// execute a passed function once for each NUMA node
|
||||
// This must be run from the main thread only.
|
||||
// ... TODO: honor ppl_cores == 1 for comparative measurements against single threads.
|
||||
template<typename FUNCTION>
|
||||
static void foreach_node_single_threaded (const FUNCTION & f)
|
||||
{
|
||||
const size_t n = getnumnodes();
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
overridenode ((int) i);
|
||||
f();
|
||||
}
|
||||
overridenode (-1);
|
||||
}
|
||||
|
||||
// get the current NUMA node
|
||||
static inline size_t getcurrentnode()
|
||||
{
|
||||
// we can force it to be a certain node, for use in initializations
|
||||
if (node_override >= 0)
|
||||
return (size_t) node_override;
|
||||
// actually use current node
|
||||
DWORD i = GetCurrentProcessorNumber(); // note: need to change for >63 processors
|
||||
UCHAR n;
|
||||
if (!GetNumaProcessorNode ((UCHAR) i, &n)) return 0;
|
||||
if (n == 0xff)
|
||||
throw std::logic_error ("GetNumaProcessorNode() failed to determine NUMA node for GetCurrentProcessorNumber()??");
|
||||
return n;
|
||||
}
|
||||
|
||||
// allocate memory
|
||||
// Allocation seems to be at least on a 512-byte boundary. We nevertheless verify alignment requirements.
|
||||
typedef LPVOID (WINAPI *VirtualAllocExNuma_t) (HANDLE,LPVOID,SIZE_T,DWORD,DWORD,DWORD);
|
||||
static VirtualAllocExNuma_t VirtualAllocExNuma = (VirtualAllocExNuma_t)-1;
|
||||
|
||||
static inline void * malloc (size_t n, size_t align)
|
||||
{
|
||||
// VirtualAllocExNuma() only exists on Vista+, so go through an explicit function pointer
|
||||
if (VirtualAllocExNuma == (VirtualAllocExNuma_t)-1)
|
||||
{
|
||||
VirtualAllocExNuma = (VirtualAllocExNuma_t) GetProcAddress (GetModuleHandle ( TEXT ("kernel32.dll")), "VirtualAllocExNuma");
|
||||
}
|
||||
|
||||
// if we have the function then do a NUMA-aware allocation
|
||||
void * p;
|
||||
if (VirtualAllocExNuma != NULL)
|
||||
{
|
||||
size_t node = getcurrentnode();
|
||||
// "all Win32 heap allocations that are 1 MB or greater are forwarded directly to NtAllocateVirtualMemory
|
||||
// when they are allocated and passed directly to NtFreeVirtualMemory when they are freed" Greg Colombo, 2010/11/17
|
||||
if (n < 1024*1024)
|
||||
n = 1024*1024; // -> brings NUMA-optimized code back to Node Interleave level (slightly faster)
|
||||
p = VirtualAllocExNuma (GetCurrentProcess(), NULL, n, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE, (DWORD) node);
|
||||
}
|
||||
else // on old OS call no-NUMA version
|
||||
{
|
||||
p = VirtualAllocEx (GetCurrentProcess(), NULL, n, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
|
||||
}
|
||||
if (p == NULL)
|
||||
fprintf (stderr, "numa::malloc: failed allocating %d bytes with alignment %d\n", n, align);
|
||||
if (((size_t) p) % align != 0)
|
||||
throw std::logic_error ("VirtualAllocExNuma() returned an address that does not match the alignment requirement");
|
||||
return p;
|
||||
}
|
||||
|
||||
// free memory allocated with numa::malloc()
|
||||
static inline void free (void * p)
|
||||
{
|
||||
assert (p != NULL);
|
||||
if (!VirtualFree (p, 0, MEM_RELEASE))
|
||||
throw std::logic_error ("VirtualFreeEx failure");
|
||||
}
|
||||
|
||||
// dump memory allocation
|
||||
static inline void showavailablememory (const char * what)
|
||||
{
|
||||
size_t n = getnumnodes();
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
ULONGLONG availbytes = 0;
|
||||
BOOL rc = GetNumaAvailableMemoryNode ((UCHAR) i, &availbytes);
|
||||
const double availmb = availbytes / (1024.0*1024.0);
|
||||
if (rc)
|
||||
fprintf (stderr, "%s: %8.2f MB available on NUMA node %d\n", what, availmb, i);
|
||||
else
|
||||
fprintf (stderr, "%s: error '%S' for getting available memory on NUMA node %d\n", what, FormatWin32Error (::GetLastError()).c_str(), i);
|
||||
}
|
||||
}
|
||||
|
||||
// determine NUMA node with most memory available
|
||||
static inline size_t getmostspaciousnumanode()
|
||||
{
|
||||
size_t n = getnumnodes();
|
||||
size_t bestnode = 0;
|
||||
ULONGLONG bestavailbytes = 0;
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
ULONGLONG availbytes = 0;
|
||||
GetNumaAvailableMemoryNode ((UCHAR) i, &availbytes);
|
||||
if (availbytes > bestavailbytes)
|
||||
{
|
||||
bestavailbytes = availbytes;
|
||||
bestnode = i;
|
||||
}
|
||||
}
|
||||
return bestnode;
|
||||
}
|
||||
|
||||
#if 0 // this is no longer used (we now parallelize the big matrix products directly)
|
||||
// class to manage multiple copies of data on local NUMA nodes
|
||||
template<class DATATYPE,class CACHEDTYPE> class numalocaldatacache
|
||||
{
|
||||
numalocaldatacache (const numalocaldatacache&); numalocaldatacache & operator= (const numalocaldatacache&);
|
||||
|
||||
// the data set we associate to
|
||||
const DATATYPE & data;
|
||||
|
||||
// cached copies of the models for NUMA
|
||||
vector<unique_ptr<CACHEDTYPE>> cache;
|
||||
|
||||
// get the pointer to the clone for the NUMA node of the current thread (must exist)
|
||||
CACHEDTYPE * getcloneptr()
|
||||
{
|
||||
return cache[getcurrentnode()].get();
|
||||
}
|
||||
public:
|
||||
numalocaldatacache (const DATATYPE & data) : data (data), cache (getnumnodes())
|
||||
{
|
||||
foreach_node_single_threaded ([&]()
|
||||
{
|
||||
cache[getcurrentnode()].reset (new CACHEDTYPE (data));
|
||||
});
|
||||
}
|
||||
|
||||
// this takes the cached versions of the parent model
|
||||
template<typename ARGTYPE1,typename ARGTYPE2,typename ARGTYPE3>
|
||||
numalocaldatacache (numalocaldatacache<DATATYPE,DATATYPE> & parentcache, const ARGTYPE1 & arg1, const ARGTYPE2 & arg2, const ARGTYPE3 & arg3) : data (*(DATATYPE*)nullptr), cache (getnumnodes())
|
||||
{
|
||||
foreach_node_single_threaded ([&]()
|
||||
{
|
||||
const DATATYPE & parent = parentcache.getclone();
|
||||
size_t numanodeid = getcurrentnode();
|
||||
cache[numanodeid].reset (new CACHEDTYPE (parent, arg1, arg2, arg3));
|
||||
});
|
||||
}
|
||||
|
||||
// re-clone --update clones from the cached 'data' reference
|
||||
// This is only valid if CACHEDTYPE==DATATYPE.
|
||||
// ... parallelize this!
|
||||
void reclone()
|
||||
{
|
||||
parallel_for_on_each_numa_node (true, [&] (size_t numanodeid, size_t step, size_t steps)
|
||||
{
|
||||
if (step != 0)
|
||||
return; // ... TODO: tell parallel_for_on_each_numa_node() to only have one step, or parallelize
|
||||
cache[numanodeid].get()->copyfrom (data); // copy it all over
|
||||
});
|
||||
}
|
||||
|
||||
// post-process all clones
|
||||
// 'numanodeid' is ideally the current NUMA node most of the time, but not required.
|
||||
template<typename POSTPROCFUNC>
|
||||
void process (const POSTPROCFUNC & postprocess)
|
||||
{
|
||||
parallel_for_on_each_numa_node (true, [&] (size_t numanodeid, size_t step, size_t steps)
|
||||
{
|
||||
postprocess (*cache[numanodeid].get(), step, steps);
|
||||
});
|
||||
}
|
||||
|
||||
// a thread calls this to get the data pre-cloned for its optimal NUMA node
|
||||
// (only works for memory allocated through msra::numa::malloc())
|
||||
const CACHEDTYPE & getclone() const { return *getcloneptr(); }
|
||||
CACHEDTYPE & getclone() { return *getcloneptr(); }
|
||||
};
|
||||
#endif
|
||||
};};
|
|
@ -1,99 +0,0 @@
|
|||
//
|
||||
// <copyright file="pplhelpers.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// pplhelpers.h -- some helpers for PPL library
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __unix__
|
||||
#include <ppl.h>
|
||||
#endif
|
||||
namespace msra { namespace parallel {
|
||||
|
||||
// ===========================================================================
|
||||
// helpers related to multiprocessing and NUMA
|
||||
// ===========================================================================
|
||||
|
||||
// determine number of CPU cores on this machine
|
||||
static inline size_t determine_num_cores()
|
||||
{
|
||||
SYSTEM_INFO sysInfo;
|
||||
GetSystemInfo (&sysInfo);
|
||||
return sysInfo.dwNumberOfProcessors;
|
||||
}
|
||||
|
||||
extern size_t ppl_cores; // number of cores to run on as requested by user
|
||||
|
||||
static inline void set_cores (size_t cores)
|
||||
{
|
||||
ppl_cores = cores;
|
||||
}
|
||||
|
||||
static inline size_t get_cores() // if returns 1 then no parallelization will be done
|
||||
{
|
||||
return ppl_cores;
|
||||
}
|
||||
|
||||
#if 0
|
||||
// execute body() a bunch of times for hopefully each core
|
||||
// This is not precise. Cores will be hit multiple times, and some cores may not be touched.
|
||||
template <typename FUNCTION> void for_all_numa_nodes_approximately (const FUNCTION & body)
|
||||
{
|
||||
if (ppl_cores > 1) // parallel computation (regular)
|
||||
parallel_for ((size_t) 0, ppl_cores * 2, (size_t) 1, [&](size_t) { body(); });
|
||||
else // for comparison: single-threaded (this also documents what the above means)
|
||||
body();
|
||||
}
|
||||
#endif
|
||||
|
||||
// wrapper around Concurrency::parallel_for() to allow disabling parallelization altogether
|
||||
template <typename FUNCTION> void parallel_for (size_t begin, size_t end, size_t step, const FUNCTION & f)
|
||||
{
|
||||
const size_t cores = ppl_cores;
|
||||
if (cores > 1) // parallel computation (regular)
|
||||
{
|
||||
//fprintf (stderr, "foreach_index_block: computing %d blocks of %d frames on %d cores\n", nblocks, nfwd, determine_num_cores());
|
||||
Concurrency::parallel_for (begin, end, step, f);
|
||||
}
|
||||
else // for comparison: single-threaded (this also documents what the above means)
|
||||
{
|
||||
//fprintf (stderr, "foreach_index_block: computing %d blocks of %d frames on a single thread\n", nblocks, nfwd);
|
||||
for (size_t j0 = begin; j0 < end; j0 += step) f (j0);
|
||||
}
|
||||
}
|
||||
|
||||
// execute a function 'body (j0, j1)' for j = [0..n) in chunks of ~targetstep in 'cores' cores
|
||||
// Very similar to parallel_for() except that body function also takes end index,
|
||||
// and the 'targetsteps' gets rounded a little to better map to 'cores.'
|
||||
// ... TODO: Currently, 'cores' does not limit the number of threads in parallel_for() (not so critical, fix later or never)
|
||||
template <typename FUNCTION> void foreach_index_block (size_t n, size_t targetstep, size_t targetalignment, const FUNCTION & body)
|
||||
{
|
||||
const size_t cores = ppl_cores;
|
||||
const size_t maxnfwd = 2 * targetstep;
|
||||
size_t nblocks = (n + targetstep / 2) / targetstep;
|
||||
if (nblocks == 0) nblocks = 1;
|
||||
// round to a multiple of the number of cores
|
||||
if (nblocks < cores) // less than # cores -> round up
|
||||
nblocks = (1+(nblocks-1)/cores) * cores;
|
||||
else // more: round down (reduce overhead)
|
||||
nblocks = nblocks / cores * cores;
|
||||
size_t nfwd = 1 + (n - 1) / nblocks;
|
||||
assert (nfwd * nblocks >= n);
|
||||
if (nfwd > maxnfwd) nfwd = maxnfwd; // limit to allocated memory just in case
|
||||
// ... TODO: does the above actually do anything/significant? nfwd != targetstep?
|
||||
|
||||
// enforce alignment
|
||||
nfwd = (1 + (nfwd -1) / targetalignment) * targetalignment;
|
||||
|
||||
// execute it!
|
||||
parallel_for (0, n, nfwd, [&](size_t j0)
|
||||
{
|
||||
size_t j1 = min (j0 + nfwd, n);
|
||||
body (j0, j1);
|
||||
});
|
||||
}
|
||||
|
||||
};};
|
|
@ -1,249 +0,0 @@
|
|||
//
|
||||
// <copyright file="readaheadsource.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// readaheadsource.h -- wrapper ('minibatchreadaheadsource') of a read-ahead thread that pre-rolls feature and lattice data
|
||||
//
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "basetypes.h"
|
||||
#include "minibatchiterator.h"
|
||||
#include "latticearchive.h"
|
||||
#ifdef _WIN32
|
||||
#include "simplethread.h"
|
||||
#endif
|
||||
#include <deque>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// minibatchreadaheadsource -- read-ahead thread that pre-rolls feature and lattice data
|
||||
// ---------------------------------------------------------------------------
|
||||
class minibatchreadaheadsource : public minibatchsource/*the interface we implement*/,
|
||||
noncopyable/*assignment operator needed somewhere*/,
|
||||
CCritSec/*for multi-threaded access*/
|
||||
{
|
||||
minibatchsource & source; // the underlying source we read from
|
||||
const size_t epochframes; // epoch size
|
||||
unique_ptr<msra::util::simplethread> thread;
|
||||
int verbosity;
|
||||
// the FIFO
|
||||
struct batchdata // all arguments to/from getbatch
|
||||
{
|
||||
size_t globalts; // time for which we get the data
|
||||
// return values
|
||||
msra::dbn::matrix feat;
|
||||
std::vector<size_t> uids;
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts;
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> lattices;
|
||||
batchdata (size_t globalts) : globalts (globalts) { }
|
||||
};
|
||||
deque<batchdata> fifo; // this is guarded by the CCritSec
|
||||
size_t epoch; // which epoch we are in currently
|
||||
// parameters for the thread proc (set by caller; taken over once newglobalts is set to non-SIZE_MAX (cleared back by thread))
|
||||
volatile size_t newglobalts; // reset request
|
||||
volatile size_t currentepochreqframes; // minibatch size for this epoch (taken from the first getbatch() call)
|
||||
volatile size_t currentepochendframe; // we cannot request beyond
|
||||
// signalling
|
||||
mutable msra::util::signallingevent callerchangedsignal, threadchangedsignal;
|
||||
void waitcallerchanged() const { callerchangedsignal.wait(); }
|
||||
void flagcallerchanged() const { callerchangedsignal.flag(); }
|
||||
void waitthreadchanged() const { threadchangedsignal.wait(); }
|
||||
void flagthreadchanged() const { threadchangedsignal.flag(); }
|
||||
// the thread proc
|
||||
volatile bool terminaterequest; // threadproc must respond to this
|
||||
size_t globalts; // read cursor, owned by thread only
|
||||
void threadproc()
|
||||
{
|
||||
// note on signaling:
|
||||
// This thread will always flag 'threadchangedsignal' if there is a state change,
|
||||
// e.g. a new batch is available, or we have successfully initialized.
|
||||
// The main ('caller') thread would check whether it finds a state it can make use of, and if not,
|
||||
// it will wait for the 'threadchangedsignal' and then check again the state etc.
|
||||
fprintf (stderr, "minibatchreadaheadsource: read-ahead thread entered\n");
|
||||
try
|
||||
{
|
||||
size_t epochreqframes = 0; // minibatch size for this epoch (taken from the first getbatch() call)
|
||||
size_t epochendframe = 0; // we cannot request beyond
|
||||
size_t globalts = 0; // reset request
|
||||
while (!terminaterequest)
|
||||
{
|
||||
bool stillhasdata;
|
||||
{
|
||||
CAutoLock lock (*this);
|
||||
// if reset request then do it
|
||||
if (newglobalts != SIZE_MAX)
|
||||
{
|
||||
// take over parameters from caller
|
||||
globalts = newglobalts;
|
||||
epochreqframes = currentepochreqframes;
|
||||
epochendframe = currentepochendframe;
|
||||
newglobalts = SIZE_MAX; // remember we got it
|
||||
// reset the FIFO
|
||||
fifo.clear();
|
||||
flagthreadchanged(); // signal state change (needed?)
|
||||
fprintf (stderr, "minibatchreadaheadsource: thread entered new epoch, frame pos reset to %d\n", (int) globalts);
|
||||
continue;
|
||||
}
|
||||
// did we run out of data to give to the caller?
|
||||
stillhasdata = !fifo.empty();
|
||||
}
|
||||
// we kick in once the FIFO is empty (and only once we know the mbsize)
|
||||
// Note that the underlying source will be able to fulfill many more minibatches at no cost
|
||||
// since we stopped pulling minibatches from it once it told us it read something from the disk.
|
||||
// Thus it is OK (efficient) to run the FIFO empty before we continue asking the underlying source
|
||||
// for more data--it will give us quite some more data for free--which the caller can go and process--
|
||||
// before an expensive read operation is needed again.
|
||||
if (globalts >= epochendframe || stillhasdata)
|
||||
{
|
||||
waitcallerchanged(); // nothing to do: wait for caller state change and check again
|
||||
continue;
|
||||
}
|
||||
// we will bring in data from the current 'globalts' until the sub-getbatch() tells us
|
||||
// that we loaded new data (which means subsequent getbatch() will be free until the next load).
|
||||
// We assume the access pattern that
|
||||
// - we start at or closely after the epoch boundary
|
||||
// - we never go across an epoch boundary
|
||||
// - the number of requested frames within an epoch is always the same except for the last MB
|
||||
// This pattern is implemented by the minibatchiterator. We require it.
|
||||
// (but it is possible that less is returned, i.e. at a sweep boundary or epoch end).
|
||||
bool readfromdisk = false;
|
||||
// we stop once data was read (the subsequent fetches will be cheap until the next data read)
|
||||
// For small setups, all data may be in RAM and thus no reading will happen anymore.
|
||||
// To guard against that, we limit the number of frames we pre-read.
|
||||
fprintf (stderr, "minibatchreadaheadsource: thread entering reading loop, frame read pos %d\n", (int) globalts);
|
||||
size_t batchesread = 0;
|
||||
const size_t prerollendframe = globalts + 360000; // read max. 1 hour --to guard against setups that fit to RAM entirely (no disk reading after startup)
|
||||
while (!terminaterequest && !readfromdisk && globalts < epochendframe && globalts < prerollendframe)
|
||||
{
|
||||
// get batch and append to FIFO (outside the lock)
|
||||
batchdata batch (globalts);
|
||||
const size_t requestedframes = min (epochreqframes, epochendframe - globalts); // we must not request beyond the epoch
|
||||
readfromdisk = source.getbatch (globalts, requestedframes, batch.feat, batch.uids, batch.transcripts, batch.lattices);
|
||||
batchesread++;
|
||||
// Note: We may still get data beyond the end of the epoch, in utterance mode, since the epoch boundary likely falls within an utterance.
|
||||
CAutoLock lock (*this);
|
||||
if (!fifo.empty() && globalts != fifo.back().globalts + fifo.back().feat.cols())
|
||||
throw std::logic_error ("minibatchreadaheadsource: FIFO got out of order while pre-reading new batch");
|
||||
if (newglobalts != SIZE_MAX)
|
||||
throw std::logic_error ("minibatchreadaheadsource: main thread reset to new epoch while current epoch not yet finished");
|
||||
globalts += batch.feat.cols();
|
||||
fifo.push_back (std::move (batch));
|
||||
flagthreadchanged(); // signal state change so caller can pick up the new batch
|
||||
}
|
||||
fprintf (stderr, "minibatchreadaheadsource: thread exited reading loop, %d batches read up to frame position %d-1\n", (int) batchesread, (int) globalts);
|
||||
}
|
||||
fprintf (stderr, "minibatchreadaheadsource: reading loop was terminated at frame position %d-1\n", (int) globalts);
|
||||
}
|
||||
catch (const exception & e)
|
||||
{
|
||||
fprintf (stderr, "minibatchreadaheadsource: exception caught in read-ahead thread: %s\n", e.what());
|
||||
thread->fail (e); // set the error first before we signal the caller
|
||||
flagthreadchanged();
|
||||
throw; // (this will set the error a second time; OK)
|
||||
}
|
||||
fprintf (stderr, "minibatchreadaheadsource: read-ahead thread exited normally\n");
|
||||
}
|
||||
void cancelthread() // this is only ever called by the destructor
|
||||
{
|
||||
fprintf (stderr, "minibatchreadaheadsource: requesting thread termination\n");
|
||||
terminaterequest = true;
|
||||
flagcallerchanged();
|
||||
thread->wait();
|
||||
}
|
||||
public:
|
||||
minibatchreadaheadsource (minibatchsource & source, size_t epochframes)
|
||||
: source (source), epochframes (epochframes),
|
||||
terminaterequest (false), globalts (SIZE_MAX),
|
||||
epoch (SIZE_MAX), currentepochreqframes (0), currentepochendframe (0), newglobalts (SIZE_MAX), verbosity(2)
|
||||
{
|
||||
// kick off the thread
|
||||
fprintf (stderr, "minibatchreadaheadsource: kicking off read-ahead thread\n");
|
||||
thread.reset (new msra::util::simplethread ([this] () { threadproc(); }));
|
||||
}
|
||||
~minibatchreadaheadsource()
|
||||
{
|
||||
fprintf (stderr, "~minibatchreadaheadsource: destructing read-ahead thread\n");
|
||||
cancelthread();
|
||||
}
|
||||
void setverbosity(int newverbosity){ verbosity = newverbosity; }
|
||||
bool getbatch (const size_t globalts,
|
||||
const size_t framesrequested, msra::dbn::matrix & feat, std::vector<size_t> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & lattices)
|
||||
{
|
||||
#if 1
|
||||
// first check whether the thread is still alive
|
||||
thread->check();
|
||||
// in case of epoch change, we signal the thread
|
||||
size_t thisepoch = globalts / epochframes;
|
||||
if (thisepoch != epoch)
|
||||
{
|
||||
fprintf (stderr, "minibatchreadaheadsource: signalling thread to enter new epoch\n");
|
||||
epoch = thisepoch; // remember for next check --we have officially changed epochs
|
||||
CAutoLock lock (*this);
|
||||
if (!fifo.empty())
|
||||
throw std::logic_error ("getbatch: FIFO not cleared at end of epoch");
|
||||
newglobalts = globalts;
|
||||
currentepochreqframes = framesrequested; // it is assumed that these won't change
|
||||
currentepochendframe = (epoch + 1) * epochframes;
|
||||
flagcallerchanged();
|
||||
}
|
||||
else if (globalts + framesrequested < currentepochendframe && currentepochreqframes != framesrequested)
|
||||
throw std::logic_error ("getbatch: cannot change minibatch size mid-epoch");
|
||||
// loop
|
||||
bool readfromdisk = false;
|
||||
for(;;) // wait for batch to appear
|
||||
{
|
||||
thread->check();
|
||||
{
|
||||
CAutoLock lock (*this);
|
||||
if (!fifo.empty())
|
||||
{
|
||||
// get the first batch from the FIFO
|
||||
batchdata front = std::move (fifo.front());
|
||||
fifo.pop_front();
|
||||
flagcallerchanged();
|
||||
// it must be the correct one
|
||||
if (front.globalts != globalts)
|
||||
throw std::logic_error ("getbatch: data in FIFO out of sequence");
|
||||
// return it
|
||||
feat = std::move (front.feat);
|
||||
uids = std::move (front.uids);
|
||||
transcripts = std::move (front.transcripts);
|
||||
lattices = std::move (front.lattices);
|
||||
return readfromdisk;
|
||||
}
|
||||
}
|
||||
// batch not there --keep looping
|
||||
waitthreadchanged();
|
||||
readfromdisk = true; // we had to wait --use to indicate that we needed to read data (does not really matter...)
|
||||
}
|
||||
#else
|
||||
return source.getbatch (globalts, framesrequested, feat, uids, transcripts, lattices);
|
||||
#endif
|
||||
}
|
||||
bool getbatch (const size_t globalts,
|
||||
const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & lattices)
|
||||
{
|
||||
|
||||
feat.resize(1);
|
||||
uids.resize(1);
|
||||
//transcripts.resize(1);
|
||||
//lattices.resize(1);
|
||||
return getbatch(globalts, framesrequested, feat[0], uids[0], transcripts, lattices);
|
||||
}
|
||||
|
||||
size_t totalframes() const { return source.totalframes(); }
|
||||
size_t epochsize() const {return epochframes;}double gettimegetbatch() { return source.gettimegetbatch(); } // TODO: no, use our own time measurement
|
||||
size_t firstvalidglobalts (const size_t globalts) { return source.firstvalidglobalts (globalts); }
|
||||
const std::vector<size_t> & unitcounts() const { return source.unitcounts(); }
|
||||
};
|
||||
|
||||
};};
|
|
@ -1,827 +0,0 @@
|
|||
//
|
||||
// <copyright file="rollingwindowsource.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// rollingwindowsource.h -- implementation of a rolling-window minibatch source ('minibatchframesource') with a disk page file
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "basetypes.h" // for attempt()
|
||||
//#include "numahelpers.h" // for NUMA allocation
|
||||
#include "minibatchsourcehelpers.h"
|
||||
#include "minibatchiterator.h"
|
||||
#include "biggrowablevectors.h"
|
||||
#include "ssematrix.h"
|
||||
|
||||
namespace msra { namespace dbn {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// biggrowablevectorarray -- a big array of vectors for features, growable (push_back)
|
||||
// Data is striped across NUMA nodes, as to not clog them up.
|
||||
// This also supports paging to disk, which is used for the old minibatchframesource.
|
||||
// ---------------------------------------------------------------------------
|
||||
class biggrowablevectorarray : public growablevectorbase<msra::dbn::matrix>
|
||||
{
|
||||
size_t m; // dim
|
||||
|
||||
size_t inmembegin; // range we have in memory, rounded to enclosing blocks (not rounded at end)
|
||||
size_t inmemend;
|
||||
|
||||
wstring pagepath; // path for paging, empty if no paging
|
||||
auto_file_ptr f; // file handle for paging
|
||||
bool reading; // have we begun reading?
|
||||
|
||||
// allocate a block
|
||||
msra::dbn::matrix * newblock() const
|
||||
{
|
||||
// we stripe the data across NUMA nodes as to not fill up one node with the feature data
|
||||
//msra::numa::overridenode ((int) msra::numa::getmostspaciousnumanode());
|
||||
msra::dbn::matrix * res = new msra::dbn::matrix (m, elementsperblock);
|
||||
//msra::numa::overridenode (-1); // note: we really should reset it also in case of failure
|
||||
return res;
|
||||
}
|
||||
|
||||
// handling of page file
|
||||
bool paging() const { return !pagepath.empty(); }
|
||||
void openpagefile (bool wantread)
|
||||
{
|
||||
if (!paging()) return;
|
||||
msra::files::make_intermediate_dirs (pagepath);
|
||||
|
||||
if (!wantread)
|
||||
{
|
||||
FILE *ftry = NULL;
|
||||
wstring pathname (pagepath);
|
||||
ftry = _wfopen (pathname.c_str(), L"wbS");
|
||||
if (ftry) fclose (ftry);
|
||||
}
|
||||
|
||||
/*
|
||||
code below to cycle through a-z appended to file name is no longer necessary
|
||||
since caller guarantees unique file names via HTKMLFReader
|
||||
and we want the pagepath logged to the user to be the actual one used by the code
|
||||
|
||||
// try to open the pagepath from a to z
|
||||
if (!wantread)
|
||||
{
|
||||
FILE *ftry = NULL;
|
||||
char trynum = 'a';
|
||||
while (!ftry && trynum <= 'z')
|
||||
{
|
||||
wstring pathname (pagepath);
|
||||
pathname += trynum++;
|
||||
ftry = _wfopen (pathname.c_str(), L"wbS");
|
||||
}
|
||||
if (ftry) fclose (ftry);
|
||||
pagepath += --trynum;
|
||||
}
|
||||
*/
|
||||
f = fopenOrDie (pagepath, wantread ? L"rbS" : L"wbS");
|
||||
reading = wantread;
|
||||
}
|
||||
void flushlastblock() // during population phase, must be called once per block in sequence
|
||||
{
|
||||
if (!paging()) return;
|
||||
assert (!reading);
|
||||
if (blocks.empty()) return;
|
||||
const size_t blockid = blocks.size() -1;
|
||||
msra::dbn::matrix & block = *blocks[blockid];
|
||||
assert (fgetpos (f) == blockid * block.sizeinpagefile());
|
||||
block.topagefile (f);
|
||||
blocks[blockid].reset(); // free the memory
|
||||
assert (blockid * elementsperblock == inmembegin);
|
||||
inmembegin = inmemend; // empty range
|
||||
}
|
||||
void releaseblock (size_t t0) // t0=block start time
|
||||
{
|
||||
assert (paging() && reading);
|
||||
size_t blockid = t0 / elementsperblock;
|
||||
assert (blockid * elementsperblock == t0);
|
||||
assert (blocks[blockid]);
|
||||
fprintf (stderr, "recoverblock: releasing feature block %zu [%zu..%zu)\n", blockid, t0, t0 + elementsperblock -1);
|
||||
blocks[blockid].reset(); // free the memory
|
||||
}
|
||||
void recoverblock (size_t t0) // t0=block start time
|
||||
{
|
||||
assert (paging() && reading);
|
||||
size_t blockid = t0 / elementsperblock;
|
||||
assert (blockid * elementsperblock == t0);
|
||||
assert (!blocks[blockid]);
|
||||
fprintf (stderr, "recoverblock: recovering feature block %zu [%zu..%zu)\n", blockid, t0, t0 + elementsperblock -1);
|
||||
blocks[blockid].reset (newblock());
|
||||
msra::dbn::matrix & block = *blocks[blockid];
|
||||
fsetpos (f, blockid * block.sizeinpagefile());
|
||||
block.frompagefile (f);
|
||||
}
|
||||
|
||||
public:
|
||||
biggrowablevectorarray (const wstring & pagepath)
|
||||
: growablevectorbase (65536), m (0),
|
||||
inmembegin (0), inmemend (0), pagepath (pagepath), reading (false)
|
||||
{
|
||||
openpagefile (false);
|
||||
if (paging())
|
||||
fprintf (stderr, "biggrowablevectorarray: creating disk backup store at '%S'\n", pagepath.c_str());
|
||||
}
|
||||
~biggrowablevectorarray() { // clean up the big temp file
|
||||
if (paging()) {
|
||||
fclose (f);
|
||||
if (_wunlink (pagepath.c_str())==0)
|
||||
fprintf (stderr, "biggrowablevectorarray: deleted disk backup store at '%S'\n", pagepath.c_str());
|
||||
else
|
||||
fprintf (stderr, "biggrowablevectorarray: unable to delete disk backup store at '%S'\n", pagepath.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
size_t dim() const { return m; } // dimension of a frame
|
||||
|
||||
// reading phase
|
||||
void push_back (const std::vector<float> & in)
|
||||
{
|
||||
assert (!in.empty());
|
||||
assert (m == 0 || m == in.size());
|
||||
m = in.size();
|
||||
const size_t blockid = n / elementsperblock;
|
||||
assert (blockid <= blocks.size());
|
||||
if (blockid == blocks.size()) // a new block is needed
|
||||
{
|
||||
flushlastblock();
|
||||
blocks.push_back (std::unique_ptr<msra::dbn::matrix> (newblock()));
|
||||
}
|
||||
const size_t blockn = n % elementsperblock;
|
||||
msra::dbn::matrix & block = *blocks[blockid].get();
|
||||
foreach_index (k, in)
|
||||
block(k,blockn) = in[k];
|
||||
n++;
|
||||
inmemend = n;
|
||||
}
|
||||
void no_more_push_back() // done pushing --switch to consumption mode
|
||||
{
|
||||
if (!paging()) return;
|
||||
// finish off last block
|
||||
flushlastblock();
|
||||
fflushOrDie (f);
|
||||
fprintf (stderr, "biggrowablevectorarray: disk backup store created, %d frames, %zu bytes\n", (int) n, fgetpos (f));
|
||||
fclose (f);
|
||||
foreach_index (i, blocks) assert (!blocks[i]); // ensure we flushed
|
||||
assert (inmembegin == inmemend); // nothing in cache
|
||||
// switch to reading mode
|
||||
openpagefile (true);
|
||||
}
|
||||
|
||||
// access phase
|
||||
// Returns 'true' if data was actually read from disk.
|
||||
bool require (pair<size_t,size_t> bounds) // we require this range of frames
|
||||
{
|
||||
bool readfromdisk = false;
|
||||
|
||||
// get bounds rounded to block boundaries
|
||||
const size_t ts = bounds.first / elementsperblock * elementsperblock;
|
||||
const size_t te = min (n, (bounds.second + elementsperblock -1) / elementsperblock * elementsperblock);
|
||||
assert (paging());
|
||||
// free all the memmory
|
||||
for (size_t t = inmembegin; t < inmemend; t += elementsperblock)
|
||||
{
|
||||
if (t >= ts && t < te) // if in wanted range then skip to end of it
|
||||
t = te - elementsperblock;
|
||||
else
|
||||
releaseblock (t);
|
||||
}
|
||||
// page in all required blocks
|
||||
for (size_t t = ts; t < te; t += elementsperblock)
|
||||
{
|
||||
if (t >= inmembegin && t < inmemend) // if in memory already then skip to end of it
|
||||
t = inmemend - elementsperblock;
|
||||
else
|
||||
{
|
||||
recoverblock (t);
|
||||
readfromdisk = true; // tell caller we did something expensive
|
||||
}
|
||||
}
|
||||
// got it
|
||||
inmembegin = ts;
|
||||
inmemend = te;
|
||||
return readfromdisk;
|
||||
}
|
||||
const msra::dbn::matrixstripe operator[] (size_t t) const // get a feature vector
|
||||
{
|
||||
if (t < inmembegin || t >= inmemend)
|
||||
throw std::logic_error ("biggrowablevectorarray: attempt to access vector without requesting to page it in first");
|
||||
const size_t blockt = getblockt (t);
|
||||
/*const*/ msra::dbn::matrix & block = getblock (t);
|
||||
return msra::dbn::matrixstripe (block, blockt, 1);
|
||||
}
|
||||
wstring pagepathname(){ return pagepath;}
|
||||
void cleanuppagefile()
|
||||
{
|
||||
if (paging()) {
|
||||
fclose (f);
|
||||
if (_wunlink (pagepath.c_str())==0){
|
||||
fprintf (stderr, "biggrowablevectorarray: deleted disk backup store at '%S'\n", pagepath.c_str());
|
||||
}
|
||||
else{
|
||||
fprintf (stderr, "biggrowablevectorarray: could NOT delete disk backup store at '%S'\n", pagepath.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// minibatchframesource -- feature source to provide randomized frames in minibatches
|
||||
// This is the old code that pages all frames to a huge disk file first.
|
||||
// (The new minibatchutterancesource pages from input files directly and can also
|
||||
// operate in utterance mode for MMI training.)
|
||||
// ---------------------------------------------------------------------------
|
||||
class minibatchframesource : public minibatchsource
|
||||
{
|
||||
size_t vdim; // feature dimension after augmenting neighhors (0: don't read features)
|
||||
unsigned int sampperiod; // (for reference and to check against model)
|
||||
string featkind;
|
||||
size_t featdim;
|
||||
// cache
|
||||
biggrowablevectorarray frames; // [t][i] all features concatenated
|
||||
std::vector<char> boundaryflags; // [t] -1 for first and +1 for last frame, 0 else (for augmentneighbors())
|
||||
std::vector<CLASSIDTYPE> classids; // [t] the state that the frame belongs to
|
||||
size_t numframes; // total frames (==frames.size()==boundaryflags.size()==classids.size()) unless special modes vdim == 0 and/or no labels
|
||||
msra::dbn::randomordering randomordering; // [t] -> t'
|
||||
double timegetbatch;
|
||||
int verbosity;
|
||||
public:
|
||||
// constructor
|
||||
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
|
||||
minibatchframesource (const std::vector<wstring> & infiles, const map<wstring,std::vector<msra::asr::htkmlfentry>> & labels,
|
||||
size_t vdim, size_t udim, size_t randomizationrange, const wstring & pagepath, const bool mayhavenoframe=false, int addEnergy=0)
|
||||
: vdim (vdim), sampperiod (0), featdim (0), numframes (0), frames (pagepath), timegetbatch (0), verbosity(2)
|
||||
{
|
||||
if (vdim == 0 && labels.empty())
|
||||
throw runtime_error ("minibatchframesource: when running without features, labels are needed");
|
||||
// at this stage, we simply page in the entire training set at once and work off RAM
|
||||
// We will benefit from feature archives indirectly through htkfeatio.
|
||||
// TODO:
|
||||
// - infiles must specify time range
|
||||
// - at this stage only reserve() (we know the time range; allocate second-layer structure)
|
||||
// - implement block-wise paging directly from HTK feature files through htkfeatreader
|
||||
featkind.clear();
|
||||
std::vector<float> frame;
|
||||
fprintf (stderr, "minibatchframesource: reading %zu utterances..", infiles.size());
|
||||
size_t numclasses = 0; // number of units found (actually max id +1)
|
||||
size_t notfound = 0; // number of entries missing in MLF
|
||||
msra::asr::htkfeatreader reader; // feature reader
|
||||
reader.AddEnergy(addEnergy);
|
||||
|
||||
foreach_index (i, infiles)
|
||||
{
|
||||
if (i % (infiles.size() / 100 + 1) == 0) { fprintf (stderr, "."); fflush (stderr); }
|
||||
msra::basetypes::matrix<float> feat;
|
||||
msra::asr::htkfeatreader::parsedpath ppath (infiles[i]);
|
||||
|
||||
// skip files for which labels don't exist (assuming bad alignment)
|
||||
wstring key;
|
||||
if (!labels.empty()) // empty means unsupervised mode (don't load any)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
key = removeExtension(basename(ppath));
|
||||
#endif
|
||||
if (labels.find (key) == labels.end())
|
||||
{
|
||||
if (notfound < 5)
|
||||
fprintf (stderr, "\nminibatchframesource: %d-th file not found in MLF label set: %S", i, key.c_str());
|
||||
notfound++;
|
||||
continue; // skip this utterance at all
|
||||
}
|
||||
}
|
||||
|
||||
// get feature frames
|
||||
if (vdim != 0) // (vdim == special mode to not read features at all)
|
||||
{
|
||||
msra::util::attempt (5, [&]()
|
||||
{
|
||||
reader.read (ppath, featkind, sampperiod, feat); // whole file read as columns of feature vectors
|
||||
});
|
||||
if (featdim == 0) // first time
|
||||
featdim = feat.rows();
|
||||
else if (featdim != feat.rows())
|
||||
throw std::runtime_error ("minibatchframesource: inconsistent feature dimension across files");
|
||||
// HVite occasionally generates mismatching output --skip such files
|
||||
if (!key.empty()) // (we have a key if supervised mode)
|
||||
{
|
||||
const auto & labseq = labels.find (key)->second; // (we already checked above that it exists)
|
||||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
|
||||
if (abs ((int) labframes - (int) feat.cols()) > 0)
|
||||
{
|
||||
fprintf (stderr, "\nminibatchframesource: %d-th file has small duration mismatch (%zu in label vs. %zu in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
|
||||
notfound++;
|
||||
continue; // skip this utterance at all
|
||||
}
|
||||
}
|
||||
// append to cache
|
||||
frame.resize (featdim);
|
||||
if (feat.cols() < 2) // (2 frames needed for boundary markers)
|
||||
throw std::runtime_error ("minibatchframesource: utterances < 2 frames not supported");
|
||||
foreach_column (t, feat)
|
||||
{
|
||||
foreach_index (k, frame)
|
||||
frame[k] = feat(k,t);
|
||||
frames.push_back (frame);
|
||||
numframes++;
|
||||
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
|
||||
}
|
||||
assert (numframes == frames.size());
|
||||
assert (numframes == boundaryflags.size());
|
||||
}
|
||||
|
||||
// get label sequence
|
||||
if (!key.empty()) // (we have a key if supervised mode)
|
||||
{
|
||||
const auto & labseq = labels.find (key)->second; // (we already checked above that it exists)
|
||||
foreach_index (i, labseq)
|
||||
{
|
||||
const auto & e = labseq[i];
|
||||
if ((i > 0 && labseq[i-1].firstframe + labseq[i-1].numframes != e.firstframe) || (i == 0 && e.firstframe != 0))
|
||||
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesource: labels not in consecutive order MLF in label set: %S", key.c_str()));
|
||||
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
|
||||
{
|
||||
if (e.classid >= udim)
|
||||
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesource: class id exceeds model dimension in file %S", key.c_str()));
|
||||
if (e.classid != (CLASSIDTYPE) e.classid)
|
||||
throw std::runtime_error ("CLASSIDTYPE has too few bits");
|
||||
classids.push_back ((CLASSIDTYPE) e.classid);
|
||||
numclasses = max ((size_t)numclasses, (size_t)(1u + e.classid));
|
||||
}
|
||||
}
|
||||
if (vdim == 0)
|
||||
numframes = classids.size();
|
||||
if (numframes != classids.size()) // TODO: remove this once we are confident
|
||||
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesource: label duration inconsistent with feature file in MLF label set: %S", key.c_str()));
|
||||
assert (numframes == classids.size());
|
||||
}
|
||||
else
|
||||
{
|
||||
assert (classids.empty()); // that's how we detect it later
|
||||
}
|
||||
}
|
||||
assert (vdim == 0 || numframes == frames.size());
|
||||
assert (labels.empty() || numframes == classids.size());
|
||||
if ((vdim != 0 && numframes != frames.size()) || (!labels.empty() && numframes != classids.size()))
|
||||
throw std::runtime_error ("minibatchframesource: numframes variable screwup");
|
||||
fprintf (stderr, " %zu frames read from %zu utterances; %zu classes\n", numframes, infiles.size(), numclasses);
|
||||
if (notfound > 0)
|
||||
{
|
||||
fprintf (stderr, "minibatchframesource: %zu files out of %zu not found in label set\n", notfound, infiles.size());
|
||||
if (notfound > infiles.size() / 2)
|
||||
throw std::runtime_error ("minibatchframesource: too many files not found in label set--assuming broken configuration\n");
|
||||
}
|
||||
|
||||
if (numframes == 0 && !mayhavenoframe)
|
||||
throw std::runtime_error ("minibatchframesource: no input features given!");
|
||||
|
||||
// notify frames source to switch from population to consumption mode
|
||||
frames.no_more_push_back();
|
||||
|
||||
// initialize randomizer
|
||||
if (numframes > 0)
|
||||
randomordering.resize (numframes, randomizationrange);
|
||||
}
|
||||
virtual ~minibatchframesource() {}
|
||||
size_t totalframes() const { assert (vdim == 0 || numframes == frames.size()); assert (!issupervised() || numframes == classids.size()); return numframes; }
|
||||
|
||||
bool issupervised() const { return !classids.empty(); }
|
||||
|
||||
void setverbosity(int newverbosity) { verbosity = newverbosity; }
|
||||
|
||||
// retrieve one minibatch
|
||||
// Minibatches are deterministic pseudo-random samples. The entire corpus
|
||||
// is repeated infinitely, but each repetition (a 'sweep') is randomized
|
||||
// differently.
|
||||
// This function allows to retrieve a mini-batch starting from any frame
|
||||
// within this infinitely extended repetition. To the end, mini-batches are
|
||||
// specified by start frame and #frames.
|
||||
// This function returns the same data independent on #frames, i.e. the concept
|
||||
// of the mini-batch is not defined in here, but on the caller side. The caller
|
||||
// can retrieve the frames of a mini-batch in chunks that do not match the
|
||||
// caller's definition of "mini-batch," e.g. bigger or smaller chunks.
|
||||
// If a requested mini-batch spans a sweep boundary, then this function will
|
||||
// not return samples after the sweep boundary. Instead, the returned frame
|
||||
// set is shortened to not exceed the end of the sweep. The caller must make
|
||||
// a separate second call to get the rest. In trainlayer(), the one
|
||||
// sweep-boundary-spanning mini-batch will simply be shortened.
|
||||
// This function is NOT thread-safe (due to caching of random sequence).
|
||||
bool getbatch (const size_t globalts, const size_t framesrequested, msra::dbn::matrix & feat, std::vector<size_t> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & latticepairs)
|
||||
{
|
||||
auto_timer timergetbatch;
|
||||
|
||||
transcripts.clear(); // word-level transcripts not supported by frame source (aimed at MMI)
|
||||
latticepairs.clear(); // neither are lattices
|
||||
|
||||
assert (totalframes() > 0);
|
||||
const size_t sweep = globalts / totalframes(); // which sweep (this determines randomization)
|
||||
const size_t ts = globalts % totalframes(); // start frame within the sweep
|
||||
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
|
||||
assert (te > ts);
|
||||
if (verbosity >= 2)
|
||||
fprintf (stderr, "getbatch: frames [%zu..%zu] in sweep %zu\n", ts, te-1, sweep);
|
||||
|
||||
// get random sequence (each time index occurs exactly once)
|
||||
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.
|
||||
const auto & tmap = randomordering (sweep);
|
||||
|
||||
// page in the needed range of frames
|
||||
const size_t extent = augmentationextent (frames.dim(), vdim);
|
||||
bool readfromdisk = frames.require (randomordering.bounds (max (ts, extent) - extent, te + 1 + extent));
|
||||
|
||||
// generate features and uids
|
||||
feat.resize (vdim, te - ts); // note: special mode vdim == 0 means no features to be loaded
|
||||
if (issupervised()) // empty means unsupervised training -> return empty uids
|
||||
uids.resize (te - ts);
|
||||
else
|
||||
uids.clear();
|
||||
for (size_t t = ts; t < te; t++)
|
||||
{
|
||||
size_t trand = tmap[t]; // the random-sequence sample point for this point in time
|
||||
if (vdim != 0)
|
||||
{
|
||||
auto v_t = feat.col(t-ts); // the vector to fill in
|
||||
augmentneighbors (frames, boundaryflags, trand, v_t);
|
||||
}
|
||||
if (issupervised())
|
||||
uids[t-ts] = classids[trand];
|
||||
}
|
||||
timegetbatch = timergetbatch;
|
||||
return readfromdisk;
|
||||
}
|
||||
|
||||
bool getbatch (const size_t globalts, const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & latticepairs)
|
||||
{
|
||||
// for single input/output set size to be 1 and run old getbatch
|
||||
feat.resize(1);
|
||||
uids.resize(1);
|
||||
//transcripts.resize(1);
|
||||
//latticepairs.resize(1);
|
||||
return getbatch(globalts, framesrequested, feat[0], uids[0], transcripts, latticepairs);
|
||||
}
|
||||
|
||||
double gettimegetbatch () { return timegetbatch;}
|
||||
|
||||
// return first valid globalts to ask getbatch() for
|
||||
// In frame mode, there is no constraint, i.e. it is 'globalts' itself.
|
||||
/*implement*/ size_t firstvalidglobalts (const size_t globalts) { return globalts; }
|
||||
|
||||
/*implement*/ const std::vector<size_t> & unitcounts() const { throw logic_error ("unitcounts: not implemented for this feature source"); static std::vector<size_t> x; return x;/*keep compiler happy*/ }
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// minibatchframesourcemulti -- feature source to provide randomized frames in minibatches
|
||||
// this is derived from minibatchframesource but worked with multiple inputs and/or outputs
|
||||
// by making "frames" and "classids" a vector of vectors
|
||||
// ---------------------------------------------------------------------------
|
||||
class minibatchframesourcemulti : public minibatchsource
|
||||
{
|
||||
std::vector<size_t> vdim; // feature dimension after augmenting neighhors (0: don't read features)
|
||||
std::vector<size_t> leftcontext; // number of frames to the left of the target frame in the context window
|
||||
std::vector<size_t> rightcontext; // number of frames to the right of the target frame in the context window
|
||||
unsigned int sampperiod; // (for reference and to check against model)
|
||||
string featkind;
|
||||
size_t featdim;
|
||||
size_t maxvdim;
|
||||
// cache
|
||||
//std::vector<biggrowablevectorarray> frames;
|
||||
std::vector<unique_ptr<biggrowablevectorarray>> pframes; // [t][i] all features concatenated
|
||||
std::vector<char> boundaryflags; // [t] -1 for first and +1 for last frame, 0 else (for augmentneighbors())
|
||||
std::vector<std::vector<CLASSIDTYPE>> classids; // [t] the state that the frame belongs to
|
||||
size_t numframes; // total frames (==frames.size()==boundaryflags.size()==classids.size()) unless special modes vdim == 0 and/or no labels
|
||||
msra::dbn::randomordering randomordering; // [t] -> t'
|
||||
double timegetbatch;
|
||||
int verbosity;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
// Pass empty labels to denote unsupervised training (so getbatch() will not return uids).
|
||||
minibatchframesourcemulti (const std::vector<std::vector<wstring>> & infiles, const std::vector<map<std::wstring,std::vector<msra::asr::htkmlfentry>>> & labels,
|
||||
std::vector<size_t> vdim, std::vector<size_t> udim, std::vector<size_t> leftcontext, std::vector<size_t> rightcontext, size_t randomizationrange, const std::vector<wstring> & pagepath, const bool mayhavenoframe=false, int addEnergy=0)
|
||||
: vdim (vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod (0), featdim (0), numframes (0), timegetbatch (0), verbosity(2), maxvdim(0)
|
||||
{
|
||||
|
||||
if (vdim[0] == 0 && labels.empty())
|
||||
throw runtime_error ("minibatchframesourcemulti: when running without features, labels are needed");
|
||||
// at this stage, we simply page in the entire training set at once and work off RAM
|
||||
// We will benefit from feature archives indirectly through htkfeatio.
|
||||
// TODO:
|
||||
// - infiles must specify time range
|
||||
// - at this stage only reserve() (we know the time range; allocate second-layer structure)
|
||||
// - implement block-wise paging directly from HTK feature files through htkfeatreader
|
||||
featkind.clear();
|
||||
std::vector<float> frame;
|
||||
std::vector<size_t>numclasses; // number of units found (actually max id +1)
|
||||
size_t notfound = 0; // number of entries missing in MLF
|
||||
|
||||
|
||||
std::vector<size_t>framesaccum;
|
||||
|
||||
if (infiles.size()==0)
|
||||
throw runtime_error("minibatchframesourcemulti: need at least one network input specified with features");
|
||||
|
||||
if (labels.size()==0)
|
||||
fprintf(stderr,"no MLF label files detected\n");
|
||||
|
||||
foreach_index (i, infiles)
|
||||
{
|
||||
pframes.push_back(unique_ptr<biggrowablevectorarray>(new biggrowablevectorarray(pagepath[i])));
|
||||
|
||||
if (vdim[i]>maxvdim)
|
||||
maxvdim=vdim[i];
|
||||
}
|
||||
|
||||
|
||||
foreach_index (i, labels)
|
||||
{
|
||||
classids.push_back(std::vector<CLASSIDTYPE>());
|
||||
numclasses.push_back(0);
|
||||
}
|
||||
|
||||
|
||||
fprintf (stderr, "minibatchframesourcemulti: reading %zu feature sets and %zu label sets...", infiles.size(),labels.size());
|
||||
|
||||
foreach_index (m, infiles)
|
||||
{
|
||||
|
||||
|
||||
featdim=0;
|
||||
numframes=0;
|
||||
featkind.clear();
|
||||
msra::asr::htkfeatreader reader; // feature reader
|
||||
reader.AddEnergy(addEnergy);
|
||||
|
||||
foreach_index (i, infiles[m]) // read each feature file in set m
|
||||
{
|
||||
if (i % (infiles[m].size() / 100 + 1) == 0) { fprintf (stderr, "."); fflush (stderr); }
|
||||
msra::basetypes::matrix<float> feat;
|
||||
msra::asr::htkfeatreader::parsedpath ppath (infiles[m][i]);
|
||||
|
||||
// skip files for which labels don't exist (assuming bad alignment)
|
||||
wstring key;
|
||||
if (!labels.empty())
|
||||
{
|
||||
if (!labels[0].empty()) // empty means unsupervised mode (don't load any)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
key = regex_replace ((wstring)ppath, wregex (L"\\.[^\\.\\\\/:]*$"), wstring()); // delete extension (or not if none)
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
key = removeExtension(basename(ppath));
|
||||
#endif
|
||||
if (labels[0].find (key) == labels[0].end())
|
||||
{
|
||||
if (notfound < 5)
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file not found in MLF label set: %S", i, key.c_str());
|
||||
notfound++;
|
||||
continue; // skip this utterance at all
|
||||
}
|
||||
}
|
||||
}
|
||||
// get feature frames
|
||||
if (vdim[m] != 0) // (vdim == special mode to not read features at all)
|
||||
{
|
||||
msra::util::attempt (5, [&]()
|
||||
{
|
||||
reader.read (ppath, featkind, sampperiod, feat); // whole file read as columns of feature vectors
|
||||
});
|
||||
if (featdim == 0) // first time
|
||||
featdim = feat.rows();
|
||||
else if (featdim != feat.rows())
|
||||
throw std::runtime_error ("minibatchframesourcemulti: inconsistent feature dimension across files");
|
||||
// HVite occasionally generates mismatching output --skip such files
|
||||
if (!key.empty()) // (we have a key if supervised mode)
|
||||
{
|
||||
const auto & labseq = labels[0].find (key)->second; // (we already checked above that it exists)
|
||||
size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size()-1].firstframe + labseq[labseq.size()-1].numframes);
|
||||
if (abs ((int) labframes - (int) feat.cols()) > 0)
|
||||
{
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: %d-th file has small duration mismatch (%zu in label vs. %zu in feat file), skipping: %S", i, labframes, feat.cols(), key.c_str());
|
||||
notfound++;
|
||||
continue; // skip this utterance at all
|
||||
}
|
||||
}
|
||||
// append to cache
|
||||
frame.resize (featdim);
|
||||
if (feat.cols() < 2) // (2 frames needed for boundary markers)
|
||||
throw std::runtime_error ("minibatchframesourcemulti: utterances < 2 frames not supported");
|
||||
foreach_column (t, feat)
|
||||
{
|
||||
foreach_index (k, frame)
|
||||
frame[k] = feat(k,t);
|
||||
|
||||
pframes[m]->push_back (frame);
|
||||
numframes++;
|
||||
if (m==0)
|
||||
boundaryflags.push_back ((t == 0) ? -1 : (t == feat.cols() -1) ? +1 : 0);
|
||||
}
|
||||
if (m==0)
|
||||
framesaccum.push_back(numframes);
|
||||
else
|
||||
assert(numframes == framesaccum[i]);
|
||||
|
||||
assert (numframes == pframes[m]->size());
|
||||
}
|
||||
if (m==0)
|
||||
assert (numframes == boundaryflags.size());
|
||||
|
||||
|
||||
|
||||
if (m==0) // after we get the key for this file, read all labels (only done for first feature)
|
||||
{
|
||||
if (!key.empty())
|
||||
{
|
||||
foreach_index (j, labels)
|
||||
{
|
||||
const auto & labseq = labels[j].find (key)->second; // (we already checked above that it exists)
|
||||
foreach_index (i, labseq)
|
||||
{
|
||||
const auto & e = labseq[i];
|
||||
if ((i > 0 && labseq[i-1].firstframe + labseq[i-1].numframes != e.firstframe) || (i == 0 && e.firstframe != 0))
|
||||
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesourcemulti: labels not in consecutive order MLF in label set: %S", key.c_str()));
|
||||
for (size_t t = e.firstframe; t < e.firstframe + e.numframes; t++)
|
||||
{
|
||||
if (e.classid >= udim[j])
|
||||
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesourcemulti: class id exceeds model dimension in file %S", key.c_str()));
|
||||
if (e.classid != (CLASSIDTYPE) e.classid)
|
||||
throw std::runtime_error ("CLASSIDTYPE has too few bits");
|
||||
classids[j].push_back ((CLASSIDTYPE) e.classid);
|
||||
numclasses[j] = max (numclasses[j], (long unsigned int)(1u + e.classid));
|
||||
}
|
||||
}
|
||||
if (vdim[m] == 0)
|
||||
numframes = classids[j].size();
|
||||
if (numframes != classids[j].size()) // TODO: remove this once we are confident
|
||||
throw std::runtime_error (msra::strfun::strprintf ("minibatchframesourcemulti: label duration inconsistent with feature file in MLF label set: %S", key.c_str()));
|
||||
assert (numframes == classids[j].size());
|
||||
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(classids.empty());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
assert (vdim[m] == 0 || numframes == pframes[m]->size());
|
||||
|
||||
foreach_index(j, labels)
|
||||
assert (labels[j].empty() || numframes == classids[j].size());
|
||||
|
||||
if (vdim[m] != 0 && numframes != pframes[m]->size()) // || (!labels.empty() && numframes != classids.size()))
|
||||
throw std::runtime_error ("\nminibatchframesource: numframes variable screwup");
|
||||
if (m==0)
|
||||
{
|
||||
foreach_index (j, numclasses)
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: read label set %d: %zu classes\n", j, numclasses[j]);
|
||||
}
|
||||
fprintf (stderr, "\nminibatchframesourcemulti: feature set %d: %zu frames read from %zu utterances\n", m, pframes[m]->size(), infiles[m].size());
|
||||
if (notfound > 0)
|
||||
{
|
||||
fprintf (stderr, "minibatchframesourcemulti: %zu files out of %zu not found in label set\n", notfound, infiles[m].size());
|
||||
if (notfound > infiles[m].size() / 2)
|
||||
throw std::runtime_error ("minibatchframesourcemulti: too many files not found in label set--assuming broken configuration\n");
|
||||
}
|
||||
// notify frames source to switch from population to consumption mode
|
||||
pframes[m]->no_more_push_back();
|
||||
|
||||
}
|
||||
|
||||
if (numframes == 0 && !mayhavenoframe)
|
||||
throw std::runtime_error ("minibatchframesource: no input features given!");
|
||||
|
||||
|
||||
// initialize randomizer
|
||||
if (numframes > 0)
|
||||
randomordering.resize (numframes, randomizationrange);
|
||||
|
||||
}
|
||||
virtual ~minibatchframesourcemulti() {}
|
||||
size_t totalframes() const {
|
||||
assert (maxvdim == 0 || numframes == pframes[0]->size()); assert (!issupervised() || numframes == classids[0].size()); return numframes; }
|
||||
|
||||
bool issupervised() const { return !classids.empty(); }
|
||||
|
||||
void setverbosity(int newverbosity) { verbosity = newverbosity; }
|
||||
|
||||
// retrieve one minibatch
|
||||
// Minibatches are deterministic pseudo-random samples. The entire corpus
|
||||
// is repeated infinitely, but each repetition (a 'sweep') is randomized
|
||||
// differently.
|
||||
// This function allows to retrieve a mini-batch starting from any frame
|
||||
// within this infinitely extended repetition. To the end, mini-batches are
|
||||
// specified by start frame and #frames.
|
||||
// This function returns the same data independent on #frames, i.e. the concept
|
||||
// of the mini-batch is not defined in here, but on the caller side. The caller
|
||||
// can retrieve the frames of a mini-batch in chunks that do not match the
|
||||
// caller's definition of "mini-batch," e.g. bigger or smaller chunks.
|
||||
// If a requested mini-batch spans a sweep boundary, then this function will
|
||||
// not return samples after the sweep boundary. Instead, the returned frame
|
||||
// set is shortened to not exceed the end of the sweep. The caller must make
|
||||
// a separate second call to get the rest. In trainlayer(), the one
|
||||
// sweep-boundary-spanning mini-batch will simply be shortened.
|
||||
// This function is NOT thread-safe (due to caching of random sequence).
|
||||
bool getbatch (const size_t globalts, const size_t framesrequested, std::vector<msra::dbn::matrix> & feat, std::vector<std::vector<size_t>> & uids,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & transcripts,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & latticepairs)
|
||||
{
|
||||
|
||||
auto_timer timergetbatch;
|
||||
bool readfromdisk;
|
||||
size_t nreadfromdisk=0;
|
||||
transcripts.clear(); // word-level transcripts not supported by frame source (aimed at MMI)
|
||||
latticepairs.clear(); // neither are lattices
|
||||
|
||||
assert (totalframes() > 0);
|
||||
const size_t sweep = globalts / totalframes(); // which sweep (this determines randomization)
|
||||
const size_t ts = globalts % totalframes(); // start frame within the sweep
|
||||
const size_t te = min (ts + framesrequested, totalframes()); // do not go beyond sweep boundary
|
||||
assert (te > ts);
|
||||
if (verbosity >= 2)
|
||||
fprintf (stderr, "getbatch: frames [%zu..%zu] in sweep %zu\n", ts, te-1, sweep);
|
||||
|
||||
// get random sequence (each time index occurs exactly once)
|
||||
// If the sweep changes, this will re-cache the sequence. We optimize for rare, monotonous sweep changes.
|
||||
const auto & tmap = randomordering (sweep);
|
||||
|
||||
feat.resize(pframes.size());
|
||||
uids.resize(classids.size());
|
||||
foreach_index(i, feat)
|
||||
{
|
||||
size_t leftextent, rightextent;
|
||||
// page in the needed range of frames
|
||||
if (leftcontext[i] == 0 && rightcontext[i] == 0)
|
||||
{
|
||||
leftextent = rightextent = augmentationextent(pframes[i]->dim(), vdim[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
leftextent = leftcontext[i];
|
||||
rightextent = rightcontext[i];
|
||||
}
|
||||
readfromdisk = pframes[i]->require (randomordering.bounds (max (ts, leftextent) - leftextent, te + 1 + rightextent));
|
||||
// generate features and uids
|
||||
feat[i].resize (vdim[i], te - ts); // note: special mode vdim == 0 means no features to be loaded
|
||||
if (issupervised()) // empty means unsupervised training -> return empty uids
|
||||
foreach_index(j, uids)
|
||||
uids[j].resize (te - ts);
|
||||
else
|
||||
uids.clear();
|
||||
|
||||
for (size_t t = ts; t < te; t++)
|
||||
{
|
||||
size_t trand = tmap[t]; // the random-sequence sample point for this point in time
|
||||
if (vdim[i] != 0)
|
||||
{
|
||||
auto v_t = feat[i].col(t-ts); // the vector to fill in
|
||||
augmentneighbors (*pframes[i], boundaryflags, trand, leftextent, rightextent, v_t);
|
||||
}
|
||||
if (i==0){ // read labels for all outputs on first pass thru features. this guarantees they will be read if only one feature set but > 1 label set
|
||||
if (issupervised())
|
||||
foreach_index(j, uids)
|
||||
uids[j][t-ts] = classids[j][trand];
|
||||
}
|
||||
}
|
||||
timegetbatch = timergetbatch;
|
||||
if (readfromdisk)
|
||||
nreadfromdisk++;
|
||||
|
||||
}
|
||||
|
||||
(nreadfromdisk==feat.size()) ? readfromdisk = true : readfromdisk = false;
|
||||
|
||||
return readfromdisk;
|
||||
|
||||
}
|
||||
|
||||
bool getbatch (const size_t /*globalts*/, const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector<size_t> & /*uids*/,
|
||||
std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> & /*transcripts*/,
|
||||
std::vector<shared_ptr<const latticesource::latticepair>> & /*latticepairs*/)
|
||||
{
|
||||
// should never get here
|
||||
throw runtime_error("minibatchframesourcemulti: getbatch() being called for single input feature and single output feature, should use minibatchframesource instead\n");
|
||||
}
|
||||
|
||||
double gettimegetbatch () { return timegetbatch;}
|
||||
|
||||
// return first valid globalts to ask getbatch() for
|
||||
// In frame mode, there is no constraint, i.e. it is 'globalts' itself.
|
||||
/*implement*/ size_t firstvalidglobalts (const size_t globalts) { return globalts; }
|
||||
|
||||
/*implement*/ const std::vector<size_t> & unitcounts() const { throw logic_error ("unitcounts: not implemented for this feature source"); }
|
||||
|
||||
};
|
||||
};};
|
|
@ -1,89 +0,0 @@
|
|||
//
|
||||
// <copyright file="simple_checked_arrays.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// simple_checked_arrays.h -- a simple wrapper around pointers used as arrays to allow bounds checking
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h> // for size_t
|
||||
#include <assert.h>
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// array_ref -- wraps a C pointer to an array together with its size.
|
||||
//
|
||||
// Called _ref because this is a reference to the array rather than the array
|
||||
// itself (since it wraps a pointer). No need to pass an array_ref by reference.
|
||||
//
|
||||
// operator[] checks index bounds in Debug builds. size() is provided such
|
||||
// that this class can be substituted for STL vector in many cases.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template<class _T> class array_ref
|
||||
{
|
||||
_T * data;
|
||||
size_t n;
|
||||
inline void check_index (size_t i) const { i; assert (i < n); }
|
||||
inline void check_ptr() const { n; data; assert (n == 0 || data != NULL); }
|
||||
public:
|
||||
inline array_ref (_T * ptr, size_t size) throw() : data (ptr), n (size) { }
|
||||
inline array_ref() throw() : data (NULL), n (0) { } // in case we have a vector of this
|
||||
inline _T & operator[] (size_t i) throw() { check_index (i); return data[i]; }
|
||||
inline const _T & operator[] (size_t i) const throw() { check_index (i); return data[i]; }
|
||||
inline size_t size() const throw() { return n; }
|
||||
inline _T * begin() { return data; }
|
||||
inline _T * end() { return data + n; }
|
||||
inline void resize (size_t sz) { sz; assert (n == sz); } // allow compatibility with some functions
|
||||
// construct from other vector types
|
||||
template<class _V> inline array_ref (_V & v) : data (v.size() > 0 ? &v[0] : NULL), n ((size_t) v.size()) { }
|
||||
};
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// const_array_ref -- same as array_ref for 'const' (read-only) pointers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template<class _T> class const_array_ref
|
||||
{
|
||||
const _T * data;
|
||||
size_t n;
|
||||
inline void check_index (size_t i) const { i; assert (i < n); }
|
||||
inline void check_ptr() const { n; data; assert (n == 0 || data != NULL); }
|
||||
public:
|
||||
inline const_array_ref (const _T * ptr, size_t size) throw() : data (ptr), n (size) { }
|
||||
inline const_array_ref() throw() : data (NULL), n (0) { } // in case we have a vector of this
|
||||
inline const _T & operator[] (size_t i) const throw() { check_index (i); return data[i]; }
|
||||
inline size_t size() const throw() { return n; }
|
||||
inline const _T * begin() { return data; }
|
||||
inline const _T * end() { return data + n; }
|
||||
inline const _T & front() const throw() { check_index (0); return data[0];}
|
||||
inline const _T & back() const throw() {check_index (0); return data[n-1];}
|
||||
// construct from other vector types
|
||||
template<class _V> inline const_array_ref (const _V & v) : data (v.size() > 0 ? &v[0] : NULL), n ((size_t) v.size()) { }
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// hardcoded_array -- wraps a fixed-size C array together with its size.
|
||||
//
|
||||
// operator[] checks index bounds in Debug builds. size() is provided such
|
||||
// that this class can be substituted for STL vector in many cases.
|
||||
// Can be constructed with a size parameter--it will be checked against the
|
||||
// hard-coded size.
|
||||
// Can also be constructed with an initialization parameter (typ. 0).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
template<class _T, int _N> class hardcoded_array
|
||||
{
|
||||
_T data[_N];
|
||||
inline void check_index (size_t i) const { i; assert (i < _N); }
|
||||
inline void check_size (size_t n) const { n; assert (n == _N); }
|
||||
public:
|
||||
inline hardcoded_array() throw() {}
|
||||
inline hardcoded_array (size_t n) throw() { check_size (n); } // we can instantiate with a size parameter--just checks the size
|
||||
inline hardcoded_array (size_t n, const _T & val) throw() { check_size (n); for (size_t i = 0; i < n; i++) data[i] = val; }
|
||||
inline _T & operator[] (size_t i) throw() { check_index (i); return data[i]; }
|
||||
inline const _T & operator[] (size_t i) const throw() { check_index (i); return data[i]; }
|
||||
inline size_t size() const throw() { return _N; }
|
||||
};
|
|
@ -1,241 +0,0 @@
|
|||
//
|
||||
// <copyright file="simplesenonehmm.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// latticearchive.h -- managing lattice archives
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "basetypes.h"
|
||||
#include "fileutil.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <algorithm> // for find()
|
||||
#include "simple_checked_arrays.h"
|
||||
|
||||
namespace msra { namespace asr {
|
||||
|
||||
// ===========================================================================
|
||||
// simplesenonehmm -- simple senone-based CD-HMM
|
||||
// ===========================================================================
|
||||
|
||||
class simplesenonehmm
|
||||
{
|
||||
public: // (TODO: better encapsulation)
|
||||
static const size_t MAXSTATES = 3; // we use a fixed memory allocation since it's almost always 3 anyway
|
||||
struct transP;
|
||||
struct hmm
|
||||
{
|
||||
const char * name; // (this points into the key in the hash table to save memory)
|
||||
struct transP * transP; // underlying transition matrix
|
||||
unsigned char transPindex; // index of transP in struct transP
|
||||
unsigned char numstates; // number of states
|
||||
unsigned short senoneids[MAXSTATES]; // [0..numstates-1] senone indices
|
||||
|
||||
const char * getname() const { return name; } // (should be used for diagnostics only)
|
||||
size_t getsenoneid (size_t i) const { if (i < numstates) return (size_t) senoneids[i]; throw std::logic_error ("getsenoneid: out of bounds access"); }
|
||||
size_t getnumstates() const { return (size_t) numstates; }
|
||||
unsigned char gettransPindex() const { return transPindex;}
|
||||
const struct transP & gettransP() const { return *transP; }
|
||||
|
||||
bool operator< (const hmm & other) const
|
||||
{
|
||||
return memcmp (this, &other, sizeof (other)) < 0;
|
||||
}
|
||||
};
|
||||
std::vector<hmm> hmms; // the set of HMMs
|
||||
std::unordered_map<std::string,size_t> symmap; // [name] -> index into hmms[]
|
||||
struct transP
|
||||
{
|
||||
private:
|
||||
size_t numstates;
|
||||
float loga[MAXSTATES+1][MAXSTATES+1];
|
||||
void check (int from, size_t to) const { if (from < -1 || from >= (int) numstates || to > numstates) throw std::logic_error ("transP: index out of bounds"); }
|
||||
public:
|
||||
void resize (size_t n) { if (n > MAXSTATES) throw std::runtime_error ("resize: requested transP that exceeds MAXSTATES"); numstates = n; }
|
||||
size_t getnumstates() const { return numstates; }
|
||||
// from = -1 and to = numstates are allowed, but we also allow 'from' to be size_t to avoid silly typecasts
|
||||
float & operator() (int from, size_t to) { check (from, to); return loga[from+1][to]; } // from >= -1
|
||||
const float & operator() (int from, size_t to) const { check (from, to); return loga[from+1][to]; } // from >= -1
|
||||
const float & operator() (size_t from, size_t to) const { check ((int)from, to); return loga[from+1][to]; } // from >= 0
|
||||
transP() : numstates (0) {}
|
||||
};
|
||||
std::vector<transP> transPs; // the transition matrices --TODO: finish this
|
||||
std::hash_map<std::string,size_t> transPmap; // [transPname] -> index into transPs[]
|
||||
public:
|
||||
// get an hmm by index
|
||||
const hmm & gethmm (size_t i) const { return hmms[i]; }
|
||||
|
||||
// get an hmm by name
|
||||
size_t gethmmid (const string & name) const
|
||||
{
|
||||
auto iter = symmap.find (name);
|
||||
if (iter == symmap.end())
|
||||
throw std::logic_error ("gethmm: unknown unit name: " + name);
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
// diagnostics: map state id to senone name
|
||||
std::vector<std::string> statenames;
|
||||
const char * getsenonename (size_t senoneid) const { return statenames[senoneid].c_str(); }
|
||||
|
||||
// inverse lookup, for re-scoring the ground-truth path for sequence training
|
||||
// This may be ambiguous, but we know that for current setup, that's only the case for /sil/ and /sp/.
|
||||
std::vector<int> senoneid2transPindex; // or -1 if ambiguous
|
||||
std::vector<int> senoneid2stateindex; // 0..2, or -1 if ambiguous
|
||||
|
||||
// construct from model files
|
||||
simplesenonehmm (const std::wstring & cdphonetyingpath, const std::wstring & statelistpath, const std::wstring & transPpath)
|
||||
{
|
||||
if (cdphonetyingpath.empty()) // no tying info specified --just leave an empty object
|
||||
return;
|
||||
fprintf (stderr, "simplesenonehmm: reading '%S', '%S', '%S'\n", cdphonetyingpath.c_str(), statelistpath.c_str(), transPpath.c_str());
|
||||
// read the state list
|
||||
vector<char> textbuffer;
|
||||
auto readstatenames = msra::files::fgetfilelines (statelistpath, textbuffer);
|
||||
foreach_index (s, readstatenames)
|
||||
statenames.push_back (readstatenames[s]);
|
||||
std::unordered_map<std::string,size_t> statemap; // [name] -> index
|
||||
statemap.rehash (readstatenames.size());
|
||||
foreach_index (i, readstatenames)
|
||||
statemap[readstatenames[i]] = i;
|
||||
// TRANSPNAME NUMSTATES (ROW_from[to])+
|
||||
msra::strfun::tokenizer toks (" \t", 5);
|
||||
auto transPlines = msra::files::fgetfilelines (transPpath, textbuffer);
|
||||
transPs.resize (transPlines.size());
|
||||
string key; key.reserve (100);
|
||||
foreach_index (i, transPlines)
|
||||
{
|
||||
toks = transPlines[i];
|
||||
if (toks.size() < 3)
|
||||
throw std::runtime_error ("simplesenonehmm: too few tokens in transP line: " + string (transPlines[i]));
|
||||
key = toks[0]; // transPname --using existing object to avoid malloc
|
||||
transPmap[key] = i;
|
||||
size_t numstates = msra::strfun::toint (toks[1]);
|
||||
if (numstates == 0)
|
||||
throw std::runtime_error ("simplesenonehmm: invalid numstates: " + string (transPlines[i]));
|
||||
auto & transP = transPs[i];
|
||||
transP.resize (numstates);
|
||||
size_t k = 2; // index into tokens; transP values start at toks[2]
|
||||
for (int from = -1; from < (int) numstates; from++) for (size_t to = 0; to <= numstates; to++)
|
||||
{
|
||||
if (k >= toks.size())
|
||||
throw std::runtime_error ("simplesenonehmm: not enough tokens on transP line: " + string (transPlines[i]));
|
||||
const char * sval = toks[k++];
|
||||
const double aij = msra::strfun::todouble (sval);
|
||||
if (aij > 1e-10) // non-0
|
||||
transP(from,to) = logf ((float) aij); // we store log probs
|
||||
else
|
||||
transP(from,to) = -1e30f;
|
||||
}
|
||||
if (toks.size() > k)
|
||||
throw std::runtime_error ("simplesenonehmm: unexpected garbage at endof transP line: " + string (transPlines[i]));
|
||||
}
|
||||
// allocate inverse lookup
|
||||
senoneid2transPindex.resize (readstatenames.size(), -2);
|
||||
senoneid2stateindex.resize (readstatenames.size(), -2);
|
||||
// read the cd-phone tying info
|
||||
// HMMNAME TRANSPNAME SENONENAME+
|
||||
auto lines = msra::files::fgetfilelines (cdphonetyingpath, textbuffer);
|
||||
hmms.reserve (lines.size());
|
||||
symmap.rehash (lines.size());
|
||||
// two tables: (1) name -> HMM; (2) HMM -> HMM index (uniq'ed)
|
||||
map<string,hmm> name2hmm; // [name] -> unique HMM struct (without name)
|
||||
map<hmm,size_t> hmm2index; // [unique HMM struct] -> hmm index, hmms[i] contains full hmm
|
||||
foreach_index (i, lines)
|
||||
{
|
||||
toks = lines[i];
|
||||
if (toks.size() < 3)
|
||||
throw std::runtime_error ("simplesenonehmm: too few tokens in line: " + string (lines[i]));
|
||||
const char * hmmname = toks[0];
|
||||
const char * transPname = toks[1];
|
||||
// build the HMM structure
|
||||
hmm hmm;
|
||||
hmm.name = NULL; // for use as key in hash tables, we keep this NULL
|
||||
// get the transP pointer
|
||||
// TODO: this becomes a hard lookup with failure
|
||||
key = transPname; // (reuse existing memory)
|
||||
auto iter = transPmap.find (key);
|
||||
if (iter == transPmap.end())
|
||||
throw std::runtime_error ("simplesenonehmm: unknown transP name: " + string (lines[i]));
|
||||
size_t transPindex = iter->second;
|
||||
hmm.transPindex = (unsigned char) transPindex;
|
||||
hmm.transP = &transPs[transPindex];
|
||||
if (hmm.transPindex != transPindex)
|
||||
throw std::runtime_error ("simplesenonehmm: numeric overflow for transPindex field");
|
||||
// get the senones
|
||||
hmm.numstates = (unsigned char) (toks.size() - 2); // remaining tokens
|
||||
if (hmm.numstates != transPs[transPindex].getnumstates())
|
||||
throw std::runtime_error ("simplesenonehmm: number of states mismatches that of transP: " + string (lines[i]));
|
||||
if (hmm.numstates > _countof (hmm.senoneids))
|
||||
throw std::runtime_error ("simplesenonehmm: hmm.senoneids[MAXSTATES] is too small in line: " + string (lines[i]));
|
||||
for (size_t s = 0; s < hmm.numstates; s++)
|
||||
{
|
||||
const char * senonename = toks[s+2];
|
||||
key = senonename; // (reuse existing memory)
|
||||
auto iter = statemap.find (key);
|
||||
if (iter == statemap.end())
|
||||
throw std::runtime_error ("simplesenonehmm: unrecognized senone name in line: " + string (lines[i]));
|
||||
hmm.senoneids[s] = (unsigned short) iter->second;
|
||||
if (hmm.getsenoneid(s) != iter->second)
|
||||
throw std::runtime_error ("simplesenonehmm: not enough bits to store senone index in line: " + string (lines[i]));
|
||||
// inverse lookup
|
||||
if (senoneid2transPindex[hmm.senoneids[s]] == -2) // no value yet
|
||||
senoneid2transPindex[hmm.senoneids[s]] = hmm.transPindex;
|
||||
else if (senoneid2transPindex[hmm.senoneids[s]] != hmm.transPindex)
|
||||
senoneid2transPindex[hmm.senoneids[s]] = -1; // multiple inconsistent values
|
||||
if (senoneid2stateindex[hmm.senoneids[s]] == -2)
|
||||
senoneid2stateindex[hmm.senoneids[s]] = (int) s;
|
||||
else if (senoneid2stateindex[hmm.senoneids[s]] != (int) s)
|
||||
senoneid2stateindex[hmm.senoneids[s]] = -1;
|
||||
}
|
||||
for (size_t s = hmm.numstates; s < _countof (hmm.senoneids); s++) // clear out the rest if needed
|
||||
hmm.senoneids[s] = USHRT_MAX;
|
||||
// add to name-to-HMM hash
|
||||
auto ir = name2hmm.insert (std::make_pair (hmmname, hmm)); // insert into hash table
|
||||
if (!ir.second) // not inserted
|
||||
throw std::runtime_error ("simplesenonehmm: duplicate unit name in line: " + string (lines[i]));
|
||||
// add to hmm-to-index hash
|
||||
// and update the actual lookup table
|
||||
size_t hmmindex = hmms.size(); // (assume it's a new entry)
|
||||
auto is = hmm2index.insert (std::make_pair (hmm, hmmindex));
|
||||
if (is.second) // was indeed inserted: add to hmms[]
|
||||
{
|
||||
// insert first, as this copies the name; we can then point to it
|
||||
auto it = symmap.insert (std::make_pair (hmmname, hmmindex)); // insert into hash table
|
||||
hmm.name = it.first->first.c_str(); // only use first name if multiple (the name is informative only anyway)
|
||||
hmms.push_back (hmm);
|
||||
}
|
||||
else // not inserted
|
||||
{
|
||||
hmmindex = is.first->second; // use existing value
|
||||
symmap.insert (std::make_pair (hmmname, hmmindex)); // insert into hash table
|
||||
}
|
||||
}
|
||||
fprintf (stderr, "simplesenonehmm: %zu units with %zu unique HMMs, %zu tied states, and %zu trans matrices read\n",
|
||||
symmap.size(), hmms.size(), statemap.size(), transPs.size());
|
||||
}
|
||||
|
||||
// exposed so we can pass it to the lattice reader, which maps the symbol ids for us
|
||||
const std::unordered_map<std::string,size_t> & getsymmap() const { return symmap; }
|
||||
|
||||
// inverse lookup --for scoring the ground-truth
|
||||
// Note: /sil/ and /sp/ will be ambiguous, so need to handle them as a special case.
|
||||
int senonetransP (size_t senoneid) const { return senoneid2transPindex[senoneid]; }
|
||||
int senonestate (size_t senoneid) const { return senoneid2stateindex[senoneid]; }
|
||||
const size_t getnumsenone () const {return senoneid2stateindex.size(); }
|
||||
const bool statebelongstohmm (const size_t senoneid, const hmm & hmm) const // reutrn true if one of the states of this hmm == senoneid
|
||||
{
|
||||
size_t numstates = hmm.getnumstates();
|
||||
for (size_t i = 0; i < numstates; i++)
|
||||
if (hmm.senoneids[i] == senoneid)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
};};
|
|
@ -1,152 +0,0 @@
|
|||
//
|
||||
// <copyright file="simplethread.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// simplethread.h -- a simple thread implementation
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "basetypes.h"
|
||||
#ifdef _WIN32
|
||||
#include <process.h> // for _beginthread()
|
||||
#endif
|
||||
|
||||
namespace msra { namespace util {
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// signallingevent -- wrapper around Windows events
|
||||
// ---------------------------------------------------------------------------
|
||||
class signallingevent // TODO: should this go into basetypes.h?
|
||||
{
|
||||
HANDLE h;
|
||||
public:
|
||||
signallingevent (bool initialstate = true)
|
||||
{
|
||||
h = ::CreateEvent (NULL, FALSE/*manual reset*/, initialstate ? TRUE : FALSE, NULL);
|
||||
if (h == NULL)
|
||||
throw std::runtime_error ("signallingevent: CreateEvent() failed");
|
||||
}
|
||||
~signallingevent() { ::CloseHandle (h); }
|
||||
void wait() { if (::WaitForSingleObject (h, INFINITE) != WAIT_OBJECT_0) throw std::runtime_error ("wait: WaitForSingleObject() unexpectedly failed"); }
|
||||
void flag() { if (::SetEvent (h) == 0) throw std::runtime_error ("flag: SetEvent() unexpectedly failed"); }
|
||||
};
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// simplethread -- simple thread wrapper
|
||||
// ---------------------------------------------------------------------------
|
||||
class simplethread : CCritSec
|
||||
{
|
||||
std::shared_ptr<std::exception> badallocexceptionptr; // in case we fail to copy the exception
|
||||
std::shared_ptr<std::exception> exceptionptr; // if non-NULL, then thread failed with exception
|
||||
// wrapper around passing the functor
|
||||
signallingevent startsignal;
|
||||
const void * functorptr;
|
||||
template<typename FUNCTION> static unsigned int __stdcall staticthreadproc (void * usv)
|
||||
{
|
||||
simplethread * us = (simplethread*) usv;
|
||||
const FUNCTION body = *(const FUNCTION *) us->functorptr;
|
||||
us->startsignal.flag();
|
||||
us->threadproc (body);
|
||||
return 0;
|
||||
}
|
||||
template<typename FUNCTION> void threadproc (const FUNCTION & body)
|
||||
{
|
||||
try
|
||||
{
|
||||
body(); // execute the function
|
||||
}
|
||||
catch (const std::exception & e)
|
||||
{
|
||||
fail (e);
|
||||
}
|
||||
catch (...) // we do not catch anything that is not based on std::exception
|
||||
{
|
||||
fprintf (stderr, "simplethread: thread proc failed with unexpected unknown exception, which is not allowed. Terminating\n");
|
||||
fflush (stderr); // (needed?)
|
||||
abort(); // should never happen
|
||||
}
|
||||
}
|
||||
HANDLE threadhandle;
|
||||
public:
|
||||
template<typename FUNCTION> simplethread (const FUNCTION & body) : badallocexceptionptr (new std::bad_alloc()), functorptr (&body), startsignal (false)
|
||||
{
|
||||
unsigned int threadid;
|
||||
uintptr_t rc = _beginthreadex (NULL/*security*/, 0/*stack*/, staticthreadproc<FUNCTION>, this, CREATE_SUSPENDED, &threadid);
|
||||
if (rc == 0)
|
||||
throw std::runtime_error ("simplethread: _beginthreadex() failed");
|
||||
threadhandle = OpenThread (THREAD_ALL_ACCESS, FALSE, threadid);
|
||||
if (threadhandle == NULL)
|
||||
throw std::logic_error ("simplethread: _beginthreadex() unexpectedly did not return valid thread id"); // BUGBUG: leaking something
|
||||
DWORD rc1 = ::ResumeThread (threadhandle);
|
||||
if (rc1 == (DWORD) -1)
|
||||
{
|
||||
::TerminateThread (threadhandle, 0);
|
||||
::CloseHandle (threadhandle);
|
||||
throw std::logic_error ("simplethread: ResumeThread() failed unexpectedly");
|
||||
}
|
||||
try
|
||||
{
|
||||
startsignal.wait(); // wait until functor has been copied
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
::TerminateThread (threadhandle, 0);
|
||||
::CloseHandle (threadhandle);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
// check if the thread is still alive and without error
|
||||
void check()
|
||||
{
|
||||
CAutoLock lock (*this);
|
||||
// pass on a pending exception
|
||||
if (exceptionptr)
|
||||
throw *exceptionptr.get();
|
||||
// the thread going away without error is also unexpected at this point
|
||||
if (wait (0)) // (0 means don't block, so OK to call inside lock)
|
||||
throw std::runtime_error ("check: thread terminated unexpectedly");
|
||||
}
|
||||
bool wait (DWORD dwMilliseconds = INFINITE)
|
||||
{
|
||||
DWORD rc = ::WaitForSingleObject (threadhandle, dwMilliseconds);
|
||||
if (rc == WAIT_TIMEOUT)
|
||||
return false;
|
||||
else if (rc == WAIT_OBJECT_0)
|
||||
return true;
|
||||
else
|
||||
throw std::runtime_error ("wait: WaitForSingleObject() failed unexpectedly");
|
||||
}
|
||||
// thread itself can set the failure condition, e.g. before it signals some other thread to pick it up
|
||||
void fail (const std::exception & e)
|
||||
{
|
||||
// exception: remember it --this will remove the type info :(
|
||||
CAutoLock lock (*this);
|
||||
try // copy the exception--this may fail if we are out of memory
|
||||
{
|
||||
exceptionptr.reset (new std::runtime_error (e.what()));
|
||||
}
|
||||
catch (...) // failed to alloc: fall back to bad_alloc, which is most likely the cause in such situation
|
||||
{
|
||||
exceptionptr = badallocexceptionptr;
|
||||
}
|
||||
}
|
||||
//void join()
|
||||
//{
|
||||
// check();
|
||||
// wait();
|
||||
// check_for_exception(); // (check() not sufficient because it would fail since thread is gone)
|
||||
//}
|
||||
~simplethread() throw()
|
||||
{
|
||||
// wait until it shuts down
|
||||
try { wait(); }
|
||||
catch (...) { ::TerminateThread (threadhandle, 0); }
|
||||
// close the handle
|
||||
::CloseHandle (threadhandle);
|
||||
}
|
||||
};
|
||||
|
||||
};};
|
|
@ -1,123 +0,0 @@
|
|||
//
|
||||
// <copyright file="ssefloat4.h" company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// ssematrix.h -- matrix with SSE-accelerated operations
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <intrin.h> // for intrinsics
|
||||
#endif
|
||||
#ifdef __unix__
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
namespace msra { namespace math {
|
||||
|
||||
// ===========================================================================
|
||||
// float4 -- wrapper around the rather ugly SSE intrinsics for float[4]
|
||||
//
|
||||
// Do not use the intrinsics outside anymore; instead add all you need into this class.
|
||||
//
|
||||
// MSDN links:
|
||||
// basic: http://msdn.microsoft.com/en-us/library/x5c07e2a%28v=VS.80%29.aspx
|
||||
// load/store: (add this)
|
||||
// newer ones: (seems no single list available)
|
||||
// ===========================================================================
|
||||
|
||||
class float4
|
||||
{
|
||||
__m128 v; // value
|
||||
private:
|
||||
// return the low 'float'
|
||||
float f0() const { float f; _mm_store_ss (&f, v); return f; }
|
||||
// construct from a __m128, assuming it is a f32 vector (needed for directly returning __m128 below)
|
||||
float4 (const __m128 & v) : v (v) {}
|
||||
// return as a __m128 --should this be a reference?
|
||||
operator __m128() const { return v; }
|
||||
// assign a __m128 (needed for using nested float4 objects inside this class, e.g. sum())
|
||||
float4 & operator= (const __m128 & other) { v = other; return *this; }
|
||||
public:
|
||||
float4() {} // uninitialized
|
||||
float4 (const float4 & f4) : v (f4.v) {}
|
||||
float4 & operator= (const float4 & other) { v = other.v; return *this; }
|
||||
|
||||
// construct from a single float, copy to all components
|
||||
float4 (float f) : v (_mm_load1_ps (&f)) {}
|
||||
//float4 (float f) : v (_mm_set_ss (f)) {} // code seems more complex than _mm_load1_ps()
|
||||
|
||||
// basic math
|
||||
float4 operator-() const { return _mm_sub_ps (_mm_setzero_ps(), v); } // UNTESTED; setzero is a composite
|
||||
|
||||
float4 operator& (const float4 & other) const { return _mm_and_ps (v, other); }
|
||||
float4 operator| (const float4 & other) const { return _mm_or_ps (v, other); }
|
||||
float4 operator+ (const float4 & other) const { return _mm_add_ps (v, other); }
|
||||
float4 operator- (const float4 & other) const { return _mm_sub_ps (v, other); }
|
||||
float4 operator* (const float4 & other) const { return _mm_mul_ps (v, other); }
|
||||
float4 operator/ (const float4 & other) const { return _mm_div_ps (v, other); }
|
||||
|
||||
float4 & operator&= (const float4 & other) { v = _mm_and_ps (v, other); return *this; }
|
||||
float4 & operator|= (const float4 & other) { v = _mm_or_ps (v, other); return *this; }
|
||||
float4 & operator+= (const float4 & other) { v = _mm_add_ps (v, other); return *this; }
|
||||
float4 & operator-= (const float4 & other) { v = _mm_sub_ps (v, other); return *this; }
|
||||
float4 & operator*= (const float4 & other) { v = _mm_mul_ps (v, other); return *this; }
|
||||
float4 & operator/= (const float4 & other) { v = _mm_div_ps (v, other); return *this; }
|
||||
|
||||
float4 operator>= (const float4 & other) const { return _mm_cmpge_ps (v, other); }
|
||||
float4 operator<= (const float4 & other) const { return _mm_cmple_ps (v, other); }
|
||||
|
||||
// not yet implemented binary arithmetic ops: sqrt, rcp (reciprocal), rqsrt, min, max
|
||||
|
||||
// other goodies I came across (intrin.h):
|
||||
// - _mm_prefetch
|
||||
// - _mm_stream_ps --store without polluting cache
|
||||
// - unknown: _mm_addsub_ps, _mm_hsub_ps, _mm_movehdup_ps, _mm_moveldup_ps, _mm_blend_ps, _mm_blendv_ps, _mm_insert_ps, _mm_extract_ps, _mm_round_ps
|
||||
// - _mm_dp_ps dot product! http://msdn.microsoft.com/en-us/library/bb514054.aspx
|
||||
// Not so interesting for long vectors, we get better numerical precision with parallel adds and hadd at the end
|
||||
|
||||
// prefetch a float4 from an address
|
||||
static void prefetch (const float4 * p) { _mm_prefetch ((const char *) const_cast<float4 *> (p), _MM_HINT_T0); }
|
||||
|
||||
// transpose a 4x4 matrix
|
||||
// Passing input as const ref to ensure aligned-ness
|
||||
static void transpose (const float4 & col0, const float4 & col1, const float4 & col2, const float4 & col3,
|
||||
float4 & row0, float4 & row1, float4 & row2, float4 & row3)
|
||||
{ // note: the temp variable here gets completely eliminated by optimization
|
||||
float4 m0 = col0; float4 m1 = col1; float4 m2 = col2; float4 m3 = col3;
|
||||
_MM_TRANSPOSE4_PS (m0, m1, m2, m3); // 8 instructions for 16 elements
|
||||
row0 = m0; row1 = m1; row2 = m2; row3 = m3;
|
||||
}
|
||||
|
||||
// save a float4 to RAM bypassing the cache ('without polluting the cache')
|
||||
void storewithoutcache (float4 & r4) const
|
||||
{
|
||||
//_mm_stream_ps ((float*) &r4, v);
|
||||
r4 = v;
|
||||
}
|
||||
|
||||
#if 0
|
||||
// save a float4 to RAM bypassing the cache ('without polluting the cache')
|
||||
void storewithoutcache (float4 * p4) const
|
||||
{
|
||||
//_mm_stream_ps ((float*) p4, v);
|
||||
*p4 = v;
|
||||
}
|
||||
|
||||
// save a float to RAM bypassing the cache ('without polluting the cache')
|
||||
void storewithoutcache (float & r) const
|
||||
{
|
||||
_mm_stream_ss (&r, v);
|
||||
}
|
||||
#endif
|
||||
|
||||
// return the horizontal sum of all 4 components
|
||||
// ... return float4, use another mechanism to store the low word
|
||||
float sum() const { float4 hsum = _mm_hadd_ps (v, v); hsum = _mm_hadd_ps (hsum, hsum); return hsum.f0(); }
|
||||
|
||||
// please add anything else you might need HERE
|
||||
};
|
||||
|
||||
};};
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,156 +1,156 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup Label="ProjectConfigurations">
|
||||
<ProjectConfiguration Include="Debug|x64">
|
||||
<Configuration>Debug</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
<ProjectConfiguration Include="Release|x64">
|
||||
<Configuration>Release</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
</ItemGroup>
|
||||
<PropertyGroup Label="Globals">
|
||||
<ProjectGuid>{9A2F2441-5972-4EA8-9215-4119FCE0FB68}</ProjectGuid>
|
||||
<SccProjectName>
|
||||
</SccProjectName>
|
||||
<SccAuxPath>
|
||||
</SccAuxPath>
|
||||
<SccLocalPath>
|
||||
</SccLocalPath>
|
||||
<SccProvider>
|
||||
</SccProvider>
|
||||
<Keyword>Win32Proj</Keyword>
|
||||
<RootNamespace>UCIReader</RootNamespace>
|
||||
<ProjectName>LMSequenceReader</ProjectName>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>true</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>false</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<WholeProgramOptimization>true</WholeProgramOptimization>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
|
||||
<ImportGroup Label="ExtensionSettings">
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<PropertyGroup Label="UserMacros" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<LinkIncremental>true</LinkIncremental>
|
||||
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<LinkIncremental>false</LinkIncremental>
|
||||
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<ClCompile>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>WIN32;_DEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<Optimization>MaxSpeed</Optimization>
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\basetypes.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataWriter.h" />
|
||||
<ClInclude Include="..\..\Common\Include\File.h" />
|
||||
<ClInclude Include="..\..\Common\Include\fileutil.h" />
|
||||
<ClInclude Include="minibatchsourcehelpers.h" />
|
||||
<ClInclude Include="SequenceWriter.h" />
|
||||
<ClInclude Include="stdafx.h" />
|
||||
<ClInclude Include="targetver.h" />
|
||||
<ClInclude Include="SequenceReader.h" />
|
||||
<ClInclude Include="SequenceParser.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\..\Common\ConfigFile.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\DataReader.cpp" />
|
||||
<ClCompile Include="..\..\Common\DataWriter.cpp" />
|
||||
<ClCompile Include="..\..\Common\File.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\fileutil.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</CompileAsManaged>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
</PrecompiledHeader>
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</CompileAsManaged>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="SequenceWriter.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="SequenceReader.cpp" />
|
||||
<ClCompile Include="SequenceParser.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Text Include="SentenceTest.txt" />
|
||||
<Text Include="SequenceTest.txt" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
</ImportGroup>
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup Label="ProjectConfigurations">
|
||||
<ProjectConfiguration Include="Debug|x64">
|
||||
<Configuration>Debug</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
<ProjectConfiguration Include="Release|x64">
|
||||
<Configuration>Release</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
</ItemGroup>
|
||||
<PropertyGroup Label="Globals">
|
||||
<ProjectGuid>{9A2F2441-5972-4EA8-9215-4119FCE0FB68}</ProjectGuid>
|
||||
<SccProjectName>
|
||||
</SccProjectName>
|
||||
<SccAuxPath>
|
||||
</SccAuxPath>
|
||||
<SccLocalPath>
|
||||
</SccLocalPath>
|
||||
<SccProvider>
|
||||
</SccProvider>
|
||||
<Keyword>Win32Proj</Keyword>
|
||||
<RootNamespace>UCIReader</RootNamespace>
|
||||
<ProjectName>LMSequenceReader</ProjectName>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>true</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>false</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<WholeProgramOptimization>true</WholeProgramOptimization>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
|
||||
<ImportGroup Label="ExtensionSettings">
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<PropertyGroup Label="UserMacros" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<LinkIncremental>true</LinkIncremental>
|
||||
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<LinkIncremental>false</LinkIncremental>
|
||||
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<ClCompile>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>WIN32;_DEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<Optimization>MaxSpeed</Optimization>
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_USRDLL;UCIREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\basetypes.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataWriter.h" />
|
||||
<ClInclude Include="..\..\Common\Include\File.h" />
|
||||
<ClInclude Include="..\..\Common\Include\fileutil.h" />
|
||||
<ClInclude Include="minibatchsourcehelpers.h" />
|
||||
<ClInclude Include="SequenceWriter.h" />
|
||||
<ClInclude Include="stdafx.h" />
|
||||
<ClInclude Include="targetver.h" />
|
||||
<ClInclude Include="SequenceReader.h" />
|
||||
<ClInclude Include="SequenceParser.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\..\Common\ConfigFile.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\DataReader.cpp" />
|
||||
<ClCompile Include="..\..\Common\DataWriter.cpp" />
|
||||
<ClCompile Include="..\..\Common\File.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\fileutil.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</CompileAsManaged>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
</PrecompiledHeader>
|
||||
<CompileAsManaged Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</CompileAsManaged>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="SequenceWriter.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="SequenceReader.cpp" />
|
||||
<ClCompile Include="SequenceParser.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Text Include="SentenceTest.txt" />
|
||||
<Text Include="SequenceTest.txt" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
</ImportGroup>
|
||||
</Project>
|
|
@ -2101,7 +2101,7 @@ void BatchSequenceReader<ElemType>::GetLabelOutput(std::map < std::wstring,
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
void BatchSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<float>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
{
|
||||
DEVICEID_TYPE device = mtSentenceBegin.GetDeviceId();
|
||||
mtSentenceBegin.TransferFromDeviceToDevice(device, sentenceBegin.GetDeviceId(), true);
|
||||
|
|
|
@ -76,7 +76,7 @@ public:
|
|||
double logprob(int i) const { if (uniform_sampling) return uniform_log_prob; else return m_log_prob[i]; }
|
||||
|
||||
template <typename Engine>
|
||||
int sample(Engine &eng) const
|
||||
int sample(Engine &eng)
|
||||
{
|
||||
int m = unif_int(eng);
|
||||
if (uniform_sampling)
|
||||
|
@ -353,7 +353,7 @@ private:
|
|||
bool mSentenceEnd;
|
||||
bool mSentenceBegin;
|
||||
|
||||
Matrix<ElemType> mtSentenceBegin;
|
||||
Matrix<float> mtSentenceBegin;
|
||||
vector<MinibatchPackingFlag> m_minibatchPackingFlag;
|
||||
|
||||
public:
|
||||
|
@ -396,7 +396,7 @@ public:
|
|||
size_t NumberSlicesInEachRecurrentIter();
|
||||
|
||||
void SetSentenceSegBatch(std::vector<size_t> &sentenceEnd);
|
||||
void SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
void SetSentenceSegBatch(Matrix<float>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
|
||||
int GetSentenceEndIdFromOutputLabel();
|
||||
|
||||
|
|
|
@ -984,7 +984,7 @@ size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(std::map<std::wstring,
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void BatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
void BatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<float>& sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
{
|
||||
DEVICEID_TYPE device = mtSentenceBegin.GetDeviceId();
|
||||
mtSentenceBegin.TransferFromDeviceToDevice(device, sentenceBegin.GetDeviceId(), true);
|
||||
|
@ -1291,7 +1291,7 @@ void MultiIOBatchLUSequenceReader<ElemType>::StartMinibatchLoop(size_t mbSize, s
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
void MultiIOBatchLUSequenceReader<ElemType>::SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag)
|
||||
{
|
||||
/// run for each reader
|
||||
vector<size_t> col;
|
||||
|
|
|
@ -301,7 +301,7 @@ public:
|
|||
size_t NumberSlicesInEachRecurrentIter();
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t mz);
|
||||
|
||||
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
void SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
|
||||
public:
|
||||
void GetClassInfo(LabelInfo& lblInfo);
|
||||
|
@ -359,7 +359,7 @@ public:
|
|||
/// the second data stream has two sentences, with 0 indicating begining of sentences
|
||||
/// you may use 1 even if a sentence begins at that position, in this case, the trainer will carry over hidden states to the following
|
||||
/// frame.
|
||||
Matrix<ElemType> mtSentenceBegin;
|
||||
Matrix<float> mtSentenceBegin;
|
||||
|
||||
/// a matrix of 1 x n_length
|
||||
/// 1 denotes the case that there exists sentnece begin or no_labels case in this frame
|
||||
|
@ -399,7 +399,7 @@ public:
|
|||
|
||||
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples);
|
||||
|
||||
void SetSentenceSegBatch(Matrix<ElemType> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
void SetSentenceSegBatch(Matrix<float> & sentenceBegin, vector<MinibatchPackingFlag>& minibatchPackingFlag);
|
||||
|
||||
size_t NumberSlicesInEachRecurrentIter();
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@
|
|||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
@ -96,7 +96,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
|
|
|
@ -145,7 +145,7 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceSegBatch(Matrix<ElemType> &, vector<MinibatchPackingFlag>& ){};
|
||||
void SetSentenceSegBatch(Matrix<float> &, vector<MinibatchPackingFlag>&){};
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
|
|
@ -74,7 +74,7 @@
|
|||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
@ -97,7 +97,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
|
|
|
@ -58,7 +58,7 @@ public:
|
|||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return 1 ;}
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t) { };
|
||||
void SetSentenceSegBatch(Matrix<ElemType> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
|
||||
void SetSentenceSegBatch(Matrix<float> &/*sentenceBegin*/, vector<MinibatchPackingFlag>& /*sentenceExistsBeginOrNoLabels*/) {};
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& /*sectionName*/, size_t /*numRecords*/, void* /*data*/, size_t& /*dataBufferSize*/, size_t /*recordStart*/) { throw runtime_error("GetData not supported in SparsePCReader"); };
|
||||
|
|
|
@ -1,151 +1,151 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup Label="ProjectConfigurations">
|
||||
<ProjectConfiguration Include="Debug|x64">
|
||||
<Configuration>Debug</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
<ProjectConfiguration Include="Release|x64">
|
||||
<Configuration>Release</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
</ItemGroup>
|
||||
<PropertyGroup Label="Globals">
|
||||
<ProjectGuid>{CE429AA2-3778-4619-8FD1-49BA3B81197B}</ProjectGuid>
|
||||
<SccProjectName>
|
||||
</SccProjectName>
|
||||
<SccAuxPath>
|
||||
</SccAuxPath>
|
||||
<SccLocalPath>
|
||||
</SccLocalPath>
|
||||
<SccProvider>
|
||||
</SccProvider>
|
||||
<Keyword>Win32Proj</Keyword>
|
||||
<RootNamespace>SparsePCReader</RootNamespace>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>true</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>false</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<WholeProgramOptimization>true</WholeProgramOptimization>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
|
||||
<ImportGroup Label="ExtensionSettings">
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<PropertyGroup Label="UserMacros" />
|
||||
<PropertyGroup Label="UserMacros" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<LinkIncremental>true</LinkIncremental>
|
||||
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<LinkIncremental>false</LinkIncremental>
|
||||
<IncludePath>c:\Program Files\Microsoft MPI\Inc;..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>c:\Program Files\Microsoft MPI\Lib\amd64;$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<ClCompile>
|
||||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;_DEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<Optimization>MaxSpeed</Optimization>
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;NDEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\basetypes.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataWriter.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\Common\Include\File.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\Common\Include\fileutil.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="SparsePCReader.h" />
|
||||
<ClInclude Include="minibatchsourcehelpers.h" />
|
||||
<ClInclude Include="stdafx.h" />
|
||||
<ClInclude Include="targetver.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\..\Common\ConfigFile.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\DataReader.cpp" />
|
||||
<ClCompile Include="..\..\Common\DataWriter.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\File.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\fileutil.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="dllmain.cpp" />
|
||||
<ClCompile Include="SparsePCReader.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Use</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
</ImportGroup>
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup Label="ProjectConfigurations">
|
||||
<ProjectConfiguration Include="Debug|x64">
|
||||
<Configuration>Debug</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
<ProjectConfiguration Include="Release|x64">
|
||||
<Configuration>Release</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
</ItemGroup>
|
||||
<PropertyGroup Label="Globals">
|
||||
<ProjectGuid>{CE429AA2-3778-4619-8FD1-49BA3B81197B}</ProjectGuid>
|
||||
<SccProjectName>
|
||||
</SccProjectName>
|
||||
<SccAuxPath>
|
||||
</SccAuxPath>
|
||||
<SccLocalPath>
|
||||
</SccLocalPath>
|
||||
<SccProvider>
|
||||
</SccProvider>
|
||||
<Keyword>Win32Proj</Keyword>
|
||||
<RootNamespace>SparsePCReader</RootNamespace>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>true</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>false</UseDebugLibraries>
|
||||
<PlatformToolset>v120</PlatformToolset>
|
||||
<WholeProgramOptimization>true</WholeProgramOptimization>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
|
||||
<ImportGroup Label="ExtensionSettings">
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<PropertyGroup Label="UserMacros" />
|
||||
<PropertyGroup Label="UserMacros" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<LinkIncremental>true</LinkIncremental>
|
||||
<IncludePath>..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<LinkIncremental>false</LinkIncremental>
|
||||
<IncludePath>c:\Program Files\Microsoft MPI\Inc;..\..\common\include;..\..\math\math;$(VCInstallDir)include;$(VCInstallDir)atlmfc\include;$(WindowsSDK_IncludePath);</IncludePath>
|
||||
<LibraryPath>c:\Program Files\Microsoft MPI\Lib\amd64;$(SolutionDir)$(Platform)\$(Configuration);$(VCInstallDir)lib\amd64;$(VCInstallDir)atlmfc\lib\amd64;$(WindowsSDK_LibraryPath_x64);</LibraryPath>
|
||||
<IntDir>$(Platform)\$(Configuration)\$(ProjectName)\</IntDir>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<ClCompile>
|
||||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;_DEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<Optimization>MaxSpeed</Optimization>
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;WIN32;NDEBUG;_WINDOWS;_USRDLL;SparsePCREADER_EXPORTS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>..\..\common\include;..\..\math\math</AdditionalIncludeDirectories>
|
||||
<OpenMPSupport>false</OpenMPSupport>
|
||||
<AdditionalOptions>/d2Zi+ %(AdditionalOptions)</AdditionalOptions>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\..\Common\Include\basetypes.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataReader.h" />
|
||||
<ClInclude Include="..\..\Common\Include\DataWriter.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\Common\Include\File.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\..\Common\Include\fileutil.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="SparsePCReader.h" />
|
||||
<ClInclude Include="minibatchsourcehelpers.h" />
|
||||
<ClInclude Include="stdafx.h" />
|
||||
<ClInclude Include="targetver.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\..\Common\ConfigFile.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\DataReader.cpp" />
|
||||
<ClCompile Include="..\..\Common\DataWriter.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\File.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\..\Common\fileutil.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="dllmain.cpp" />
|
||||
<ClCompile Include="SparsePCReader.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Use</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Exports.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
</ImportGroup>
|
||||
</Project>
|
|
@ -112,13 +112,13 @@ public:
|
|||
virtual bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
|
||||
|
||||
size_t NumberSlicesInEachRecurrentIter() { return mBlgSize; }
|
||||
void SetSentenceSegBatch(Matrix<ElemType> &, vector<MinibatchPackingFlag>& ){};
|
||||
void SetSentenceSegBatch(Matrix<float> &, vector<MinibatchPackingFlag>&){};
|
||||
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
|
||||
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
|
||||
|
||||
virtual bool DataEnd(EndDataType endDataType);
|
||||
void SetSentenceSegBatch(Matrix<ElemType>&, Matrix<ElemType>&) { };
|
||||
void SetSentenceSegBatch(Matrix<float>&, Matrix<ElemType>&) { };
|
||||
|
||||
void SetNbrSlicesEachRecurrentIter(const size_t sz);
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@
|
|||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>CNTKMath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)$(Platform)\$(Configuration)\;..\..\math\$(Platform)\$(Configuration);..\$(Platform)\$(Configuration)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
@ -95,7 +95,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>CNTKmath.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>CNTKMathDll.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalLibraryDirectories>..\..\math\$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration)\</AdditionalLibraryDirectories>
|
||||
<Profile>true</Profile>
|
||||
</Link>
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче